0
1

More than 1 year has passed since last update.

Algorithm | Union-FindをPython3で解説(例題あり)

Last updated at Posted at 2021-08-20

Union-Findとは

Union-Findは、グループ分けを管理できるもの。

主に2つの操作を行うことができる。

  • グループの接合(Union)
  • グループに属するかの判定(Find)

これら2つが行えるため、Union-Findと呼ばれる。

下記にUnionFindのクラスを参照してあるので、これを使って説明する。

from collections import defaultdict

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

各メソッドの使い方

UnionFindのそれぞれのメソッドについて1つずつ説明していく。

parents

  • 各要素の親要素の番号を保存するためのリスト

image.png

# parents
uf_3 = UnionFind(3)
print(uf_3.parents)
uf_5 = UnionFind(5)
print(uf_5.parents)
# 出力
>> [-1, -1, -1]
>> [-1, -1, -1, -1, -1]

union(x, y)

  • ある要素xが属するグループと別の要素yが属するグループを接合する

image.png

# union(x, y)
uf_3.union(1, 2)
print(uf_3.parents)
uf_3.union(0, 1)
print(uf_3.parents)
uf_5.union(1, 2)
print(uf_5.parents)
uf_5.union(2, 4)
print(uf_5.parents)
# 出力
>> [-1, -2, 1]
>> [1, -3, 1]
>> [-1, -2, 1, -1, -1]
>> [-1, -3, 1, -1, 1]

find(x)

  • ある要素xが属するグループの根を返す

image.png

# parents
print(uf_3.find(2)) # 0とunion()したので、親は1
print(uf_3.find(1)) # 1の親はもちろん1

print(uf_5.find(3)) # 3はどれともつながっていないので、3
print(uf_5.find(4)) # union(1, 2)とunion(2, 4)より親は1
# 出力例
>> 1
>> 1
>> 3
>> 1

size(x)

  • ある要素xの属するグループの大きさを返す

image.png

# size(x)
print(uf_3.size(2))
print(uf_3.size(1))

print(uf_5.size(3))
print(uf_5.size(4))
# 出力
>> 3
>> 3
>> 1
>> 3

same(x, y)

  • ある要素xとある要素yが同じグループに属しているか判定する

image.png

# same(x, y)
print(uf_3.same(1, 2))
print(uf_3.same(0, 2))
print(uf_5.same(1, 4))
print(uf_5.same(1, 3)) # 1,3はつながっていないためFalse
# 出力
>> True
>> True
>> True
>> False

members(x)

  • ある要素xが属するグループの要素をリストで返す

image.png

# members()
print(uf_3.members(0))
print(uf_3.members(1))
print(uf_5.members(1))
print(uf_5.members(3))
# 出力
>> [0, 1, 2]
>> [0, 1, 2]
>> [1, 2, 4]
>> [3]

roots()

  • その木に属するすべての根の要素をリストで返す

image.png

# roots()
print(uf_3.roots())
print(uf_5.roots())
# 出力
>> [1]
>> [0, 1, 3]

group_count()

  • その木のグループの数を返す

image.png

# group_count()
print(uf_3.group_count())
print(uf_5.group_count())
# 出力
>> 1
>> 3

all_group_members()

  • 木に属する要素とそのグループの中身を辞書で返す
# all_group_members
print(uf_3.all_group_members())
print(uf_5.all_group_members())
# 出力
# >> defaultdict(<class 'list'>, {1: [0, 1, 2]})
# >> defaultdict(<class 'list'>, {0: [0], 1: [1, 2, 4], 3: [3]})

ここからは実際の問題と解答例を載せていく。

例題1: ATC001 B - Union Find

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def same(self, x, y):
        return self.find(x) == self.find(y)

if __name__ == '__main__':
    n, q = map(int, input().split())

    uf = UnionFind(n)

    for i in range(q):
        p, a, b = map(int, input().split())
        if p == 0:
            uf.union(a, b)
        else:
            if uf.same(a, b):
                print('Yes')
            else:
                print('No')

例題2: ABC049 D - 連結

from collections import defaultdict

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x


