はじめに
本日は競技プログラミングで頻出の最短経路問題 を、BFSと呼ばれるアルゴリズムを用いて解いてみます。
扱う題材はAtCoder「競プロ典型90問」からの次の問題です。
A63は重みなしグラフでの最短経路問題です。このような問題に対してBFSは非常に強力です。
BFSは浅いnodeから順に探索を進めるアルゴリズムです。したがって、ヒューリスティック的な観点から言えば、比較的浅いnodeに答えがあるという可能性が高いと推論されるときはDFSよりも、BFSが有利とされています。
ちなみに、実際の地図上での最大経路を求めるときは重み付きのグラフの最短経路が重要となります。そのような問題に対してはBFSではなくDijkstra's algorithmと呼ばれるより一般的なアルゴリズムが有効です。重みなしグラフはすべての辺が重み1の重み付きグラフとも言い換えられるのでDijkstra's algorithmはBFSを包括しています。
本稿ではBFSに絞って書きますが、BFSを理解すれば、Dijkstra's algorithmも理解しやすくなると思います。
コードを書いていきましょう
さて、DFSの記事でDFSは本質的にスタックだということを書きましたが、BFSはその逆です。すなわち、BFSはキューのデータ構造によるアルゴリズムです。
キューを使うときはcollections.deque
をimport
します。list
でもlist.pop(0)
とlist.append()
をつかってキューを実装できますが、計算時間が異なるので競技プログラミングならばcollections.deque
のほうが良いともいます。(より高速な実装法をご存じの方がいればぜひご教授ください)
from collections import deque
ちなみに、list.append()
の計算時間は${O(1)}$なのですが、list.pop(0)
は${O(N)}$の計算時間がかかってしまうようです。collections.deque
ではpushとpopはともに${O(1)}$の計算時間で可能です。
N, M = map(int, input().split())
頂点と辺の数をそれぞれ読み取り、空のnodeリストを作ります。
これはDFSの時と同じですね。このリストには各nodeの隣接するnode番号を格納します。駅ごとに「隣の駅」を定義するようなイメージです。
読み取りのメソッドもDFSの時と全く同じです。
for _ in range(M):
a, b = map(int, input().split())
a -= 1
b -= 1
node[a].append(b)
node[b].append(a)
別にaとbをそれぞれ-1してnodeを0から始める必要もないのですが、慣習と思ってこのように書いています。
inf = 10 ** 9
dist = [inf] * N
dist[0] = 0
問題の条件をよく読んでinfの値は絶対に到達しない十分に大きな値に設定してください。この問題の場合頂点数より大きなオーダーであれば問題ないと思います。dist
は頂点1(プログラム上では0)との距離としてそれぞれのnodeに格納された値の初期値になります。
q = deque([0])
探索はまず頂点1(プログラム上では0)からですので、キューに[0]を格納します。そしてwhile q:
でq
の中身がなくなるまでループを回していきます。
while q:
# cur は現在の node
cur = q[0]
q.popleft()
# 現在いる cur に隣接する頂点 e を探索
for e in node[cur]:
# 条件を満たす場合はその node に移動
if dist[e] > dist[cur] + 1:
dist[e] = dist[cur] + 1
q.append(e)
最後は表示です。
for e in dist:
if e == inf:
print(-1)
else:
print(e)
コード全体
以上をまとめてソースコードの全体はこのようになります。
from collections import deque
N, M = map(int, input().split())
node = [[] for _ in range(N)]
for _ in range(M):
a, b = map(int, input().split())
a -= 1
b -= 1
node[a].append(b)
node[b].append(a)
inf = 10 ** 9
dist = [inf] * N
dist[0] = 0
q = deque([0])
while q:
cur = q[0]
q.popleft()
for e in node[cur]:
if dist[e] > dist[cur] + 1:
dist[e] = dist[cur] + 1
q.append(e)
for e in dist:
if e == inf:
print(-1)
else:
print(e)
グラフの可視化(おまけ)
おまけで適当に書いてみたコードも添えておきます。このコードは先ほどのnodeの情報からデータ構造を可視化するコードです。
import matplotlib.pyplot as plt
import networkx as nx
def display_graph(node):
G = nx.Graph()
for i, dots in enumerate(node):
for dot in dots:
if dot >= i:
G.add_edge(i, dot)
plt.figure(figsize=(6,6))
nx.draw(G,with_labels=True)
plt.show()
def main():
display_graph(node)
if __name__ == "__main__":
main()
結果はこのように表示されます。
こうして見れば連結グラフなことも一目瞭然ですね。
NetworkXというライブラリを使えばこのようにして簡単にグラフを可視化できます。NetworkXについてはお気に入りのライブラリのひとつなので別の機会に詳細な記事を書く予定です。