LoginSignup
12
8

More than 1 year has passed since last update.

Algorithm | 最小全域木をPython3で解説(例題あり)

Posted at

最小全域木とは

そもそもとは、「$N$個の頂点と$N-1$本の辺で連結された閉路を持たない無向グラフ」のことである。

一般的に、以下のようなグラフのことをいう。

Qiita-13.jpg

では、全域木とはなにか。全域木は「すべての頂点が辺によって連結している木」のことをいう。

一般的に、以下のようなグラフのことをいう。

Qiita-14.jpg

ここで、次のような木が与えられ、すべての辺に重みがついているとする。この辺の重みの総和が最小になるような全域木を最小全域木とよぶ。

Qiita-15.jpg
Qiita-16.jpg

どのようにしたら最小全域木を求められるのか

最小全域木を求めるには、大きく分けて2つの方法がある。プリム法クラスカル法である。

プリム法とは

プリム法とは、「ある頂点から異なる頂点に出ている、すべての辺の重みを小さい順番に並び替え、採用した重みの最小の辺につながっている頂点から、同様のことを繰り返して、最終的に全域木を作る」方法である。

文字では伝わりにくいので、図を用いて解説していく。今回は、典型アルゴリズム問題集 F - 最小全域木問題 の入力例1をもとに解説していく。

図で解説

まず、プリム法では最小の重みの辺を調べる必要がある。しかし、辺が追加されるたびに重みを並び替えしていると、計算量が多くなってしまい効率的なアルゴリズムとはいえない。そこで、値を追加するだけで最小の値を取り出すことができるheapというものを用いる。

heapに関しては、過去に解説記事をあげているので、よければそちらを参考にしていただきたい。

したがって、heapをあらかじめ用意しておく。

Qiita-17.jpg

頂点$0$からスタートなので、頂点$0$からでているすべての辺と重みをheapに追加する。このとき取り出した最小の重みを持つ辺が、最小全域木の辺となりうるので、頂点$0$→$1$を確定させ、総和であるcostに$10$を足す。

Qiita-18.jpg

次は、今頂点$1$にたどり着いたので、頂点$1$から出ているすべての辺と重みをheapに追加する。このときの最小の重みをもつ辺は頂点$1$→$2$なので、これを確定させ、costに$10$を足す。

Qiita-19.jpg

あとは、同様のことを繰り返していく。

Qiita-20.jpg
Qiita-21 1.jpg

最終的に、すべての頂点が辺でつながれば、これが最小全域木となる。

Qiita-22 1.jpg

コード例(プリム法)

上記で説明したことをコードで実装したものは、以下の通りである。

def main():
    from heapq import heappush, heappop

    n, m = map(int, input().split())

    graph = [[] for _ in range(n)]

    for _ in range(m):
        u, v, c = map(int, input().split())

        graph[u].append((v, c)) # u->vの辺
        graph[v].append((u, c)) # v->uの辺

    # プリム法
    # 頂点がマークされているか確認する配列
    marked = [False for _ in range(n)]

    # マークされている頂点数を数える
    marked_cnt = 0

    # はじめに0頂点をマーク
    marked[0] = True
    marked_cnt += 1

    # heap
    q = []

    # 頂点0に隣接する辺を保存
    for j, c in graph[0]:
        heappush(q, (c, j))

    total = 0

    # すべての頂点をマークするまでwhileループ
    while marked_cnt < n:
        # 最小の重みの辺をheapで取り出す
        c, i = heappop(q)

        # マークされているならスキップ
        if marked[i]:
            continue

        # 頂点をマーク
        marked[i] = True
        marked_cnt += 1

        total += c

        # 頂点iに隣接する辺を保存
        for j, c in graph[i]:
            # マークされていればスキップ
            if marked[j]:
                continue

            heappush(q, (c, j))

    print(total)


if __name__ == '__main__':
    main()

クラスカル法とは

クラスカル法とは、「どの辺も選択していない状態からスタートし、重みが小さい辺から順番に追加していって最終的に全域木を作る」方法である。

こちらも同様に、典型アルゴリズム問題集 F - 最小全域木問題 の入力例1をもとに解説していく。

図で解説

プリム法と異なり、heapを用いる必要はない。クラスカル法は、かわりにUnion-Findというデータ構造を用いる。

Union-Findについても、過去に解説記事を投稿しているので、そちらを参考にしていただきたい。

Qiita-23.jpg

クラスカル法では、はじめからすべての辺をみて、重みが最小のものから選択していく。重みが同じ場合は、頂点番号が小さいものから優先的に選択される。つまり、ここでの重みが最小の辺は頂点$0$→$1$の辺であるので、これを確定させ、総和を求めるcostに重みを足す。

Qiita-24.jpg

次に重みが小さい辺は頂点$1$→$2$の辺であるので、これを確定させ、重みを足す。

Qiita-25.jpg

同様に、現在行った操作を繰り返していく。

Qiita-26.jpg

最終的にすべての頂点が繋がっていれば、それは最小全域木となる。

Qiita-27.jpg

コード例(クラスカル法)

上記で説明したことをコードで実装したものは、以下の通りである。

# Union-Find

from collections import defaultdict

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

# Code

