Help us understand the problem. What is going on with this article?

競技プログラミングで使う有名グラフアルゴリズムまとめ

0. はじめに

AtCoderなどでは、グラフを扱った問題が多く出るが、その度に一から実装していると時間が掛かりすぎてしまうため、有名なものをあらかじめ持っておく必要がありそう。そこで、Pythonを用いて、ダイクストラ法、ベルマンフォード法、プリム法、クラスカル法、ワーシャルフロイド法を実装した。
コメント、意見等ある方は是非! お待ちしてます!

1. ダイクストラ法

1.1. ダイクストラ法(defaultdictで実装)

defaultdictで実装すると、リストで実装するよりも、ノード数$N$が大きい際には高速に動作する。ただし、経路復元の関数は、うまく書けなかった......。
(2019/7/6 追記)結局できました。1.1.1. を参照してください。

import collections
import heapq


class Dijkstra:
    def __init__(self):
        self.e = collections.defaultdict(list)

    def add(self, u, v, d):
        self.e[u].append([v, d])
        self.e[v].append([u, d])

    def delete(self, u, v):
        self.e[u] = [_ for _ in self.e[u] if _[0] != v]
        self.e[v] = [_ for _ in self.e[v] if _[0] != u]

    def search(self, s):
        """
        :param s: 始点
        :return: 始点から各点までの最短経路
        """
        d = collections.defaultdict(lambda: float('inf'))
        d[s] = 0
        q = []
        heapq.heappush(q, (0, s))
        v = collections.defaultdict(bool)
        while len(q):
            k, u = heapq.heappop(q)
            if v[u]:
                continue
            v[u] = True

            for uv, ud in self.e[u]:
                if v[uv]:
                    continue
                vd = k + ud
                if d[uv] > vd:
                    d[uv] = vd
                    heapq.heappush(q, (vd, uv))

        return d

1.1.1. ダイクストラ法(defaultdictで実装) 経路復元も

defaultdictを用いても、経路復元が実装できました。ABC051 D Candidates of No Shotest Pathでテスト済みです。

import collections
import heapq


class Dijkstra():
    def __init__(self):
        self.e = collections.defaultdict(list)

    def add(self, u, v, d, directed=False):
        """
        #0-indexedでなくてもよいことに注意
        #u = from, v = to, d = cost
        #directed = Trueなら、有向グラフである
        """
        if directed is False:
            self.e[u].append([v, d])
            self.e[v].append([u, d])
        else:
            self.e[u].append([v, d])

    def delete(self, u, v):
        self.e[u] = [_ for _ in self.e[u] if _[0] != v]
        self.e[v] = [_ for _ in self.e[v] if _[0] != u]

    def Dijkstra_search(self, s):
        """
        #0-indexedでなくてもよいことに注意
        #:param s: 始点
        #:return: 始点から各点までの最短経路と最短経路を求めるのに必要なprev
        """
        d = collections.defaultdict(lambda: float('inf'))
        prev = collections.defaultdict(lambda: None)
        d[s] = 0
        q = []
        heapq.heappush(q, (0, s))
        v = collections.defaultdict(bool)
        while len(q):
            k, u = heapq.heappop(q)
            if v[u]:
                continue
            v[u] = True

            for uv, ud in self.e[u]:
                if v[uv]:
                    continue
                vd = k + ud
                if d[uv] > vd:
                    d[uv] = vd
                    prev[uv] = u
                    heapq.heappush(q, (vd, uv))

        return d, prev

    def getDijkstraShortestPath(self, start, goal):
        _, prev = self.Dijkstra_search(start)
        shortestPath = []
        node = goal
        while node is not None:
            shortestPath.append(node)
            node = prev[node]
        return shortestPath[::-1]

ダイクストラ法(defaultdictで実装)の使用例

SoundHound2018 D - Saving Snuukに対する自分の解答

# クラス部分は上と全く一緒
N, M, S, T = map(int, input().split())
UVAB = [list(map(int, input().split())) for i in range(M)]
graph1 = Dijkstra()
graph2 = Dijkstra()

for u, v, a, b in UVAB:
    graph1.add(u, v, a)
    graph2.add(u, v, b)

result_a = graph1.search(S)
result_b = graph2.search(T)
ans_list = [float('inf')]
for i in range(N, 0, -1):
    cur = result_a[i] + result_b[i]
    if cur > ans_list[-1]:
        cur = ans_list[-1]
    ans_list.append(cur)

for i in range(len(ans_list[1:])):
    print(10**15-ans_list[-(i+1)])

1.2. ダイクストラ法(リストで実装)

上にも書いた通り、動作が遅く、TLEすることも十分考えられる。しかし、経路復元が必要なときはこっちを使わざるを得ない。

import collections
import heapq


