1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

UnionFind(相当)を10行で実装する

Last updated at Posted at 2020-02-28

**素集合データ構造(Union-Find)**は、要素を互いに交わらないグループに分類するデータ構造です。標準的な実装方法はUnion-Find木です。・・・というのは皆さんすでにご存知と思います。

Union-Find木はQiitaにも実装例がたくさんあるのですが、私は以前から疑問に思っていました。

  1. find()の再帰呼び出しがあって分かりにくい。もっと分かりやすく・短く書けないか?
  2. 組み込みのコレクション型を利用した方が、実は高速なのではないか?

10行実装

そこで、Pythonでdictとfrozensetを使って短い実装を書いてみました。空白を含めても10行です。

class EasyUnionFind:
    def __init__(self, n):
        self._groups = {x: frozenset([x]) for x in range(n)}

    def union(self, x, y):
        group = self._groups[x] | self._groups[y]
        self._groups.update((c, group) for c in group)

    def groups(self):
        return frozenset(self._groups.values())

比較してみた

Union-Find木による実装と比較してみます。
比較対象は、Qiitaで「Union Find Python」で検索して見つけた一番いいね数が多い記事紹介されていた実装です。

今回の10行実装の方が遅いという結果になりました。また、要素数が増えるに従って差が開いているようです。残念。

要素数 Union-Find木実装 今回の10行実装 所要時間の比
1000 0.72秒 1.17秒 1.63
2000 1.46秒 2.45秒 1.68
4000 2.93秒 5.14秒 1.75
8000 6.01秒 11.0秒 1.83

ただ、遅いといってもUnion-Find木の2倍弱程度なので、場合によっては利用価値がある・・・かもしれません。

比較用コードと実行結果

コード:

import random
import timeit
import sys
import platform


class EasyUnionFind:
    """
    dict と frozenset を使った実装。
    """
    def __init__(self, n):
        self._groups = {x: frozenset([x]) for x in range(n)}

    def union(self, x, y):
        group = self._groups[x] | self._groups[y]
        self._groups.update((c, group) for c in group)

    def groups(self):
        return frozenset(self._groups.values())


class UnionFind(object):
    """
    典型的なUnion-Find木による実装。
    https://www.kumilog.net/entry/union-find の実装例をコピーしたが、
    今回不要なメンバ関数を削除し .groups() を追加した。
    """
    def __init__(self, n=1):
        self.par = [i for i in range(n)]
        self.rank = [0 for _ in range(n)]
        self.size = [1 for _ in range(n)]
        self.n = n

    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            self.par[x] = self.find(self.par[x])
            return self.par[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x != y:
            if self.rank[x] < self.rank[y]:
                x, y = y, x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1
            self.par[y] = x
            self.size[x] += self.size[y]

    def groups(self):
        groups = {}
        for i in range(self.n):
            groups.setdefault(self.find(i), []).append(i)
        return frozenset(frozenset(group) for group in groups.values())


def test1():
    """2つの実装の結果が同じかどうかテストする。差異があれば AssertionError を送出する。"""
    print("===== TEST1 =====")
    random.seed(20200228)
    n = 2000
    for _ in range(1000):
        elements = range(n)
        pairs = [
            (random.choice(elements), random.choice(elements))
            for _ in range(n // 2)
        ]
        uf1 = UnionFind(n)
        uf2 = EasyUnionFind(n)
        for x, y in pairs:
            uf1.union(x, y)
            uf2.union(x, y)
        assert uf1.groups() == uf2.groups()
    print('ok')
    print()


def test2():
    """
    要素数を増やしながら、2つの実装の所要時間を出力する。
    """
    print("===== TEST2 =====")
    random.seed(20200228)

    def execute_union_find(klass, n, test_datum):
        for pairs in test_datum:
            uf = klass(n)
            for x, y in pairs:
                uf.union(x, y)

    timeit_number = 1
    for n in [1000, 2000, 4000, 8000]:
        print(f"n={n}")
        test_datum = []
        for _ in range(1000):
            elements = range(n)
            pairs = [
                (random.choice(elements), random.choice(elements))
                for _ in range(n // 2)
            ]
            test_datum.append(pairs)

        t = timeit.timeit(lambda: execute_union_find(UnionFind, n, test_datum), number=timeit_number)
        print("UnionFind", t)

        t = timeit.timeit(lambda: execute_union_find(EasyUnionFind, n, test_datum), number=timeit_number)
        print("EasyUnionFind", t)
        print()

def main():
    print(sys.version)
    print(platform.platform())
    print()
    test1()
    test2()

if __name__ == "__main__":
    main()

実行結果

3.7.6 (default, Dec 30 2019, 19:38:28)
[Clang 11.0.0 (clang-1100.0.33.16)]
Darwin-18.7.0-x86_64-i386-64bit

===== TEST1 =====
ok

===== TEST2 =====
n=1000
UnionFind 0.7220867589999997
EasyUnionFind 1.1789850389999987

n=2000
UnionFind 1.460918638999999
EasyUnionFind 2.4546459260000013

n=4000
UnionFind 2.925022847000001
EasyUnionFind 5.142797402000003

n=8000
UnionFind 6.01257184
EasyUnionFind 10.963117657000005
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?