最小全域木とは
そもそも木とは、「$N$個の頂点と$N-1$本の辺で連結された閉路を持たない無向グラフ」のことである。
一般的に、以下のようなグラフのことをいう。
では、全域木とはなにか。全域木は「すべての頂点が辺によって連結している木」のことをいう。
一般的に、以下のようなグラフのことをいう。
ここで、次のような木が与えられ、すべての辺に重みがついているとする。この辺の重みの総和が最小になるような全域木を最小全域木とよぶ。
どのようにしたら最小全域木を求められるのか
最小全域木を求めるには、大きく分けて2つの方法がある。プリム法とクラスカル法である。
プリム法とは
プリム法とは、「ある頂点から異なる頂点に出ている、すべての辺の重みを小さい順番に並び替え、採用した重みの最小の辺につながっている頂点から、同様のことを繰り返して、最終的に全域木を作る」方法である。
文字では伝わりにくいので、図を用いて解説していく。今回は、典型アルゴリズム問題集 F - 最小全域木問題 の入力例1をもとに解説していく。
図で解説
まず、プリム法では最小の重みの辺を調べる必要がある。しかし、辺が追加されるたびに重みを並び替えしていると、計算量が多くなってしまい効率的なアルゴリズムとはいえない。そこで、値を追加するだけで最小の値を取り出すことができるheap
というものを用いる。
heap
に関しては、過去に解説記事をあげているので、よければそちらを参考にしていただきたい。
したがって、heap
をあらかじめ用意しておく。
頂点$0$からスタートなので、頂点$0$からでているすべての辺と重みをheap
に追加する。このとき取り出した最小の重みを持つ辺が、最小全域木の辺となりうるので、頂点$0$→$1$を確定させ、総和であるcost
に$10$を足す。
次は、今頂点$1$にたどり着いたので、頂点$1$から出ているすべての辺と重みをheap
に追加する。このときの最小の重みをもつ辺は頂点$1$→$2$なので、これを確定させ、cost
に$10$を足す。
あとは、同様のことを繰り返していく。
最終的に、すべての頂点が辺でつながれば、これが最小全域木となる。
コード例(プリム法)
上記で説明したことをコードで実装したものは、以下の通りである。
def main():
from heapq import heappush, heappop
n, m = map(int, input().split())
graph = [[] for _ in range(n)]
for _ in range(m):
u, v, c = map(int, input().split())
graph[u].append((v, c)) # u->vの辺
graph[v].append((u, c)) # v->uの辺
# プリム法
# 頂点がマークされているか確認する配列
marked = [False for _ in range(n)]
# マークされている頂点数を数える
marked_cnt = 0
# はじめに0頂点をマーク
marked[0] = True
marked_cnt += 1
# heap
q = []
# 頂点0に隣接する辺を保存
for j, c in graph[0]:
heappush(q, (c, j))
total = 0
# すべての頂点をマークするまでwhileループ
while marked_cnt < n:
# 最小の重みの辺をheapで取り出す
c, i = heappop(q)
# マークされているならスキップ
if marked[i]:
continue
# 頂点をマーク
marked[i] = True
marked_cnt += 1
total += c
# 頂点iに隣接する辺を保存
for j, c in graph[i]:
# マークされていればスキップ
if marked[j]:
continue
heappush(q, (c, j))
print(total)
if __name__ == '__main__':
main()
クラスカル法とは
クラスカル法とは、「どの辺も選択していない状態からスタートし、重みが小さい辺から順番に追加していって最終的に全域木を作る」方法である。
こちらも同様に、典型アルゴリズム問題集 F - 最小全域木問題 の入力例1をもとに解説していく。
図で解説
プリム法と異なり、heap
を用いる必要はない。クラスカル法は、かわりにUnion-Find
というデータ構造を用いる。
Union-Find
についても、過去に解説記事を投稿しているので、そちらを参考にしていただきたい。
クラスカル法では、はじめからすべての辺をみて、重みが最小のものから選択していく。重みが同じ場合は、頂点番号が小さいものから優先的に選択される。つまり、ここでの重みが最小の辺は頂点$0$→$1$の辺であるので、これを確定させ、総和を求めるcost
に重みを足す。
次に重みが小さい辺は頂点$1$→$2$の辺であるので、これを確定させ、重みを足す。
同様に、現在行った操作を繰り返していく。
最終的にすべての頂点が繋がっていれば、それは最小全域木となる。
コード例(クラスカル法)
上記で説明したことをコードで実装したものは、以下の通りである。
# Union-Find
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
# Code
def main():
n, m = map(int, input().split())
uf = UnionFind(n)
edges = []
for _ in range(m):
u, v, c = map(int, input().split())
edges.append((c, u, v))
# 重みが小さい順に辺をソート
edges.sort()
cost = 0
for edge in edges:
c, u, v = edge
# 頂点がつながっていなければ
if not uf.same(u, v):
cost += c # 重みを足し
uf.union(u, v) # 頂点同士をつなげる
print(cost)
if __name__ == '__main__':
main()
また、補足として、クラスカル法はYouTubeでわかりやすい解説動画あるので、そちらも併せてご覧頂きたい。
最小全域木を求める問題
ここまでの説明が理解できれば、問題をといてみてほしい。
実際に手を動かしてコードを書くことで、より一層知識として定着するはずだ。
ABC218 E - Destruction
answer
def main():
from heapq import heappush, heappop
n, m = map(int, input().split())
graph = [[] for _ in range(n)]
cost = 0
for _ in range(m):
u, v, c = map(int, input().split())
u, v = u-1, v-1
graph[u].append((v, c)) # u->vの辺
graph[v].append((u, c)) # v->uの辺
if c >= 0:
cost += c
# プリム法
# 頂点がマークされているか確認する配列
marked = [False for _ in range(n)]
# マークされている頂点数を数える
marked_cnt = 0
# はじめに0頂点をマーク
marked[0] = True
marked_cnt += 1
# heap
q = []
# 頂点0に隣接する辺を保存
for j, c in graph[0]:
heappush(q, (c, j))
# すべての頂点をマークするまでwhileループ
while marked_cnt < n:
# 最小の重みの辺をheapで取り出す
c, i = heappop(q)
# マークされているならスキップ
if marked[i]:
continue
# 頂点をマーク
marked[i] = True
marked_cnt += 1
# 罰金でなければ報酬を
# 受け取りすぎているので報酬分の値を引く
if c >= 0:
cost -= c
# 頂点iに隣接する辺を保存
for j, c in graph[i]:
# マークされていればスキップ
if marked[j]:
continue
heappush(q, (c, j))
print(cost)
if __name__ == '__main__':
main()
ARC065 D - Built?
answer
# Union-Find
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
# Code
def main():
n = int(input())
nodes = []
# (x, y)の座標を頂点iに保存
for i in range(n):
x, y = map(int, input().split())
nodes.append((x, y, i))
uf = UnionFind(n)
xnodes = sorted(nodes)
ynodes = sorted(nodes, key=lambda y: y[1])
edges = []
for i in range(n-1):
dx = abs(xnodes[i][0] - xnodes[i+1][0])
dy = abs(xnodes[i][1] - xnodes[i+1][1])
cost = min(dx, dy)
edges.append((cost, xnodes[i][2], xnodes[i+1][2]))
for i in range(n-1):
dx = abs(ynodes[i][0] - ynodes[i+1][0])
dy = abs(ynodes[i][1] - ynodes[i+1][1])
cost = min(dx, dy)
edges.append((cost, ynodes[i][2], ynodes[i+1][2]))
# 重みが小さい順に辺をソート
edges.sort()
ans = 0
for edge in edges:
c, u, v = edge
# 頂点がつながっていなければ
if not uf.same(u, v):
ans += c # 重みを足し
uf.union(u, v) # 頂点同士をつなげる
print(ans)
if __name__ == '__main__':
main()
ABC181 F - Silver Woods
answer
# Union-Find
from collections import defaultdict
class UnionFind():
def __init__(self, n):
self.n = n
self.parents = [-1] * n
def find(self, x):
if self.parents[x] < 0:
return x
else:
self.parents[x] = self.find(self.parents[x])
return self.parents[x]
def union(self, x, y):
x = self.find(x)
y = self.find(y)
if x == y:
return
if self.parents[x] > self.parents[y]:
x, y = y, x
self.parents[x] += self.parents[y]
self.parents[y] = x
def size(self, x):
return -self.parents[self.find(x)]
def same(self, x, y):
return self.find(x) == self.find(y)
def members(self, x):
root = self.find(x)
return [i for i in range(self.n) if self.find(i) == root]
def roots(self):
return [i for i, x in enumerate(self.parents) if x < 0]
def group_count(self):
return len(self.roots())
def all_group_members(self):
group_members = defaultdict(list)
for member in range(self.n):
group_members[self.find(member)].append(member)
return group_members
def __str__(self):
return '\n'.join(f'{r}: {m}' for r, m in self.all_group_members().items())
# Code
def main():
# 斜辺
from math import hypot
n = int(input())
nodes = [tuple(map(int, input().split())) for _ in range(n)]
edges = []
s = n
t = n+1
for i in range(n):
x, y = nodes[i]
edges.append((y+100, i, s))
edges.append((100-y, i, t))
for i in range(n):
for j in range(i):
x1, y1 = nodes[i]
x2, y2 = nodes[j]
edges.append((hypot(x1-x2, y1-y2), i, j))
# 重みが小さい順に辺をソート
edges.sort()
uf = UnionFind(n+2)
for edge in edges:
c, u, v = edge
uf.union(u, v)
# 頂点がつながっていれば
if uf.same(s, t):
print(c/2)
exit()
if __name__ == '__main__':
main()