LoginSignup
7
3

More than 3 years have passed since last update.

Union-Find構造をPythonで実装した

Last updated at Posted at 2019-03-03

ABCで実行時間オーバーになってしまったので作ってみました。
Union-Find構造は、ノードのグループ構造や、所属グループ判定を高速に処理できます。
詳細はこちら
解説によれば、以下の計算量となるようです。

  1. $\mathrm{unite}(u, v)$: 頂点 u が属するグループと頂点 v が属するグループを併合し、同じグループにする $(O(α(n)))$
  2. $\mathrm{find}(v)$: 頂点$v$が属するグループ番号を得る $(O(α(n)))$
  3. $\mathrm{size}(v)$: 頂点$v$が属するグループと同じグループに属する頂点数を得る $(O(1))$
    $n$は管理する頂点数、$\alpha$はアッカーマン関数の逆関数

実際のコードは以下です。
リポジトリ

追記
$\mathrm{unite}$ではグループの大きさでなくて、木構造の深さで結合方向を決めないといけないので間違ってますね
上記の参考資料の通り、経路圧縮による深さの変化は無視する方針で修正できそうです
2019/5/16 修正しました

UnionFindNode.py
class UnionFindNode(object):
    """
    Union-Find構造
    ノードのグループ併合や、所属グループ判定を高速に処理する
    """

    def __init__(self, group_id, parent=None, value=None):
        self.group_id_ = group_id
        self.parent_ = parent
        self.value = value
        self.rank_ = 1

    def __str__(self):
        template = "UnionFindNode(group_id: {}, \n\tparent: {}, value: {}, size: {})"
        return template.format(self.group_id_, self.parent_, self.value, self.rank_)

    def is_root(self):
        return not self.parent_

    def root(self):
        parent = self
        while not parent.is_root():
            parent = parent.parent_
            self.parent_ = parent
        return parent

    def find(self):
        root = self.root()
        return root.group_id_

    def rank(self):
        root = self.root()
        return root.rank_

    def unite(self, unite_node):
        root = self.root()
        unite_root = unite_node.root()

        if root.group_id_ != unite_root.group_id_:
            if root.rank() > unite_root.rank():
                unite_root.parent_ = root
                root.rank_ = max(root.rank_, unite_root.rank_ + 1)
            else:
                root.parent_ = unite_root
                unite_root.rank_ = max(root.rank_ + 1, unite_root.rank_)


if __name__ == "__main__":
    node_list = [UnionFindNode(i) for i in range(4)]
    node_list[0].unite(node_list[3])
    node_list[1].unite(node_list[2])
    node_list[0].unite(node_list[2])
    print("\n".join(list(map(str, node_list))))
    print()
    print("\n".join(list(map(lambda x: str(x.root()), node_list))))
7
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
3