二分探索#
二分探索(binary search) はソート済みのリストや配列のデータにおける検索で、検索したい値が中央の値より大きいかどうかで検索範囲を半分ずつに絞り込んでいく
計算量は\(O(\log N)\):1回探索するごとに探索範囲が半分ずつになっていくので要素数\(N\)に対し \(N=2^k\)回の計算で済む→\(k=\log_2 N\)
def binary_search(arr: list[int], target: int) -> int:
"""二分探索でarrにおけるtargetのindexを返す
arr: list[int]
ソート済みのlist
"""
# 探索範囲の左端のインデックス
left = 0
# 探索範囲の右端のインデックス
right = len(arr) - 1
# left が right を超えるまで探索を続ける
# left <= right の間は、まだ探索範囲が残っている
while left <= right:
# 探索範囲の中央のインデックスを求める
mid = (left + right) // 2
# 中央の値が target と一致したら、その位置を返す
if arr[mid] == target:
return mid # targetのindex
# 中央の値が target より小さい場合
elif arr[mid] < target:
# target は中央より右側にあるはずなので、探索範囲の左端を mid + 1 にずらす
left = mid + 1
# 中央の値が target より大きい場合
else:
# target は中央より左側にあるはずなので、範囲の右端を mid - 1 にずらす
right = mid - 1
# 探索範囲がなくなっても見つからなかった(arrにtargetが含まれない)場合
return -1
numbers = [1, 3, 5, 7, 9, 11, 13]
binary_search(numbers, 7)
3
関連:bisectパッケージ#
Pythonの標準パッケージ bisect は二分探索で実装されている
主な関数#
ソート済みlist a に対し、 値 x が挿入されるべき位置の左端/右端のインデックスが得られる
関数 |
説明 |
|---|---|
|
|
|
|
from bisect import bisect_left, bisect_right
a = [1, 3, 3, 5, 7] # aはソート済みのlist
print(bisect_left(a, 3)) # 1 (3が並ぶ左端)
print(bisect_right(a, 3)) # 3 (3が並ぶ右端)
print(bisect_left(a, 4)) # 3 (4を挿入するなら [1,3,3,4,5,7] の位置3)
print(bisect_right(a, 4)) # 3 (等しい値がないので left と同じ)
1
3
3
3
応用例#
値が存在するか確認する#
bisect_left が返すインデックスの値が x と等しければ存在する。
def contains(a, x):
i = bisect_left(a, x)
return i < len(a) and a[i] == x
a = [1, 3, 5, 7, 9]
print(contains(a, 5)) # True
print(contains(a, 4)) # False
True
False
\(X\) 以上の最小値 / \(X\) より大きい最小値を求める#
競プロでよく使うパターン。C++ の lower_bound / upper_bound に相当する。
bisect_left(a, x)→ x 以上の最小値のインデックス(lower_bound)bisect_right(a, x)→ x より大きい最小値のインデックス(upper_bound)
a = [1, 3, 5, 5, 7, 9]
def lower_bound(a, x):
"""x 以上の最小値を返す(存在しなければ None)"""
i = bisect_left(a, x)
return a[i] if i < len(a) else None
def upper_bound(a, x):
"""x より大きい最小値を返す(存在しなければ None)"""
i = bisect_right(a, x)
return a[i] if i < len(a) else None
print(lower_bound(a, 5)) # 5 (5以上の最小値)
print(upper_bound(a, 5)) # 7 (5より大きい最小値)
print(lower_bound(a, 6)) # 7 (6以上の最小値)
print(upper_bound(a, 9)) # None (9より大きい値がない)
5
7
7
None
範囲内の要素数を数える#
bisect_right(a, hi) - bisect_left(a, lo) で lo <= x <= hi を満たす要素数が求まる。
def count_range(a, lo, hi):
"""a の中で lo <= x <= hi を満たす要素数を返す"""
return bisect_right(a, hi) - bisect_left(a, lo)
a = [1, 2, 3, 4, 5, 6, 7]
print(count_range(a, 3, 5)) # 3 ([3, 4, 5])
print(count_range(a, 1, 7)) # 7
print(count_range(a, 4, 4)) # 1 ([4])
3
7
1