search
LoginSignup
38

More than 1 year has passed since last update.

posted at

updated at

UnionFindをPythonでやるための最低限の実装

こんにちは、しいです。

たまにUnionFindの問題が出ても他人のライブラリ使いがちだったので復習しとこうとおもって書きます。忘れっぽい自分用のノートみたいなものだとおもってもらえれば

重要例題 AtCoder Typical Contest 001 B - Union Find
https://atcoder.jp/contests/atc001/tasks/unionfind_a

UnionFindの純粋な問題ですね
解説スライドもついているのでこれに沿って理解していきます。

https://www.slideshare.net/chokudai/union-find-49066733
Union find(素集合データ構造) from AtCoder Inc

UnionFindの機能

  1. グループ作り (Union)
  2. グループに属しているかどうか (Find)

この2つだけです。グループ作成の具体的なイメージはこんな感じ↓
input で与えられた者同士をつなげて木として扱います。

グループの作成

グループに属しているか

グループに属しているかどうかは同一の親を持つかどうかで判断します。

高速化のtip1:経路圧縮

縦に長くなると毎回親を探すのに時間がかかってしまいます。しかし今回はどのグループに属しているかが大事なので誰と誰がつながっているかはあまり重要ではないためfindするたびに親に直接つなぎ合わせてあげることで高速化します。

高速化のtip2 : ランク

木の高さを保持しておき低いほうを高いほうにつなげることで経路圧縮の計算量を減らし高速化する

実装(rankなし)

まず各要素の親をpar(parentsの略)で管理します。
最初は各要素はどこのグループにも属していないので自分自身が親になります。要素数をNとすると

par = [i for i in range(N+1)]

となります。

次に自分の親を見つける関数をfind、要素x,yが同じグループかを確認する関数をsameとすると

find

自分が親であるときは自分の番号を返しそれ以外の時はもう一度findを行うことで親を探すと同時につなぎなおしています(経路圧縮)

def find(x):
    if par[x] == x:
        return x
    else:
        par[x] = find(par[x]) #経路圧縮
        return par[x]

same

お互いの親が一緒であるかを確認して同じグループかどうか判断しています。

def same(x,y):
    return find(x) == find(y)

unite

それぞれの親を確認して異なる場合のみ親を統一します。

def unite(x,y):
    x = find(x)
    y = find(y)
    if x == y:
        return 0
    par[x] = y

以上をまとめ、今回の例題を解いてみると

N,Q = map(int,input().split())
par = [i for i in range(N+1)]
def find(x):
    if par[x] == x:
        return x
    else:
        par[x] = find(par[x]) #経路圧縮
        return par[x]
def same(x,y):
    return find(x) == find(y)
def unite(x,y):
    x = find(x)
    y = find(y)
    if x == y:
        return 0
    par[x] = y
for i in range(Q):
    p,a,b = map(int,input().split())
    if p == 0:
        unite(a,b)
    else:
        if same(a,b):
            print('Yes')
        else:
            print('No')

rankを実装するパターン

rankを実装する場合は各頂点の高さを管理してあげる必要があります

N,Q = map(int,input().split())
par = [i for i in range(N+1)]
rank = [0]*(N+1)
def find(x):
    if par[x] == x:
        return x
    else:
        par[x] = find(par[x]) #経路圧縮
        return par[x]
def same(x,y):
    return find(x) == find(y)
def unite(x,y):
    x = find(x)
    y = find(y)
    if x == y:
        return 0
    if rank[x] < rank[y]:
        par[x] = y
    else:
        par[y] = x
        if rank[x]==rank[y]:rank[x]+=1

for i in range(Q):
    p,a,b = map(int,input().split())
    if p == 0:
        unite(a,b)
    else:
        if same(a,b):
            print('Yes')
        else:
            print('No')

最後に

以上がPythonによる最低限の実装でした。
もし何か間違ったとこがあればどんどんいってください

参考にさせてもらったサイト様

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
38