PythonでDisjoint Sparse Tableをかきました。
好きにつかってやってください!
コード
長いので折りたたみにしてます。
コード詳細
class DisjointSparseTable:
"""Disjoint Sparse Table
リストに対して、事前構築O(NlogN)、区間演算O(1)でできます。
演算funcは、結合則を満たすものだけです。
結合則: func(a, b) == func(b, a) (順番は関係なし)
Attributes:
lst (list): 対象のリスト
func (Callable[[int,int],int]): 演算 (xor, gcd, min, max, ...)
Examples:
>>> dst = DisjointSparseTable([9, 1, 6, 10, 2], min)
>>> dst.prod(0, 4)
1
>>> dst.prod(2, 4)
2
>>> def xor(i, j): return i ^ j
>>> dst = DisjointSparseTable([9, 1, 6, 10, 2], xor)
>>> dst.prod(0, 4) # 9^1^6^10^2=6
6
>>> dst.prod(0, 2) # 9^1^6=14
14
"""
def __init__(self, lst: list, func: "function") -> None:
length = 1 << (len(lst) - 1).bit_length()
self.my_lst = lst + [0] * (length - len(lst))
self.__func = func
# テーブルの作成
self.tables = []
dist = length >> 1
gc = 2 # グループの数
while dist:
now_table = [None] * length
left_start = 0
left, right = 0, dist - 1
for _ in range(gc):
if left_start:
val = self.my_lst[left]
for i in range(left, right + 1):
if i != left:
val = func(val, self.my_lst[i])
now_table[i] = val
else:
# right_start
val = self.my_lst[right]
for i in range(right, left - 1, -1):
if i != right:
val = func(val, self.my_lst[i])
now_table[i] = val
left += dist
right += dist
left_start ^= 1
self.tables.append(now_table)
dist >>= 1
gc <<= 1
# self.bit_lengthリストの作成
self.bit_length = [0]
for i in range(1, (len(lst) - 1).bit_length() + 1):
for s in range(1 << (i - 1)):
self.bit_length.append(i)
def prod(self, left: int, right: int) -> int:
"""区間演算
閉区間 [left, right] で演算 func を行った結果を返します。
Args:
left (int): 区間の左端のインデックス
right (int): 区間の右端のインデックス
Returns:
int: 閉区間 [left, right] で演算 func を行った結果
Note:
閉区間 [left, right] なのに注意してね。
"""
assert 0 <= left <= right < len(self.my_lst), \
"0<=left<=right<len(self.my_lst), left={0}, right={1}" \
.format(left, right)
if left == right:
return self.my_lst[left]
# table_idx = len(self.tables) - (left ^ right).bit_length()
table_idx = len(self.tables) - self.bit_length[left ^ right]
left_el = self.tables[table_idx][left]
right_el = self.tables[table_idx][right]
return self.__func(left_el, right_el)
つかいかた
docstringに書いてある通りですが、別の例も書いておきます。
# 上のコードをコピペしてから書いてね。
import math
dst = DisjointSparseTable([10, 30, 7, 21, 6], lambda x, y: math.gcd(x, y))
print(dst.prod(0, 1)) # 10
print(dst.prod(2, 3)) # 7
print(dst.prod(0, 4)) # 1
def mul(x, y): return x * y
dst2 = DisjointSparseTable([10, 30, 7, 21, 6], mul)
print(dst2.prod(0, 2)) # 2100
print(dst2.prod(3, 4)) # 126
print(dst2.prod(0, 4)) # 264600
ACコード
参考URL
ほぼここに書いてあるのをコードにしただけです!!ありがとうございます!!!
noshi91のメモ - Disjoint Sparse Table と セグ木に関するポエム
編集履歴
2023/07/23 bit_lengthの前計算で、クエリをO(1)で処理できるように改善しました。