if __name__ == '__main__':
    n, k, l = map(int, input().split())

    uf1 = UnionFind(n) # 道路
    uf2 = UnionFind(n) # 鉄道

    for i in range(k):
        p, q = map(int, input().split())
        uf1.union(p-1, q-1)

    for i in range(l):
        r, s = map(int, input().split())
        uf2.union(r-1, s-1)

    d = defaultdict(int)
    result = []

    for i in range(n):
        result.append((uf1.find(i), uf2.find(i))) # 各都市がつながっている親(根)探し
        d[(uf1.find(i), uf2.find(i))] += 1 # 数える

    ans = []

    for re in result:
        ans.append(d[re])

    print(*ans)

例題3: ABC075 C - Bridge

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]


if __name__ == '__main__':
    n, m = map(int, input().split())

    ans = 0
    alist = []
    blist = []

    for i in range(m):
        a, b = map(int, input().split())
        alist.append(a-1)
        blist.append(b-1)

    for i in range(m):
        uf = UnionFind(n)
        for j in range(m):
            if j != i: # 辺をひとつ潰してunion
                uf.union(alist[j], blist[j])

        if uf.size(0) < n: # 連結している頂点の数の比較
            ans += 1

    print(ans)

例題4: ABC120 D - Decayed Bridges

from collections import defaultdict


class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)


if __name__ == '__main__':
    n, m = map(int, input().split())

    uf = UnionFind(n)

    result = []

    for i in range(m):
        a, b = map(int, input().split())
        result.append((a-1, b-1))

    ans = []

    ans.append(n*(n-1)//2)

    for i in range(m-1, 0, -1):
        a, b = result[i]
        if uf.same(a, b):
            uf.union(a, b)
            ans.append(ans[m-i-1])
        else:
            ans.append(ans[m-i-1]-(uf.size(a)*uf.size(b)))
            uf.union(a, b)

    for a in reversed(ans):
        print(a)

例題5: ABC214 D - Sum of Maximum Weights


class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]


if __name__ == '__main__':
    n = int(input())
    es = []

    for i in range(n-1):
        u, v, w = map(int, input().split())
        u -= 1
        v -= 1
        es.append([w, (u, v)])

    es.sort()
    uf = UnionFind(n)
    ans = 0

    for w, e in es:
        a, b = e
        ans += w * uf.size(a) * uf.size(b)
        uf.union(a, b)

    print(ans)

例題6: ARC032 B - 道路工事

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())


if __name__ == '__main__':
    n, m = map(int, input().split())

    uf = UnionFind(n)

    for i in range(m):
        a, b = map(int, input().split())
        a, b = a-1, b-1
        uf.union(a, b)

    ans = uf.group_count()-1
    print(ans)

例題7: ABC231 D - Neighbors

def main():
    from collections import defaultdict

    class UnionFind():
        def __init__(self, n):
            self.n = n
            self.parents = [-1] * n

        def find(self, x):
            if self.parents[x] < 0:
                return x
            else:
                self.parents[x] = self.find(self.parents[x])
                return self.parents[x]

        def union(self, x, y):
            x = self.find(x)
            y = self.find(y)

            if x == y:
                return

            if self.parents[x] > self.parents[y]:
                x, y = y, x

            self.parents[x] += self.parents[y]
            self.parents[y] = x

        def size(self, x):
            return -self.parents[self.find(x)]

        def same(self, x, y):
            return self.find(x) == self.find(y)

        def members(self, x):
            root = self.find(x)
            return [i for i in range(self.n) if self.find(i) == root]

        def roots(self):
            return [i for i, x in enumerate(self.parents) if x < 0]

        def group_count(self):
            return len(self.roots())

        def all_group_members(self):
            group_members = defaultdict(list)
            for member in range(self.n):
                group_members[self.find(member)].append(member)
            return group_members

        def __str__(self):
            return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

    n, m = map(int, input().split())

    uf = UnionFind(n)
    degree = [0] * n # 頂点から出ている辺を数えるための配列

    for _ in range(m):
        a, b = map(int, input().split())
        a, b = a-1, b-1
        degree[a] += 1
        degree[b] += 1
        if uf.same(a, b): # 互いがつながっていたらアウト
            print('No')
            exit()
        uf.union(a, b) # 頂点aと頂点bを結合

    if max(degree) <= 2: # ひとつの頂点から3本以上の辺がでていなければ
        print('Yes')
    else:
        print('No')


if __name__ == '__main__':
    main()

まとめ

UnionFindは、木をグループわけするのに便利なものだ。
グループごとに仕分け、操作する問題があれば、積極的にUnionFindを使ってみよう。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1