class Dijkstra():
    def __init__(self, N):
        self.N = N
        self.e = [[float('inf') for i in range(self.N)] for i in range(N)]

    def add(self, u, v, d, directed=False):
        """
        0-indexedでなくてもよいことに注意
        u = from, v = to, d = cost
        directed = Trueなら、有向グラフである
        """
        if directed is False:
            self.e[u][v] = d
            self.e[v][u] = d
        else:
            self.e[u][v] = d

    def delete(self, u, v):
        self.e[u] = [_ for _ in self.e[u] if _[0] != v]
        self.e[v] = [_ for _ in self.e[v] if _[0] != u]

    def Dijkstra_search(self, s):
        """
        0-indexedであることに注意
        s =  始点
        return: 始点から各点までの最短経路
        """
        d = [float('inf') for i in range(self.N)]
        d[s] = 0
        q = []
        heapq.heappush(q, (0, s))
        v = collections.defaultdict(bool)
        while len(q):
            k, u = heapq.heappop(q)
            if v[u]:
                continue

            for uv, ud in enumerate(self.e[u]):
                if v[uv]:
                    continue
                vd = k + ud
                if d[uv] > vd:
                    d[uv] = vd
                    heapq.heappush(q, (vd, uv))

        return d

    def getDijkstraPath(self, s, t):
        # sからtへの最短経路の経路復元
        prev = [s] * self.N  # 最短経路の直前の頂点
        d = [float("inf")] * self.N
        used = [False] * self.N
        d[s] = 0

        while True:
            v = -1
            for i in range(self.N):
                if (not used[i]) and (v == -1):
                    v = i
                elif (not used[i]) and d[i] < d[v]:
                    v = i
            if v == -1:
                break
            used[v] = True

            for i in range(self.N):
                if d[i] > d[v] + self.e[v][i]:
                    d[i] = d[v] + self.e[v][i]
                    prev[i] = v

        path = [t]
        while prev[t] != s:
            path.append(prev[t])
            prev[t] = prev[prev[t]]
        path.append(s)
        path = path[::-1]
        return path

1.3. ダイクストラ法(隣接行列をリストで表現した、密なグラフ用のダイクストラ)

SoundHound2018 D - Saving Snuukは余裕でTLEした。

class Dijkstra():
    def __init__(self, N):
        self.N = N
        self.e = [[float('inf') for i in range(self.N)] for i in range(N)]

    def add(self, u, v, d, directed=False):
        if directed is False:
            self.e[u][v] = d
            self.e[v][u] = d
        else:
            self.e[u][v] = d

    def delete(self, u, v):
        self.e[u] = [_ for _ in self.e[u] if _[0] != v]
        self.e[v] = [_ for _ in self.e[v] if _[0] != u]

    def DijkstraSearch(self, s):
        d = [float('inf') for i in range(self.N)]
        d[s] = 0
        pred = [-1 for i in range(self.N)]
        visited = [False for i in range(self.N)]

        while True:
            u = -1
            sd = float('inf')
            for i in range(0, self.N):
                if not visited[i] and d[i] < sd:
                    sd = d[s]
                    u = i

            if u == -1:
                break

            visited[u] = True
            for v in range(0, self.N):
                w = self.e[u][v]
                if v == u:
                    continue
                newLen = d[u] + w
                if newLen < d[v]:
                    d[v] = newLen
                    pred[v] = u
        return d

2. ベルマンフォード法

リストでdefaultdictで実装する。defaultdictでの実装ももちろん可能。

class BellmanFord():
    def __init__(self, N):
        self.N = N
        self.edges = []

    def add(self, u, v, d, directed=False):
        """
        u = from, v = to, d = cost
        directed = Trueのとき、有向グラフである。
        """
        if directed is False:
            self.edges.append([u, v, d])
            self.edges.append([v, u, d])
        else:
            self.edges.append([u, v, d])

    def BellmanFord_search(self, s):
        """
        :param s: 始点
        :return: d[i] 始点sから各点iまでの最短経路
        """
        d = [float('inf') for i in range(self.N)]
        d[s] = 0
        numEdges = len(self.edges)
        while True:
            update = False
            for i in range(numEdges):
                e = self.edges[i]
                # e: 辺iについて [from,to,cost]
                if d[e[0]] != float("inf") and d[e[1]] > d[e[0]] + e[2]:
                    d[e[1]] = d[e[0]] + e[2]
                    update = True
            if not update:
                break
        return d

    def BellmanFord_negative_bool(self, start, numNodes):
        # 負の閉路の検出, Trueなら負の閉路が存在する
        d = [float('inf') for i in range(self.N)]
        d[start] = 0
        numEdges = len(self.edges)
        for i in range(numNodes):
            for j in range(numEdges):
                e = self.edges[j]
                if d[e[1]] > d[e[0]] + e[2]:
                    d[e[1]] = d[e[0]] + e[2]
                    if i == numNodes-1:
                        return True, d
        return False, d

3. プリム法

import heapq