def main():
    n, m = map(int, input().split())
    uf = UnionFind(n)

    edges = []

    for _ in range(m):
        u, v, c = map(int, input().split())
        edges.append((c, u, v))

    # 重みが小さい順に辺をソート
    edges.sort()

    cost = 0

    for edge in edges:
        c, u, v = edge
        # 頂点がつながっていなければ
        if not uf.same(u, v):
            cost += c # 重みを足し
            uf.union(u, v) # 頂点同士をつなげる

    print(cost)


if __name__ == '__main__':
    main()

また、補足として、クラスカル法はYouTubeでわかりやすい解説動画あるので、そちらも併せてご覧頂きたい。

最小全域木を求める問題

ここまでの説明が理解できれば、問題をといてみてほしい。
実際に手を動かしてコードを書くことで、より一層知識として定着するはずだ。

ABC218 E - Destruction

answer

def main():
    from heapq import heappush, heappop

    n, m = map(int, input().split())

    graph = [[] for _ in range(n)]

    cost = 0

    for _ in range(m):
        u, v, c = map(int, input().split())
        u, v = u-1, v-1
        graph[u].append((v, c))  # u->vの辺
        graph[v].append((u, c))  # v->uの辺
        if c >= 0:
            cost += c

    # プリム法
    # 頂点がマークされているか確認する配列
    marked = [False for _ in range(n)]

    # マークされている頂点数を数える
    marked_cnt = 0

    # はじめに0頂点をマーク
    marked[0] = True
    marked_cnt += 1

    # heap
    q = []

    # 頂点0に隣接する辺を保存
    for j, c in graph[0]:
        heappush(q, (c, j))

    # すべての頂点をマークするまでwhileループ
    while marked_cnt < n:
        # 最小の重みの辺をheapで取り出す
        c, i = heappop(q)

        # マークされているならスキップ
        if marked[i]:
            continue

        # 頂点をマーク
        marked[i] = True
        marked_cnt += 1

        # 罰金でなければ報酬を
        # 受け取りすぎているので報酬分の値を引く
        if c >= 0:
            cost -= c

        # 頂点iに隣接する辺を保存
        for j, c in graph[i]:
            # マークされていればスキップ
            if marked[j]:
                continue

            heappush(q, (c, j))

    print(cost)


if __name__ == '__main__':
    main()

ARC065 D - Built?

answer

# Union-Find

from collections import defaultdict


class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

# Code


def main():
    n = int(input())

    nodes = []

    # (x, y)の座標を頂点iに保存
    for i in range(n):
        x, y = map(int, input().split())
        nodes.append((x, y, i))

    uf = UnionFind(n)

    xnodes = sorted(nodes)
    ynodes = sorted(nodes, key=lambda y: y[1])

    edges = []

    for i in range(n-1):
        dx = abs(xnodes[i][0] - xnodes[i+1][0])
        dy = abs(xnodes[i][1] - xnodes[i+1][1])
        cost = min(dx, dy)
        edges.append((cost, xnodes[i][2], xnodes[i+1][2]))

    for i in range(n-1):
        dx = abs(ynodes[i][0] - ynodes[i+1][0])
        dy = abs(ynodes[i][1] - ynodes[i+1][1])
        cost = min(dx, dy)
        edges.append((cost, ynodes[i][2], ynodes[i+1][2]))

    # 重みが小さい順に辺をソート
    edges.sort()

    ans = 0

    for edge in edges:
        c, u, v = edge
        # 頂点がつながっていなければ
        if not uf.same(u, v):
            ans += c  # 重みを足し
            uf.union(u, v)  # 頂点同士をつなげる

    print(ans)


if __name__ == '__main__':
    main()

ABC181 F - Silver Woods

answer

# Union-Find

from collections import defaultdict

class UnionFind():
    def __init__(self, n):
        self.n = n
        self.parents = [-1] * n

    def find(self, x):
        if self.parents[x] < 0:
            return x
        else:
            self.parents[x] = self.find(self.parents[x])
            return self.parents[x]

    def union(self, x, y):
        x = self.find(x)
        y = self.find(y)

        if x == y:
            return

        if self.parents[x] > self.parents[y]:
            x, y = y, x

        self.parents[x] += self.parents[y]
        self.parents[y] = x

    def size(self, x):
        return -self.parents[self.find(x)]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def members(self, x):
        root = self.find(x)
        return [i for i in range(self.n) if self.find(i) == root]

    def roots(self):
        return [i for i, x in enumerate(self.parents) if x < 0]

    def group_count(self):
        return len(self.roots())

    def all_group_members(self):
        group_members = defaultdict(list)
        for member in range(self.n):
            group_members[self.find(member)].append(member)
        return group_members

    def __str__(self):
        return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())

# Code

def main():
    # 斜辺
    from math import hypot

    n = int(input())
    nodes = [tuple(map(int, input().split())) for _ in range(n)]
    edges = []

    s = n
    t = n+1

    for i in range(n):
        x, y = nodes[i]
        edges.append((y+100, i, s))
        edges.append((100-y, i, t))


    for i in range(n):
        for j in range(i):
            x1, y1 = nodes[i]
            x2, y2 = nodes[j]
            edges.append((hypot(x1-x2, y1-y2), i, j))

    # 重みが小さい順に辺をソート
    edges.sort()

    uf = UnionFind(n+2)

    for edge in edges:
        c, u, v = edge
        uf.union(u, v)
        # 頂点がつながっていれば
        if uf.same(s, t):
            print(c/2)
            exit()


if __name__ == '__main__':
    main()
12
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
8