bit演算#

ビット演算は整数をビット列として扱い、各ビットに対して論理演算やシフト演算を行う。高速・省メモリなため、競技プログラミングや低レベルプログラミングで頻繁に使われる。

演算子

名前

例 (a=0b1010, b=0b1100)

結果

&

AND

a & b

0b1000

|

OR

a | b

0b1110

^

XOR

a ^ b

0b0110

~

NOT

~a

-11(2の補数)

<<

左シフト

a << 1

0b10100

>>

右シフト

a >> 1

0b0101

a = 0b1010  # 10
b = 0b1100  # 12

print(f"a     = {a:04b} ({a})")
print(f"b     = {b:04b} ({b})")
print(f"a & b = {a & b:04b} ({a & b})")   # AND
print(f"a | b = {a | b:04b} ({a | b})")   # OR
print(f"a ^ b = {a ^ b:04b} ({a ^ b})")   # XOR
print(f"a << 1= {a << 1:05b} ({a << 1})") # 左シフト(×2)
print(f"a >> 1= {a >> 1:04b} ({a >> 1})") # 右シフト(÷2)
a     = 1010 (10)
b     = 1100 (12)
a & b = 1000 (8)
a | b = 1110 (14)
a ^ b = 0110 (6)
a << 1= 10100 (20)
a >> 1= 0101 (5)

ビットマスク#

1 << k で k ビット目だけが立ったマスクを作り、特定ビットの操作を行う。

操作

コード

k ビット目を立てる(set)

x | (1 << k)

k ビット目を消す(clear)

x & ~(1 << k)

k ビット目を反転(toggle)

x ^ (1 << k)

k ビット目を取得(get)

(x >> k) & 1

x = 0b0101  # 5

def show(label, val, width=4):
    print(f"{label:30s} = {val:0{width}b} ({val})")

show("x", x)
show("set bit 1   : x | (1<<1)", x | (1 << 1))
show("clear bit 2 : x & ~(1<<2)", x & ~(1 << 2))
show("toggle bit 0: x ^ (1<<0)", x ^ (1 << 0))
print(f"{'get bit 2   : (x>>2)&1':30s} = {(x >> 2) & 1}")
x                              = 0101 (5)
set bit 1   : x | (1<<1)       = 0111 (7)
clear bit 2 : x & ~(1<<2)      = 0001 (1)
toggle bit 0: x ^ (1<<0)       = 0100 (4)
get bit 2   : (x>>2)&1         = 1

よく使うイディオム#

2の累乗判定#

n & (n - 1) == 0 が成り立つとき n は 2 の累乗(n > 0 を前提)。

なぜ? 2の累乗は 100...0 の形。n - 1011...1 になるので AND が 0 になる。

最下位ビットの取り出し(LSB)#

n & (-n) で最下位の立っているビットだけを取り出せる。Fenwick Tree(BIT)の根幹。

立っているビット数(ポップカウント)#

bin(n).count('1') または n.bit_count()(Python 3.10+)。

def is_power_of_2(n: int) -> bool:
    return n > 0 and (n & (n - 1)) == 0

def lsb(n: int) -> int:
    """最下位の立っているビットを取り出す"""
    return n & (-n)

def popcount(n: int) -> int:
    return bin(n).count('1')

# 2の累乗判定
for n in [0, 1, 2, 3, 4, 6, 8, 16]:
    print(f"is_power_of_2({n:2d}) = {is_power_of_2(n)}")

print()

# LSB
n = 0b10110  # 22
print(f"n   = {n:05b} ({n})")
print(f"LSB = {lsb(n):05b} ({lsb(n)})")

print()

# ポップカウント
for n in [0b0000, 0b1010, 0b1111, 255]:
    print(f"popcount({n:08b}) = {popcount(n)}")
is_power_of_2( 0) = False
is_power_of_2( 1) = True
is_power_of_2( 2) = True
is_power_of_2( 3) = False
is_power_of_2( 4) = True
is_power_of_2( 6) = False
is_power_of_2( 8) = True
is_power_of_2(16) = True

n   = 10110 (22)
LSB = 00010 (2)

popcount(00000000) = 0
popcount(00001010) = 2
popcount(00001111) = 4
popcount(11111111) = 8

応用:部分集合の列挙(ビット全探索)#

n 個の要素からなる集合の全部分集合を 0 から 2^n - 1 までのビット列で表現する。各ビットが「その要素を選ぶかどうか」に対応する。計算量は O(2^n)。

items = ["A", "B", "C"]
n = len(items)

print("全部分集合:")
for mask in range(1 << n):
    subset = [items[i] for i in range(n) if (mask >> i) & 1]
    print(f"  {mask:03b} -> {subset}")

print()

# 応用例: 合計がちょうど target になる部分集合を列挙
values = [1, 3, 5, 7]
target = 8
n = len(values)

print(f"values={values}, target={target} になる部分集合:")
for mask in range(1 << n):
    subset = [values[i] for i in range(n) if (mask >> i) & 1]
    if sum(subset) == target:
        print(f"  {subset}")
全部分集合:
  000 -> []
  001 -> ['A']
  010 -> ['B']
  011 -> ['A', 'B']
  100 -> ['C']
  101 -> ['A', 'C']
  110 -> ['B', 'C']
  111 -> ['A', 'B', 'C']

values=[1, 3, 5, 7], target=8 になる部分集合:
  [3, 5]
  [1, 7]

応用:ビット DP#

状態を集合(ビット列)で表すDPをビットDPという。典型例として巡回セールスマン問題(TSP)がある。

dp[S][v] = 頂点集合 S を訪問済みで、現在頂点 v にいるときの最小コスト。

\[ dp[S | (1 << u)][u] = \min(dp[S | (1 << u)][u], \; dp[S][v] + \text{cost}[v][u]) \]

計算量: O(2^n × n^2)

import math

def tsp(cost: list[list[int]]) -> int:
    """TSP: 全頂点を1度ずつ巡回して0に戻る最小コスト"""
    n = len(cost)
    INF = math.inf
    # dp[mask][v]: mask の頂点を訪問済みで現在 v にいる最小コスト
    dp = [[INF] * n for _ in range(1 << n)]
    dp[1][0] = 0  # 頂点0からスタート(bit 0 が立った状態)

    for S in range(1 << n):
        for v in range(n):
            if dp[S][v] == INF:
                continue
            if not (S >> v) & 1:
                continue
            for u in range(n):
                if (S >> u) & 1:
                    continue  # 既訪問
                nS = S | (1 << u)
                dp[nS][u] = min(dp[nS][u], dp[S][v] + cost[v][u])

    full = (1 << n) - 1
    return min(dp[full][v] + cost[v][0] for v in range(1, n))

# 4頂点の例
cost = [
    [0, 2, 9, 10],
    [1, 0, 6,  4],
    [15, 7, 0, 8],
    [6, 3, 12, 0],
]
print(f"TSP最小コスト: {tsp(cost)}")  # 期待値: 21
TSP最小コスト: 21