class Prim():
    # 無向グラフであるという前提に注意
    def __init__(self, N):
        self.edge = [[] for i in range(N)]
        self.N = N

    def add(self, u, v, d):
        """
        u = from, v = to, d = cost
        0-indexedであることに注意、graph.add(u-1, v-1)とする必要がある
        """
        self.edge[u].append([d, v])  # コスト、e_toとなっていることに注意
        self.edge[v].append([d, u])

    def delete(self, u, v):
        self.edge[u] = [_ for _ in self.edge[u] if _[0] != v]
        self.edge[v] = [_ for _ in self.edge[v] if _[0] != u]

    def Prim(self):
        """
        return: 最小全域木のコストの和
        """
        used = [True] * self.N  # True:不使用
        edgelist = []
        for e in self.edge[0]:
            heapq.heappush(edgelist, e)
        used[0] = False
        res = 0
        while len(edgelist) != 0:
            minedge = heapq.heappop(edgelist)
            if not used[minedge[1]]:
                continue
            v = minedge[1]
            used[v] = False
            for e in self.edge[v]:
                if used[e[1]]:
                    heapq.heappush(edgelist, e)
            res += minedge[0]
        return res

4. クラスカル法

class Kruskal_UnionFind():
    # 無向グラフであるという前提に注意
    def __init__(self, N):
        self.edges = []
        self.rank = [0] * N
        self.par = [i for i in range(N)]
        self.counter = [1] * N

    def add(self, u, v, d):
        """
        u = from, v = to, d = cost
        """
        self.edges.append([u, v, d])

    def find(self, x):
        if self.par[x] == x:
            return x
        else:
            self.par[x] = self.find(self.par[x])
            return self.par[x]

    def unite(self, x, y):
        x = self.find(x)
        y = self.find(y)
        if x != y:
            z = self.counter[x] + self.counter[y]
            self.counter[x], self.counter[y] = z, z
        if self.rank[x] < self.rank[y]:
            self.par[x] = y
        else:
            self.par[y] = x
            if self.rank[x] == self.rank[y]:
                self.rank[x] += 1

    def size(self, x):
        x = self.find(x)
        return self.counter[x]

    def same(self, x, y):
        return self.find(x) == self.find(y)

    def Kruskal(self):
        """
        return: 最小全域木のコストの和
        """
        edges = sorted(self.edges, key=lambda x: x[2])  # costでself.edgesをソートする
        res = 0
        for e in edges:
            if not self.same(e[0], e[1]):
                self.unite(e[0], e[1])
                res += e[2]
        return res

クラスカル法の使用例

AtCoder Beginner Contest 065 D - Built?に対する自分の解答

# クラス部分は上と一緒
N = int(input())
XY = [[i] + list(map(int, input().split())) for i in range(N)]

graph = Kruskal_UnionFind(N)
XY = sorted(XY, key=lambda x: x[1])
X_costs = [[XY[i-1][0], XY[i][0], abs(XY[i-1][1] - XY[i][1])] for i in range(1, N)]
XY = sorted(XY, key=lambda x: x[2])
Y_costs = [[XY[i-1][0], XY[i][0], abs(XY[i-1][2] - XY[i][2])] for i in range(1, N)]

for i in range(N-1):
    x0, x1, d = X_costs[i]
    graph.add(x0, x1, d)
    y0, y1, d = Y_costs[i]
    graph.add(y0, y1, d)

print(graph.Kruskal())

5. ワーシャルフロイド法

(2019/7/8 追記) グラフが負の閉路を持つかを判定するように変更しました。これに合わせて、下のコードも少し変わっています。

class WarshallFloyd():
    def __init__(self, N):
        self.N = N
        self.d = [[float("inf") for i in range(N)]
                  for i in range(N)]  # d[u][v] : 辺uvのコスト(存在しないときはinf)

    def add(self, u, v, c, directed=False):
        """
        0-indexedであることに注意
        u = from, v = to, c = cost
        directed = Trueなら、有向グラフである
        """
        if directed is False:
            self.d[u][v] = c
            self.d[v][u] = c
        else:
            self.d[u][v] = c

    def WarshallFloyd_search(self):
        # これを d[i][j]: iからjへの最短距離 にする
        # 本来無向グラフでのみ全域木を考えるが、二重辺なら有向でも行けそう
        # d[i][i] < 0 なら、グラフは負のサイクルを持つ
        for k in range(self.N):
            for i in range(self.N):
                for j in range(self.N):
                    self.d[i][j] = min(
                        self.d[i][j], self.d[i][k] + self.d[k][j])
        hasNegativeCycle = False
        for i in range(self.N):
            if self.d[i][i] < 0:
                hasNegativeCycle = True
                break
        for i in range(self.N):
            self.d[i][i] = 0
        return hasNegativeCycle, self.d

ワーシャルフロイド法の使用例

AtCoder Beginner Contest 012 D - バスと避けられない運命に対する自分の解答

# クラス部分は上と一緒
N, M = map(int, input().split())
ABT = [list(map(int, input().split())) for i in range(M)]
graph = WarshallFloyd(N)
for a, b, t in ABT:
    graph.add(a-1, b-1, t)

hasNegativeCycle, d = graph.WarshallFloyd_search()
ans = sorted([[i, max(d[i])] for i in range(N)], key=lambda x: x[1])
print(ans[0][1])
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした