1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

pythonで永続UnionFind実装

Last updated at Posted at 2021-09-23

#はじめに
永続UnionFind(部分永続と完全永続)を書いたので紹介したいと思います。(実行速度検証あり)
(21/9/25若干の修正を加えました(主に計算量の改善))

#部分永続UnionFind
今のバージョンに対して変更可能で、前のバージョンに遡ってクエリに答えることができます。

###①仕組み
https://camypaper.bitbucket.io/2016/12/18/adc2016/

こちらの記事に分かりやすい解説が載っていました。
ポイントとしては
① unionfind(by rank または by size)を用いて実装。
② 各ノードにおいて親が更新された時刻をunion時に記録。
③ find時、親ノードをたどっていく時に指定した時刻tにおいて更に親があるかどうかを判断
(union時に記録した時刻と照らし合わせるとできる。親がなければ、そのノードが時刻tにおける親。あればさらに親をたどる)
といった感じです。

###②コード
https://misteer.hatenablog.com/entry/persistentUF

こちらの記事のコードを参考にunionfind by sizeで書いたものです。
(時刻tにおけるsizeを求められる機能も付いています。)

Persistent_Part_UF.py
class PersistentPartUnionFind:
    def __init__(self,N):
        INF = float('INF')
        self.now = 0
        self.N = 0
        self.parent = [-1 for i in range(N)]
        self.time = [INF for i in range(N)]
        self.num = [[] for i in range(N)]
        for i in range(N):
            self.num[i].append((0,1))
    
    def find(self,t,x):
        '''
        version:tにおけるxの根を見つける
        t (any) : version
        x (int) : 要素
        return : int : 根
        '''
        while self.time[x] <= t:
            x = self.parent[x]
        return x
    
    def union(self,x,y):
        '''
        x,yをつなげる
        x (int) : 要素
        y (int) : 要素
        '''
        self.now += 1
        x = self.find(self.now,x)
        y = self.find(self.now,y)

        if x == y:
            return
        
        if self.parent[x] > self.parent[y]:
            x,y = y,x
        
        
        self.parent[x] += self.parent[y]
        self.parent[y] = x
        self.time[y] = self.now
        self.num[x].append((self.now,-self.parent[x]))

    def same(self,t,x,y):
        '''
        version:tにおけるx,yが同じかどうかO(logN)
        t (any) : version
        x (int) : 要素
        y (int) : 要素
        return : bool : 同じかどうか
        '''
        return self.find(t,x) == self.find(t,y)

    def size(self,t,x):
        '''
        version:tにおける要素xが含まれる集合の大きさ
        t (any) : version
        x (int) : 要素
        return : int :集合の大きさ  
        '''
        x = self.find(t,x)

        ok = 0
        ng = len(self.num[x])
        while (ng-ok > 1):
            mid = (ok+ng)>>1
            if self.num[x][mid][0] <= t:
                ok = mid
            else:
                ng = mid
        
        return self.num[x][ok][1]

計算量はノード数N,union回数QとしてfindにO(log(N))程度かかるので、
・union: O(log(N))
・find: O(log(N))
・size: O(log(N)+log(Q))
となります。

③使用例

testA.py
PUF = PersistentPartUnionFind(10)
PUF.union(0,1) # 1
PUF.union(2,3) # 2
PUF.union(4,9) # 3
PUF.union(4,3) # 4
PUF.union(7,8) # 5
PUF.union(5,7) # 6
PUF.union(5,1) # 7

print(PUF.same(2,2,4))
# False
print(PUF.same(3,2,4))
# False
print(PUF.same(4,2,4))
# True
print(PUF.same(5,2,4))
# True

print(PUF.size(6,5))
# 3
print(PUF.size(7,5))
# 5

④検証

・AtCoder Library Practice Contest A - Disjoint Set Union 提出結果 480ms AC
・AGC002D D - Stamp Rally 提出結果 1709ms AC

AGC002Dは少し工夫をするとギリギリ通せる実行時間になりました。

#完全永続UnionFind
すべてのバージョンに対して変更可能で、すべてのバージョンについてのクエリに答えることができます。

①仕組み

部分永続UnionFindとは似て非なるものです。具体にはUnionFind(by size)を用い、親配列(parents)を完全永続配列*として持っておいて更新していけばよいです。(部分永続UnionFindとは違って力技で実装するイメージです。)

*完全永続配列について↓

②コード

Persistent_UF.py
class Node:
    def __init__(self,default):
        self.rch = None
        self.lch = None
        self.val = default

