#はじめに
永続UnionFind(部分永続と完全永続)を書いたので紹介したいと思います。(実行速度検証あり)
(21/9/25若干の修正を加えました(主に計算量の改善))
#部分永続UnionFind
今のバージョンに対して変更可能で、前のバージョンに遡ってクエリに答えることができます。
###①仕組み
https://camypaper.bitbucket.io/2016/12/18/adc2016/
こちらの記事に分かりやすい解説が載っていました。
ポイントとしては
① unionfind(by rank または by size)を用いて実装。
② 各ノードにおいて親が更新された時刻をunion時に記録。
③ find時、親ノードをたどっていく時に指定した時刻tにおいて更に親があるかどうかを判断
(union時に記録した時刻と照らし合わせるとできる。親がなければ、そのノードが時刻tにおける親。あればさらに親をたどる)
といった感じです。
###②コード
https://misteer.hatenablog.com/entry/persistentUF
こちらの記事のコードを参考にunionfind by sizeで書いたものです。
(時刻tにおけるsizeを求められる機能も付いています。)
class PersistentPartUnionFind:
def __init__(self,N):
INF = float('INF')
self.now = 0
self.N = 0
self.parent = [-1 for i in range(N)]
self.time = [INF for i in range(N)]
self.num = [[] for i in range(N)]
for i in range(N):
self.num[i].append((0,1))
def find(self,t,x):
'''
version:tにおけるxの根を見つける
t (any) : version
x (int) : 要素
return : int : 根
'''
while self.time[x] <= t:
x = self.parent[x]
return x
def union(self,x,y):
'''
x,yをつなげる
x (int) : 要素
y (int) : 要素
'''
self.now += 1
x = self.find(self.now,x)
y = self.find(self.now,y)
if x == y:
return
if self.parent[x] > self.parent[y]:
x,y = y,x
self.parent[x] += self.parent[y]
self.parent[y] = x
self.time[y] = self.now
self.num[x].append((self.now,-self.parent[x]))
def same(self,t,x,y):
'''
version:tにおけるx,yが同じかどうかO(logN)
t (any) : version
x (int) : 要素
y (int) : 要素
return : bool : 同じかどうか
'''
return self.find(t,x) == self.find(t,y)
def size(self,t,x):
'''
version:tにおける要素xが含まれる集合の大きさ
t (any) : version
x (int) : 要素
return : int :集合の大きさ
'''
x = self.find(t,x)
ok = 0
ng = len(self.num[x])
while (ng-ok > 1):
mid = (ok+ng)>>1
if self.num[x][mid][0] <= t:
ok = mid
else:
ng = mid
return self.num[x][ok][1]
計算量はノード数N,union回数QとしてfindにO(log(N))程度かかるので、
・union: O(log(N))
・find: O(log(N))
・size: O(log(N)+log(Q))
となります。
③使用例
PUF = PersistentPartUnionFind(10)
PUF.union(0,1) # 1
PUF.union(2,3) # 2
PUF.union(4,9) # 3
PUF.union(4,3) # 4
PUF.union(7,8) # 5
PUF.union(5,7) # 6
PUF.union(5,1) # 7
print(PUF.same(2,2,4))
# False
print(PUF.same(3,2,4))
# False
print(PUF.same(4,2,4))
# True
print(PUF.same(5,2,4))
# True
print(PUF.size(6,5))
# 3
print(PUF.size(7,5))
# 5
④検証
・AtCoder Library Practice Contest A - Disjoint Set Union 提出結果 480ms AC
・AGC002D D - Stamp Rally 提出結果 1709ms AC
AGC002Dは少し工夫をするとギリギリ通せる実行時間になりました。
#完全永続UnionFind
すべてのバージョンに対して変更可能で、すべてのバージョンについてのクエリに答えることができます。
①仕組み
部分永続UnionFindとは似て非なるものです。具体にはUnionFind(by size)を用い、親配列(parents)を完全永続配列*として持っておいて更新していけばよいです。(部分永続UnionFindとは違って力技で実装するイメージです。)
*完全永続配列について↓
②コード
class Node:
def __init__(self,default):
self.rch = None
self.lch = None
self.val = default
class PersistentArray:
def __init__(self, ls, default_ver=0):
'''
ls (list) : 永続配列にしたい配列
default_ver (any): 最初のversion(デフォルト:0)
'''
N = len(ls)
self.N = N
self.K = (N - 1).bit_length()
self.N2 = 1 << self.K
self.dat = [Node(0) for i in range(2**(self.K + 1))]
for i in range(self.N): # 葉の構築
self.dat[self.N2 + i].val = ls[i]
self.build()
self.verdict = dict()
self.verdict[default_ver] = 1 # 各versionの根のindexを格納
def build(self):
for node in range(self.N2 - 1, 0, -1):
self.dat[node].rch = self.dat[node<<1 | 1]
self.dat[node].lch = self.dat[node<<1]
self.dat.pop()
self.dat.pop()
def get_t_x(self, t, x): # ver.tにおけるリストのx番目の値
'''
version:tにおけるindex:xの値を出力(O(logN))
'''
x += self.N2
v = self.dat[self.verdict[t]] # ver.tの根
path = bin(x)[3:]
for i in path:
if i == '0':
v = v.lch
else:
v = v.rch
return v.val
def update_told_tnew_x_val(self, t_old, t_new, x, val):
'''
version:t_oldのindex:xをvalに変更したものをversion:t_newとする(O(logN))
t_old: 変更前のversion
t_new: 変更後のversion
x: 変更するindex
val: 変更後の値
'''
if not (t_old in self.verdict):
raise('No such version exists')
x += self.N2
path = bin(x)[3:]
self.verdict[t_new] = len(self.dat)
new_nodes = [Node(0) for _ in range(len(path)+1)]
v_old = self.dat[self.verdict[t_old]]
v_new = new_nodes[0]
now = 1
for i in path: # ノードをつなげる
if i == '0':
v_new.rch = v_old.rch
v_new.lch = new_nodes[now]
v_new = new_nodes[now]
v_old = v_old.lch
else:
v_new.lch = v_old.lch
v_new.rch = new_nodes[now]
v_new = new_nodes[now]
v_old = v_old.rch
now += 1
v_new.val = val
self.dat.append(new_nodes[0])
def get_t_all(self,t):
'''
version:tにおける配列を出力(O(NlogN))
return : list
'''
if not (t in self.verdict):
raise('No such version exists')
ret = []
for i in range(self.N2,self.N2+self.N):
path = bin(i)[3:]
v = self.dat[self.verdict[t]]
for p in path:
if p == '0':
v = v.lch
else:
v = v.rch
ret.append(v.val)
return ret
def __getitem__(self, xt): return self.get_t_x(xt[0], xt[1])
import collections
class PersistentUnionFind():
def __init__(self, n, default_ver=0):
self.INF = 10**9
self.n = n
self.parents = PersistentArray([-1]*(n), default_ver)
self.groups = dict()
self.groups[default_ver] = n
def find(self, t, x):
'''
version:tにおけるxの根を見つける
t (any) : version
x (int) : 要素
return (int) : 根
'''
p = self.parents[t,x]
while p >= 0:
x = p
p = self.parents[t,x]
return x
def union(self, t_old, t_new, x, y):
'''
version:t_oldのxとyをつなげたversionをversion:t_newとする(O(logN))
t_old (any) : 更新前のversion
t_new (any) : 更新後のversion
x (int) : 要素
y (int) : 要素
'''
x = self.find(t_old,x)
y = self.find(t_old,y)
if x == y:
self.parents.update_told_tnew_x_val(t_old,t_new,x,self.parents[t_old,x])
self.groups[t_new] = self.groups[t_old]
return
self.groups[t_new] = self.groups[t_old]-1
if self.parents[t_old,x] > self.parents[t_old,y]:
x,y = y,x
self.parents.update_told_tnew_x_val(t_old,t_new+self.INF,x,self.parents[t_old,x]+self.parents[t_old,y])
self.parents.update_told_tnew_x_val(t_new+self.INF,t_new,y,x)
def same(self, t, x, y):
'''
version:tにおけるx,yが同じかどうかO(logN)
t (any) : version
x (int) : 要素
y (int) : 要素
return (bool) : 同じかどうか
'''
return self.find(t,x) == self.find(t,y)
def size_t_x(self, t, x):#O(1)xが含まれる集合の要素数
'''
version:tにおける要素xが含まれる集合の大きさ
t (any) : version
x (int) : 要素
return (int) :集合の大きさ
'''
return -self.parents[t,self.find(t,x)]
def members_t_x(self, t, x):
'''
version:tにおける要素が含まれるグループを出力 O(Nlog(N))
t (any) : version
x (int) : 要素
return (list) : 要素リスト
'''
r = self.find(t,x)
ret = [r]
for i in range(self.n):
if r == self.find(t,i):
ret.append(i)
return ret
def root_t(self,t):#根の要素O(N)
'''
version:tの根の数を出力 O(Nlog(N))
t (any) : version
return (list) : 根
'''
ret = []
for i in range(self.n):
if self.parents[t,i] < 0:
ret.append(i)
return ret
def group_count_t(self,t):
'''
version:tの根の数を出力 O(1)
t (any) : version
return (int) : 根の数(グループ数)
'''
return self.groups[t]
def allgroup_members_t(self,t):#O(N)
'''
version:tの根とメンバーをすべて出力 O(Nlog(N))
t (any) : version
return (dict) : keyが根、要素がメンバー(list)
'''
group_members = collections.defaultdict(lambda:[])
for i in range(self.n):
p = self.find(t,i)
if p < 0:
group_members[p].append(p)
else:
group_members[p].append(i)
return group_members
計算量は永続配列の更新、値の取得にO(log(N))かかるので、
・union: O((logN)2)
・find: O((logN)2)
・size: O((logN)2)
となります。
③使用例
PUF = PersistentUnionFind(10)
PUF.union(0,1,1,2)
PUF.union(1,2,4,5)
PUF.union(0,3,4,5)
PUF.union(2,4,2,4)
PUF.union(4,5,1,5)
PUF.union(4,6,8,9)
PUF.union(6,7,4,8)
PUF.union(2,8,4,8)
for i in range(9):
print(PUF.group_count_t(i))
print(PUF.root_t(i))
print(PUF.allgroup_members_t(i))
'''output #defaultdict(~~,は略
10
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
{0: [0], 1: [1], 2: [2], 3: [3], 4: [4], 5: [5], 6: [6], 7: [7], 8: [8], 9: [9]}
9
[0, 1, 3, 4, 5, 6, 7, 8, 9]
{0: [0], 1: [1, 2], 3: [3], 4: [4], 5: [5], 6: [6], 7: [7], 8: [8], 9: [9]}
8
[0, 1, 3, 4, 6, 7, 8, 9]
{0: [0], 1: [1, 2], 3: [3], 4: [4, 5], 6: [6], 7: [7], 8: [8], 9: [9]}
9
[0, 1, 2, 3, 4, 6, 7, 8, 9]
{0: [0], 1: [1], 2: [2], 3: [3], 4: [4, 5], 6: [6], 7: [7], 8: [8], 9: [9]}
7
[0, 1, 3, 6, 7, 8, 9]
{0: [0], 1: [1, 2, 4, 5], 3: [3], 6: [6], 7: [7], 8: [8], 9: [9]}
7
[0, 1, 3, 6, 7, 8, 9]
{0: [0], 1: [1, 2, 4, 5], 3: [3], 6: [6], 7: [7], 8: [8], 9: [9]}
6
[0, 1, 3, 6, 7, 8]
{0: [0], 1: [1, 2, 4, 5], 3: [3], 6: [6], 7: [7], 8: [8, 9]}
5
[0, 1, 3, 6, 7]
{0: [0], 1: [1, 2, 4, 5, 8, 9], 3: [3], 6: [6], 7: [7]}
7
[0, 1, 3, 4, 6, 7, 9]
{0: [0], 1: [1, 2], 3: [3], 4: [4, 5, 8], 6: [6], 7: [7], 9: [9]}
'''
少しわかりにくいですが、途中のversionからの更新もできていることが分かります。
図にすると以下のようになります。
④検証
・AtCoder Library Practice Contest A - Disjoint Set Union 提出結果 2353ms AC
・AGC002D D - Stamp Rally 提出結果 TLE
・yosupo Library Checker Persistent UnionFind 提出結果 AC 3718ms
部分永続UnionFindよりさらに遅くなっていることが分かります汗
#参考記事
#最後に
永続データ構造いろいろと実装して調べてみてはいますが、やはり競技プログラミングの制約だとpythonでは実行時間がギリギリかTLEが多いようです。この辺りが競技プログラミング上級者には選べばれにくい理由なのかもしれません。