マージソート#
マージソート(merge sort) は分割統治法に基づくソートアルゴリズム。配列を半分に分割して再帰的にソートし、2つのソート済み配列をマージ(併合)する。
計算量 |
|
|---|---|
最悪時間計算量 |
\(O(n \log n)\) |
平均時間計算量 |
\(O(n \log n)\) |
最良時間計算量 |
\(O(n \log n)\) |
空間計算量 |
\(O(n)\)(補助配列が必要) |
特徴
安定ソート
入力に依存しない安定した \(O(n \log n)\)
外部ソート(ディスク上の大規模データ)に適用しやすい
Python の
sorted()/list.sort()の内部は Timsort(マージソート + 挿入ソートの混合)
アルゴリズム#
要素数が1になるまで再帰的に分割
マージする:2つのソート済みリストを先頭から比較して小さい方を順に取り出す。
def merge_sort(arr: list) -> list:
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return _merge(left, right)
def _merge(left: list, right: list) -> list:
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
data = [38, 27, 43, 3, 9, 82, 10]
print("入力:", data)
print("出力:", merge_sort(data))
入力: [38, 27, 43, 3, 9, 82, 10]
出力: [3, 9, 10, 27, 38, 43, 82]
import random
random.seed(2)
large = random.sample(range(10000), 100)
assert merge_sort(large) == sorted(large)
print("ランダム100件: OK")
ランダム100件: OK
in-place マージソート(Bottom-up)#
再帰を使わず、マージ幅を 1 → 2 → 4 → … と倍増させながらボトムアップで処理する。補助配列は使うが再帰スタックが不要。
def merge_sort_bottomup(arr: list) -> list:
a = arr.copy()
n = len(a)
width = 1
while width < n:
for lo in range(0, n, 2 * width):
mid = min(lo + width, n)
hi = min(lo + 2 * width, n)
merged = _merge(a[lo:mid], a[mid:hi])
a[lo:hi] = merged
width *= 2
return a
data = [38, 27, 43, 3, 9, 82, 10]
print(merge_sort_bottomup(data))
assert merge_sort_bottomup(data) == sorted(data)
[3, 9, 10, 27, 38, 43, 82]
応用:転倒数の計算#
転倒数(inversion count) とは配列の中で \(i < j\) かつ \(a[i] > a[j]\) となるペアの個数。マージソートのマージ操作中にカウントできる。計算量 \(O(n \log n)\)。
def count_inversions(arr: list) -> tuple[list, int]:
"""(ソート済み配列, 転倒数) を返す"""
if len(arr) <= 1:
return arr, 0
mid = len(arr) // 2
left, lc = count_inversions(arr[:mid])
right, rc = count_inversions(arr[mid:])
merged, mc = _merge_count(left, right)
return merged, lc + rc + mc
def _merge_count(left: list, right: list) -> tuple[list, int]:
result = []
count = 0
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
result.append(left[i])
i += 1
else:
result.append(right[j])
count += len(left) - i # left[i:] はすべて right[j] より大きい
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result, count
arr = [2, 4, 1, 3, 5]
sorted_arr, inv = count_inversions(arr)
print(f"配列: {arr}")
print(f"転倒数: {inv}") # (2,1), (4,1), (4,3) → 3
print(f"ソート済み: {sorted_arr}")
配列: [2, 4, 1, 3, 5]
転倒数: 3
ソート済み: [1, 2, 3, 4, 5]