class PersistentArray:
    def __init__(self, ls, default_ver=0):
        '''
        ls (list) : 永続配列にしたい配列
        default_ver (any): 最初のversion(デフォルト:0)
        '''
        N = len(ls)
        self.N = N
        self.K = (N - 1).bit_length()
        self.N2 = 1 << self.K
        self.dat = [Node(0) for i in range(2**(self.K + 1))]
        for i in range(self.N):  # 葉の構築
            self.dat[self.N2 + i].val = ls[i]
        self.build()
        self.verdict = dict()
        self.verdict[default_ver] = 1 # 各versionの根のindexを格納

    def build(self):
        for node in range(self.N2 - 1, 0, -1):
            self.dat[node].rch = self.dat[node<<1 | 1]
            self.dat[node].lch = self.dat[node<<1]
            self.dat.pop()
            self.dat.pop()

    def get_t_x(self, t, x):  # ver.tにおけるリストのx番目の値
        '''
        version:tにおけるindex:xの値を出力(O(logN))
        '''
        x += self.N2
        v = self.dat[self.verdict[t]] # ver.tの根
        path = bin(x)[3:]
        for i in path:
            if i == '0':
                v = v.lch
            else:
                v = v.rch
        return v.val

    def update_told_tnew_x_val(self, t_old, t_new, x, val):
        '''
        version:t_oldのindex:xをvalに変更したものをversion:t_newとする(O(logN))
        t_old: 変更前のversion
        t_new: 変更後のversion
        x: 変更するindex
        val: 変更後の値
        '''
        if not (t_old in self.verdict):
            raise('No such version exists')
        x += self.N2
        path = bin(x)[3:]
        self.verdict[t_new] = len(self.dat)
        new_nodes = [Node(0) for _ in range(len(path)+1)]
        v_old = self.dat[self.verdict[t_old]]
        v_new = new_nodes[0]
        now = 1
        for i in path: # ノードをつなげる
            if i == '0':
                v_new.rch = v_old.rch
                v_new.lch = new_nodes[now]
                v_new = new_nodes[now]
                v_old = v_old.lch 
            else:
                v_new.lch = v_old.lch
                v_new.rch = new_nodes[now]
                v_new = new_nodes[now]
                v_old = v_old.rch
            now += 1
        v_new.val = val
        self.dat.append(new_nodes[0])

    def get_t_all(self,t):
        '''
        version:tにおける配列を出力(O(NlogN))
        return : list
        '''
        if not (t in self.verdict):
            raise('No such version exists')
        ret = []
        for i in range(self.N2,self.N2+self.N):
            path = bin(i)[3:]
            v = self.dat[self.verdict[t]]
            for p in path:
                if p == '0':
                    v = v.lch
                else:
                    v = v.rch
            ret.append(v.val)
        return ret

    def __getitem__(self, xt): return self.get_t_x(xt[0], xt[1])

import collections
class PersistentUnionFind():
    def __init__(self, n, default_ver=0):
        self.INF = 10**9
        self.n = n
        self.parents = PersistentArray([-1]*(n), default_ver)
        self.groups = dict()
        self.groups[default_ver] = n

    def find(self, t, x):
        '''
        version:tにおけるxの根を見つける
        t (any) : version
        x (int) : 要素
        return (int) : 根
        '''
        p = self.parents[t,x]
        while p >= 0:
            x = p
            p = self.parents[t,x]
        return x

    def union(self, t_old, t_new, x, y):
        '''
        version:t_oldのxとyをつなげたversionをversion:t_newとする(O(logN))
        t_old (any) : 更新前のversion
        t_new (any) : 更新後のversion
        x (int) : 要素
        y (int) : 要素
        '''
        x = self.find(t_old,x)
        y = self.find(t_old,y)
        if x == y:
            self.parents.update_told_tnew_x_val(t_old,t_new,x,self.parents[t_old,x])
            self.groups[t_new] = self.groups[t_old]
            return
        self.groups[t_new] = self.groups[t_old]-1
        if self.parents[t_old,x] > self.parents[t_old,y]:
            x,y = y,x
        self.parents.update_told_tnew_x_val(t_old,t_new+self.INF,x,self.parents[t_old,x]+self.parents[t_old,y])
        self.parents.update_told_tnew_x_val(t_new+self.INF,t_new,y,x)

    def same(self, t, x, y):
        '''
        version:tにおけるx,yが同じかどうかO(logN)
        t (any) : version
        x (int) : 要素
        y (int) : 要素
        return (bool) : 同じかどうか
        '''
        return self.find(t,x) == self.find(t,y)

    def size_t_x(self, t, x):#O(1)xが含まれる集合の要素数
        '''
        version:tにおける要素xが含まれる集合の大きさ
        t (any) : version
        x (int) : 要素
        return (int) :集合の大きさ  
        '''
        return -self.parents[t,self.find(t,x)]

    def members_t_x(self, t, x):
        '''
        version:tにおける要素が含まれるグループを出力 O(Nlog(N))
        t (any) : version
        x (int) : 要素
        return (list) : 要素リスト
        '''
        r = self.find(t,x)
        ret = [r]
        for i in range(self.n):
            if r == self.find(t,i):
                ret.append(i)
        return ret

    def root_t(self,t):#根の要素O(N)
        '''
        version:tの根の数を出力 O(Nlog(N))
        t (any) : version
        return (list) : 根
        '''
        ret = []
        for i in range(self.n):
            if self.parents[t,i] < 0:
                ret.append(i)
        return ret

    def group_count_t(self,t):
        '''
        version:tの根の数を出力 O(1)
        t (any) : version
        return (int) : 根の数(グループ数)
        '''
        return self.groups[t]

    def allgroup_members_t(self,t):#O(N)
        '''
        version:tの根とメンバーをすべて出力 O(Nlog(N))
        t (any) : version
        return (dict) : keyが根、要素がメンバー(list)
        '''
        group_members = collections.defaultdict(lambda:[])
        for i in range(self.n):
            p = self.find(t,i)
            if p < 0:
                group_members[p].append(p)
            else:
                group_members[p].append(i)
        return group_members

