Union-Findとは
Union-Find
は、グループ分けを管理できるもの。
主に2つの操作を行うことができる。
- グループの接合(Union)
- グループに属するかの判定(Find)
これら2つが行えるため、Union-Findと呼ばれる。
下記にUnionFindのクラスを参照してあるので、これを使って説明する。
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
各メソッドの使い方
UnionFind
のそれぞれのメソッドについて1つずつ説明していく。
parents
- 各要素の親要素の番号を保存するためのリスト
# parents
uf_3 = UnionFind(3)
print(uf_3.parents)
uf_5 = UnionFind(5)
print(uf_5.parents)
# 出力
>> [-1, -1, -1]
>> [-1, -1, -1, -1, -1]
union(x, y)
- ある要素
x
が属するグループと別の要素y
が属するグループを接合する
# union(x, y)
uf_3.union(1, 2)
print(uf_3.parents)
uf_3.union(0, 1)
print(uf_3.parents)
uf_5.union(1, 2)
print(uf_5.parents)
uf_5.union(2, 4)
print(uf_5.parents)
# 出力
>> [-1, -2, 1]
>> [1, -3, 1]
>> [-1, -2, 1, -1, -1]
>> [-1, -3, 1, -1, 1]
find(x)
- ある要素
x
が属するグループの根を返す
# parents
print(uf_3.find(2)) # 0とunion()したので、親は1
print(uf_3.find(1)) # 1の親はもちろん1
print(uf_5.find(3)) # 3はどれともつながっていないので、3
print(uf_5.find(4)) # union(1, 2)とunion(2, 4)より親は1
# 出力例
>> 1
>> 1
>> 3
>> 1
size(x)
- ある要素
x
の属するグループの大きさを返す
# size(x)
print(uf_3.size(2))
print(uf_3.size(1))
print(uf_5.size(3))
print(uf_5.size(4))
# 出力
>> 3
>> 3
>> 1
>> 3
same(x, y)
- ある要素
x
とある要素y
が同じグループに属しているか判定する
# same(x, y)
print(uf_3.same(1, 2))
print(uf_3.same(0, 2))
print(uf_5.same(1, 4))
print(uf_5.same(1, 3)) # 1,3はつながっていないためFalse
# 出力
>> True
>> True
>> True
>> False
members(x)
- ある要素
x
が属するグループの要素をリストで返す
# members()
print(uf_3.members(0))
print(uf_3.members(1))
print(uf_5.members(1))
print(uf_5.members(3))
# 出力
>> [0, 1, 2]
>> [0, 1, 2]
>> [1, 2, 4]
>> [3]
roots()
- その木に属するすべての根の要素をリストで返す
# roots()
print(uf_3.roots())
print(uf_5.roots())
# 出力
>> [1]
>> [0, 1, 3]
group_count()
- その木のグループの数を返す
# group_count()
print(uf_3.group_count())
print(uf_5.group_count())
# 出力
>> 1
>> 3
all_group_members()
- 木に属する要素とそのグループの中身を辞書で返す
# all_group_members
print(uf_3.all_group_members())
print(uf_5.all_group_members())
# 出力
# >> defaultdict(<class 'list'>, {1: [0, 1, 2]})
# >> defaultdict(<class 'list'>, {0: [0], 1: [1, 2, 4], 3: [3]})
ここからは実際の問題と解答例を載せていく。
例題1: ATC001 B - Union Find
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def same(self, x, y):
return self.find(x) == self.find(y)
if __name__ == '__main__':
n, q = map(int, input().split())
uf = UnionFind(n)
for i in range(q):
p, a, b = map(int, input().split())
if p == 0:
uf.union(a, b)
else:
if uf.same(a, b):
print('Yes')
else:
print('No')
例題2: ABC049 D - 連結
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
if __name__ == '__main__':
n, k, l = map(int, input().split())
uf1 = UnionFind(n) # 道路
uf2 = UnionFind(n) # 鉄道
for i in range(k):
p, q = map(int, input().split())
uf1.union(p-1, q-1)
for i in range(l):
r, s = map(int, input().split())
uf2.union(r-1, s-1)
d = defaultdict(int)
result = []
for i in range(n):
result.append((uf1.find(i), uf2.find(i))) # 各都市がつながっている親(根)探し
d[(uf1.find(i), uf2.find(i))] += 1 # 数える
ans = []
for re in result:
ans.append(d[re])
print(*ans)
例題3: ABC075 C - Bridge
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
if __name__ == '__main__':
n, m = map(int, input().split())
ans = 0
alist = []
blist = []
for i in range(m):
a, b = map(int, input().split())
alist.append(a-1)
blist.append(b-1)
for i in range(m):
uf = UnionFind(n)
for j in range(m):
if j != i: # 辺をひとつ潰してunion
uf.union(alist[j], blist[j])
if uf.size(0) < n: # 連結している頂点の数の比較
ans += 1
print(ans)
例題4: ABC120 D - Decayed Bridges
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
if __name__ == '__main__':
n, m = map(int, input().split())
uf = UnionFind(n)
result = []
for i in range(m):
a, b = map(int, input().split())
result.append((a-1, b-1))
ans = []
ans.append(n*(n-1)//2)
for i in range(m-1, 0, -1):
a, b = result[i]
if uf.same(a, b):
uf.union(a, b)
ans.append(ans[m-i-1])
else:
ans.append(ans[m-i-1]-(uf.size(a)*uf.size(b)))
uf.union(a, b)
for a in reversed(ans):
print(a)
例題5: ABC214 D - Sum of Maximum Weights
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
if __name__ == '__main__':
n = int(input())
es = []
for i in range(n-1):
u, v, w = map(int, input().split())
u -= 1
v -= 1
es.append([w, (u, v)])
es.sort()
uf = UnionFind(n)
ans = 0
for w, e in es:
a, b = e
ans += w * uf.size(a) * uf.size(b)
uf.union(a, b)
print(ans)
例題6: ARC032 B - 道路工事
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
if __name__ == '__main__':
n, m = map(int, input().split())
uf = UnionFind(n)
for i in range(m):
a, b = map(int, input().split())
a, b = a-1, b-1
uf.union(a, b)
ans = uf.group_count()-1
print(ans)
例題7: ABC231 D - Neighbors
def main():
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
n, m = map(int, input().split())
uf = UnionFind(n)
degree = [0] * n # 頂点から出ている辺を数えるための配列
for _ in range(m):
a, b = map(int, input().split())
a, b = a-1, b-1
degree[a] += 1
degree[b] += 1
if uf.same(a, b): # 互いがつながっていたらアウト
print('No')
exit()
uf.union(a, b) # 頂点aと頂点bを結合
if max(degree) <= 2: # ひとつの頂点から3本以上の辺がでていなければ
print('Yes')
else:
print('No')
if __name__ == '__main__':
main()
まとめ
UnionFind
は、木をグループわけするのに便利なものだ。
グループごとに仕分け、操作する問題があれば、積極的にUnionFind
を使ってみよう。