はじめに
世の中には、過去の状態にアクセスしたいという需要が存在します。それを高速に実現する一つが永続データ構造です。
参考記事
永続配列について
①永続配列の使い道
永続セグメント木を考える前に永続配列について紹介します。
というようにデータを構築したうえで
→Ver.0の4番目を知りたい!
→Ver.1の2番目を知りたい!
のような命令(クエリ)にこたえるために、愚直にデータ構築をやると配列の長さをN、version数をQとして、NQのメモリとそれを構築する計算量がかかってしまうことが分かります。ですがよく見ると(よく見なくても)変更は一か所ずつなので、高速化ができそうだということも分かります。ここで使えるのが永続配列というデータ構造です。
②永続配列の構造
図は二分木で構築した永続配列(Ver.0)です。二分木の葉の部分にのみ値を持ちほかのノードは値を持ちません(根から該当箇所にアクセスするために値を持たないノードが存在しているといった具合です)。
根からのアクセスの仕方を見ておきましょう。
indexの1番にアクセスしたいときは根からLLRとたどればよく、
indexの6番にアクセスしたいときは根からRRLとたどればいいことが分かります。
Lを0、Rを1とすると、indexの1番は、001→1,indexの6番は、110→6となり、インデックス番号とたどり方が
1対1で対応していることが分かります。アクセスはこの性質をつかえばよさそうです。計算量はO(logN)です。
③更新とクエリの答え方
Ver.0の3番目を3に変えたものをVer.1とする操作をやってみます。これは赤いところのようにノードを追加し、pathをつなげば実現できます(計算量はO(logN)です。
クエリ(Ver.○○のindex○○の値を取得)に対しては、ver.0の時は青色の根から、ver.1の時は赤色の根から始め、末端のノードにアクセスすればしていけばよいです。
例えばVer.1の2番目を知りたい場合には緑の矢印をたどればよいです。
④実装例
pythonでの実装例を紹介します。
class Node:
def __init__(self,default):
self.rch = None
self.lch = None
self.val = default
class PersistentArray:
def __init__(self, ls, default_ver=0):
'''
ls (list) : 永続配列にしたい配列
default_ver (any): 最初のversion(デフォルト:0)
'''
N = len(ls)
self.N = N
self.K = (N - 1).bit_length()
self.N2 = 1 << self.K
self.dat = [Node(None) for i in range(2**(self.K + 1))]
for i in range(self.N): # 葉の構築
self.dat[self.N2 + i].val = ls[i]
self.build()
self.verdict = dict()
self.verdict[default_ver] = 1 # 各versionの根のindexを格納
def build(self):
for node in range(self.N2 - 1, 0, -1):
self.dat[node].rch = self.dat[node<<1 | 1]
self.dat[node].lch = self.dat[node<<1]
self.dat.pop()
self.dat.pop()
def get_t_x(self, t, x): # ver.tにおけるリストのx番目の値
'''
version:tにおけるindex:xの値を出力(O(logN))
'''
x += self.N2
v = self.dat[self.verdict[t]] # ver.tの根
path = bin(x)[3:]
for i in path:
if i == '0':
v = v.lch
else:
v = v.rch
return v.val
def update_told_tnew_x_val(self, told, tnew, x, val):
'''
version:toldのindex:xをvalに変更したものをversion:tnewとする(O(logN))
told: 変更前のversion
tnew: 変更後のversion
x: 変更するindex
val: 変更後の値
'''
if not (told in self.verdict):
raise('No such version exists')
x += self.N2
path = bin(x)[3:]
self.verdict[tnew] = len(self.dat)
new_nodes = [Node(None) for _ in range(len(path)+1)]
v_old = self.dat[self.verdict[told]]
v_new = new_nodes[0]
now = 1
for i in path: # ノードをつなげる
if i == '0':
v_new.rch = v_old.rch
v_new.lch = new_nodes[now]
v_new = new_nodes[now]
v_old = v_old.lch
else:
v_new.lch = v_old.lch
v_new.rch = new_nodes[now]
v_new = new_nodes[now]
v_old = v_old.rch
now += 1
v_new.val = val
self.dat.append(new_nodes[0])
def get_t_all(self,t):
'''
version:tにおける配列を出力(O(NlogN))
return : list
'''
if not (t in self.verdict):
raise('No such version exists')
ret = []
for i in range(self.N2,self.N2+self.N):
path = bin(i)[3:]
v = self.dat[self.verdict[t]]
for p in path:
if p == '0':
v = v.lch
else:
v = v.rch
ret.append(v.val)
return ret
def __getitem__(self, xt): return self.get_t_x(xt[0], xt[1])
Nodeクラスに左子(lch)、右子(rch)、値(val)を持たせ、二分木を構築しています。
⑤使用例
SG = PersistentArray([2,4,6,8,9,1,3,5])
SG.update_told_tnew_x_val(0,1,3,3)
SG.update_told_tnew_x_val(1,2,4,5)
SG.update_told_tnew_x_val(1,3,5,2)
print(SG.get_t_x(0,3))
# 8
print(SG.get_t_x(1,3))
# 3
print(SG.get_t_all(0))
# [2, 4, 6, 8, 9, 1, 3, 5]
print(SG.get_t_all(1))
# [2, 4, 6, 3, 9, 1, 3, 5]
print(SG.get_t_all(2))
# [2, 4, 6, 3, 5, 1, 3, 5]
print(SG.get_t_all(3))
# [2, 4, 6, 3, 9, 2, 3, 5]
⑥完全永続性
赤色矢印のように過去のversionから枝分かれするようなversionを作製できるのが完全永続、できないのが部分永続です。今回の永続配列は、過去の状態からの枝分かれも作成できるので完全永続となっています。
永続セグメント木
実は二分木の永続配列が分かれば永続セグメント木はあまり難しくありません。具体的には永続配列の二分木の親ノードに子ノードから求められる値を記録していき、区間クエリ時に読み取るといった操作が加わります。
①永続セグメント木の構造
今回紹介した永続配列は完全二分木なので、そのままセグメント木への拡張ができます。具体にはアクセスするときにpathとして使っていたノードに親の値を持たせます。セグメント木に乗せる関数は加算(sum)としました。具体には図のようになります。
②更新のやり方
先ほどの永続配列とほぼ同様です。異なるところは親の値の更新が入るところです。計算用は変わらずO(logN)でできます。
③区間クエリの取得
根からdfsをしていき、読み取るべきノードまで行ったらストップします。また区間外に出る場合には枝かりをすることで計算量を抑えることができます(もっといいアルゴリズムがあるかもです)。
Ver.1の[1,4)を取得する場合、考えるべき区間は緑の部分です。枝かり+該当箇所でストップすることで緑矢印部分のpath探索で済みます(4+9=13)。(次項の実装のところでは厳密には、末端の枝かりが甘いですが、大した計算量ではないので、場合分けは省いています。)
④実装例
pythonでの実装例を紹介します。
import math
class Node:
def __init__(self,default):
self.rch = None
self.lch = None
self.val = default
class PersistentSegTree:
DEFAULT = {
'min': 1 << 60,
'max': -(1 << 60),
'sum': 0,
'prd': 1,
'gcd': 0,
'lmc': 1,
'^': 0,
'&': (1 << 60) - 1,
'|': 0,
}
FUNC = {
'min': min,
'max': max,
'sum': (lambda x, y: x + y),
'prd': (lambda x, y: x * y),
'gcd': math.gcd,
'lmc': (lambda x, y: (x * y) // math.gcd(x, y)),
'^': (lambda x, y: x ^ y),
'&': (lambda x, y: x & y),
'|': (lambda x, y: x | y),
}
def __init__(self, ls, mode='min', func=None, default=None, default_ver=0):
'''
ls (list) : 永続配列にしたい配列
mode : モード(デフォルト:min)
func: 関数を指定(デフォルト:None) ※modeを指定する場合は不要
default: 単位元を指定(デフォルト:None) ※modeを指定する場合は不要
default_ver (any): 最初のversion(デフォルト:0)
'''
N = len(ls)
if default == None:
self.default = self.DEFAULT[mode]
else:
self.default = default
if func == None:
self.func = self.FUNC[mode]
else:
self.func = func
self.N = N
self.K = (N - 1).bit_length()
self.N2 = 1 << self.K
self.dat = [Node(self.default) for _ in range(2**(self.K + 1))]
for i in range(self.N): # 葉の構築
self.dat[self.N2 + i].val = ls[i]
self.build()
self.verdict = dict()
self.verdict[default_ver] = 1 # 各versionの根のindexを格納
def build(self):
for node in range(self.N2 - 1, 0, -1):
self.dat[node].rch = self.dat[node<<1 | 1]
self.dat[node].lch = self.dat[node<<1]
self.dat[node].val = self.func(self.dat[node<<1 | 1].val,self.dat[node<<1].val)
self.dat.pop()
self.dat.pop()
def get_t_x(self, t, x): # ver.tにおけるリストのx番目の値
'''
version:tにおけるindex:xの値を出力(O(logN))
return : int
'''
x += self.N2
v = self.dat[self.verdict[t]] # ver.tの根
path = bin(x)[3:]
for i in path:
if i == '0':
v = v.lch
else:
v = v.rch
return v.val
def update_told_tnew_x_val(self, told, tnew, x, val):
'''
version:toldのindex:xをvalに変更したものをtnewとする(O(logN))
told: 変更前のversion
tnew: 変更後のversion
x: 変更するindex
val: 変更後の値
'''
if not (told in self.verdict):
raise('No such version exists')
x += self.N2
path = bin(x)[3:]
self.verdict[tnew] = len(self.dat)
new_nodes = [Node(self.default) for _ in range(len(path)+1)]
v_old = self.dat[self.verdict[told]]
v_new = new_nodes[0]
now = 1
for i in path: # ノードをつなげる
if i == '0':
v_new.rch = v_old.rch
v_new.lch = new_nodes[now]
v_new = new_nodes[now]
v_old = v_old.lch
else:
v_new.lch = v_old.lch
v_new.rch = new_nodes[now]
v_new = new_nodes[now]
v_old = v_old.rch
now += 1
v_new.val = val
for node in range(len(path)-1,-1,-1): # 付け加えたノードの値を子から親に更新
new_nodes[node].val = self.func(new_nodes[node].lch.val,new_nodes[node].rch.val)
self.dat.append(new_nodes[0])
def get_t_all(self,t):
'''
version:tにおける配列を出力(O(NlogN))
return : list
'''
if not (t in self.verdict):
raise('No such version exists')
ret = []
for i in range(self.N2,self.N2+self.N):
path = bin(i)[3:]
v = self.dat[self.verdict[t]]
for p in path:
if p == '0':
v = v.lch
else:
v = v.rch
ret.append(v.val)
return ret
def query_t(self, t, L, R):
'''
version:tにおける区間クエリ[L,R)(O(logN))
return : int
'''
L += self.N2; R += self.N2
L0 = L; R0 = R
ls = set() # 見るべきノード番号
while L < R:
if L & 1:
ls.add(L)
L += 1
if R & 1:
R -= 1
ls.add(R)
L >>= 1
R >>= 1
valdef = self.default
d = [(1,0,self.dat[self.verdict[t]])]
while d: # 二分木上をdfs
num, depth, node = d.pop()
if num in ls:
valdef = self.func(node.val, valdef)
continue
if depth == self.K:
continue
if num >= (L0 >> (self.K-depth)):
d.append((num << 1,depth+1, node.lch))
if R0 > ((num << 1 | 1) << (self.K-depth-1)):
d.append((num << 1 | 1, depth+1 , node.rch))
return valdef
def __getitem__(self, tx): return self.get_t_x(tx[0], tx[1])
⑤使用例
SG = PersistentSegTree([2,4,6,8,9,1,3,5],mode='sum')
SG.update_told_tnew_x_val(0,1,3,3)
SG.update_told_tnew_x_val(1,2,4,5)
SG.update_told_tnew_x_val(1,3,5,2)
print(SG.get_t_all(1))
# [2, 4, 6, 3, 9, 1, 3, 5]
print(SG.get_t_all(2))
# [2, 4, 6, 3, 5, 1, 3, 5]
print(SG.get_t_all(3))
# [2, 4, 6, 3, 9, 2, 3, 5]
print(SG.query_t(0,0,6))
# 30
print(SG.query_t(1,0,6))
# 25
print(SG.query_t(2,0,6))
# 21
print(SG.query_t(3,0,6))
# 26
⑥競技プログラミングでの使用例
・Atcoder ABC165F LIS on Tree 提出結果(TLE) https://atcoder.jp/contests/abc165/submissions/25330532
何回か提出すると2100msくらいのTLEにもなったので、TLが2.5secなら通っていたかもです。なかなか制約が厳しいようです。
・yosupo Range Kth Smallest 提出結果AC 3414ms https://judge.yosupo.jp/submission/62383
少し工夫すると区間のk番目の値をlog(N)で求めるアルゴリズムが書けます。
最後に
今回作っては見ましたが、競技プログラミングにおいてpythonで通せる永続セグメント木問題はあまりないのかもしれません。。。そろそろC++を学ばなければいけませんね。。。
LGTMしていただけると大変励みになりますので参考になった方いましたらよろしくお願いいたします。
python を使ったデータ構造の記事は他にも書いてありますのでよかったら見ていってください。