計算量は永続配列の更新、値の取得にO(log(N))かかるので、
・union: O((logN)2)
・find: O((logN)2)
・size: O((logN)2)
となります。

③使用例

testB.py
PUF = PersistentUnionFind(10)
PUF.union(0,1,1,2)
PUF.union(1,2,4,5)
PUF.union(0,3,4,5)
PUF.union(2,4,2,4)
PUF.union(4,5,1,5)
PUF.union(4,6,8,9)
PUF.union(6,7,4,8)
PUF.union(2,8,4,8)
for i in range(9):
    print(PUF.group_count_t(i))
    print(PUF.root_t(i))
    print(PUF.allgroup_members_t(i))

'''output #defaultdict(~~,は略
10
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
{0: [0], 1: [1], 2: [2], 3: [3], 4: [4], 5: [5], 6: [6], 7: [7], 8: [8], 9: [9]}
9
[0, 1, 3, 4, 5, 6, 7, 8, 9]
{0: [0], 1: [1, 2], 3: [3], 4: [4], 5: [5], 6: [6], 7: [7], 8: [8], 9: [9]}
8
[0, 1, 3, 4, 6, 7, 8, 9]
{0: [0], 1: [1, 2], 3: [3], 4: [4, 5], 6: [6], 7: [7], 8: [8], 9: [9]}
9
[0, 1, 2, 3, 4, 6, 7, 8, 9]
{0: [0], 1: [1], 2: [2], 3: [3], 4: [4, 5], 6: [6], 7: [7], 8: [8], 9: [9]}
7
[0, 1, 3, 6, 7, 8, 9]
{0: [0], 1: [1, 2, 4, 5], 3: [3], 6: [6], 7: [7], 8: [8], 9: [9]}
7
[0, 1, 3, 6, 7, 8, 9]
{0: [0], 1: [1, 2, 4, 5], 3: [3], 6: [6], 7: [7], 8: [8], 9: [9]}
6
[0, 1, 3, 6, 7, 8]
{0: [0], 1: [1, 2, 4, 5], 3: [3], 6: [6], 7: [7], 8: [8, 9]}
5
[0, 1, 3, 6, 7]
{0: [0], 1: [1, 2, 4, 5, 8, 9], 3: [3], 6: [6], 7: [7]}
7
[0, 1, 3, 4, 6, 7, 9]
{0: [0], 1: [1, 2], 3: [3], 4: [4, 5, 8], 6: [6], 7: [7], 9: [9]}
'''

少しわかりにくいですが、途中のversionからの更新もできていることが分かります。
図にすると以下のようになります。
PUF1.png

④検証

・AtCoder Library Practice Contest A - Disjoint Set Union 提出結果 2353ms AC
・AGC002D D - Stamp Rally 提出結果 TLE

・yosupo Library Checker Persistent UnionFind 提出結果 AC 3718ms

部分永続UnionFindよりさらに遅くなっていることが分かります汗

#参考記事

#最後に
永続データ構造いろいろと実装して調べてみてはいますが、やはり競技プログラミングの制約だとpythonでは実行時間がギリギリかTLEが多いようです。この辺りが競技プログラミング上級者には選べばれにくい理由なのかもしれません。

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?