bitDPって難しいですよね。DPという考え方自体が難しい上、集合を2進数で考えなきゃならない・・・。他記事の解説を見ても、なかなか頭に入らなかったので、初心者(自分)向けに、解説したいと思います。
巡回セールスマン問題
ABC180 / E問 がサンプル問題として適しています。要約すると、以下のような問題です。
全ての都市と道路で繋がっている都市がN個あります。全都市をちょうど1回ずつ訪れて、戻ってくる最短ルートの距離を求めてください。
通常の全探索でDFSなどを使うと、計算量が$O(N!)$になってしまいます。この問題ではNの制約が$N \leq 17$なので、$17! = 355687428096000$は確実にTLE(Time Limit Exceeded)です。そこでbitDPを活用します。これにより計算量を$O(N!) → O(N^2 \times 2^N)$に削減できます。
DPテーブルの説明
bitDPでは、以下のような2次元配列dpを用いて最短経路を記録していきます。
dp[これまで訪れた都市の集合][最後に訪れた都市] = 最短経路
例えば、N=4
の場合、全集合は{0, 1, 2, 3}
となります。今、都市{1, 2, 3}
に訪れていて、最後に都市2にいる場合、dp[{1,2,3}][2]
としてそれまでの最短経路を記録します。このように、訪れた都市の順番を考慮せず、集合を用いて計算を効率化しています。
ビット演算で集合を使う準備
先ほどdpリストのindexには、dp[{1,2,3}][2]
のように、集合{1,2,3}
を用いました。しかし、これをそのまま実装してしまうと、リストのindexには集合を使うことができないのでエラーになります。そこで、集合の代わりに2進数を使います。
まずは初めに、BitDPにおける集合を2進数で表現・操作する以下3つのポイントを説明します。
- 集合を2進数で表現
- 集合
s
に要素i
が含まれているかの判定 - 集合
s
に要素i
を加える
というのも、いきなりDPアルゴリズムの部分まで含めた解説に入ると、2進数操作のところで「なんだこれ??」って結構混乱するからです。ビット演算を全く知らないという方は、「Python ビット演算 超入門」などを参照してみて下さい。
集合を2進数で表現
例えば、N=4
のとき、{1,2,3}
を2進数に変換してみましょう。これは{1,2,3} → 1110
と変換できます。集合の要素iを用いて、2進数の右からi番目に1を立てていくイメージです。
{0} → 0001
{1,3} → 1010
{0,1,2,3} → 1111
ただし、実際の実装では2進数を使うのではなく、10進数を2進数と見立てて使います。N個の集合の全パターンは$2^N$個の10進数で表現できます。10進数・2進数・集合の対応表は以下のコードで確認できます。
import pandas as pd, numpy as np
n = 3
df = pd.DataFrame()
df.index.name = "10進数"
df["2進数"] = [bin(i)[2:].zfill(n) for i in range(2**n)]
df["集合"] = [set([j for j in range(n) if bit[-(j+1)] == "1"]) for bit in df["2進数"]]
display(df)
>>>
10進数 2進数 集合
0 000 {}
1 001 {0}
2 010 {1}
3 011 {0, 1}
4 100 {2}
5 101 {0, 2}
6 110 {1, 2}
7 111 {0, 1, 2}
集合sに都市iが含まれているかの判定
例えば、s = 5, i = 1
のとき、5
→ 101
→ {0,2}
という集合に要素1
が含まれているかを判定したいと思います。これは、集合sの右からi番目にビットが立っているかを判定すれば良いです。
s = 5 # 集合{0,2}
i = 1
if s >> i & 1 == 1: # 集合{0,2}に要素1が含まれているか判定
print(f"集合{bin(s)[2:]}に, {i}は含まれる")
else:
print(f"集合{bin(s)[2:]}に, {i}は含まれない")
>>> 集合101に, 1は含まれない
s >> i & 1 == 1
の部分がポイントです。計算の流れは以下の通りです。
- 右シフトしてi桁目を1桁目に移動:
101 >> 1 → 10
- 1桁目と1を論理積
&
で比較し、ビットが立っているかを判定 - 実際の計算は
10 & 01 → 0
(結果が0
なので、要素1
は集合{0,2}
に含まれない)
これで集合s
に要素i
が含まれるかの判定ができました。
集合sに要素iを加える
例えば、s = 5, i = 1
のとき、5
→ 101
→ {0,2}
という集合に要素1
を加えたいと思います。集合101
→ {0,2}
に要素1を加えると、集合111
→ {0,1,2}
になります。
これは以下のように実装できます。
s = 5 # 集合{0,2}
i = 1
ns = s | 1 << i # 集合{0,2}に要素1を追加
print(bin(s)[2:], "→", bin(ns)[2:])
>>> 101 → 111
s | 1 << i
の部分がポイントです。i = 1
, s = 5
(5
→ 101
→ {0,2}
)のとき、計算の流れは以下の通りです。
-
1 << i
で右からi
番目にビットを立てた2進数を作成:1 << 1 → 010
- 集合
s
と1 << i
を論理和|
で比較して、集合s
の右からi
番目にビットを立てる - 実際の計算は
101 | 010 → 111
(111
は集合{0,1,2}
で、要素1
が追加されていることを確認できる)
これで集合s
に要素i
を追加できました。
実装
それでは、実装コードを見てみましょう。ポイントとなる場所で区切って、順番に解説します。また、ABC180 / E問 の入力例2を用いて、解説を進めたいと思います。
まずは全体のコードどうぞ。
# 都市の数を受け取る
n = int(input())
# 各都市の座標を受け取る
pos = [list(map(int, input().split())) for i in range(n)]
# 都市間の距離を計算するための配列を初期化
dist = [[0] * n for _ in range(n)]
# 都市間の距離を計算
for u in range(n):
for v in range(n):
a, b, c = pos[u]
p, q, r = pos[v]
dist[u][v] = abs(p - a) + abs(q - b) + max(0, r - c)
# dpテーブルを初期化。dp[s][v]は、都市の集合sを訪れていて、最後に都市vにいるときの最短距離
# スタート地点は都市0とする。最初の時点で最短経路は0。
dp = [[float("inf")] * n for _ in range(2**n)]
dp[1][0] = 0
for s in range(1, 2**n):
for u in range(n):
for v in range(n):
# 集合s(今まで訪れた都市)のうち、uに訪れていて、vに訪れていないとき
if s >> u & 1 and not s >> v & 1:
ns = s | 1 << v #集合sにvを追加
dp[ns][v] = min(dp[ns][v], dp[s][u] + dist[u][v])
# p[-1]を利用して、都市0に戻る。
# dp[-1][v]はすべての都市を通って、最後に都市vにいる最短経路。
# そこからdist[v][0]でv→0に移動する距離を加えた最短経路の最小が答えになる。
ans = min([dp[-1][v] + dist[v][0] for v in range(n)])
print(ans)
① 入力を受け取り、都市間の距離を計算するための配列を初期化
都市uから都市vまでの距離を計算して、distに格納します。問題文通り、そのまま計算を行います。dist[u][v]
で、uからvまでの距離を取得できます。
# 都市の数を受け取る
n = int(input())
# 各都市の座標を受け取る
pos = [list(map(int, input().split())) for i in range(n)]
# 都市間の距離を計算するための配列を初期化
dist = [[0] * n for _ in range(n)]
# 都市間の距離を計算
for u in range(n):
for v in range(n):
a, b, c = pos[u]
p, q, r = pos[v]
dist[u][v] = abs(p - a) + abs(q - b) + max(0, r - c)
② dpを初期化
dpの初期値を設定します。全ての都市を通らなければならない条件があるため、答えとなる最短経路には全ての都市が必ず含まれます。なので、どこからスタートしても良いわけですが、わかりやすく最初は都市0にいるとして、スタートします。
最初は都市0にいるので、dp[1][0] = 0
として、その時点での最短経路は0と設定します。
# dpテーブルを初期化。dp[s][v]は、都市集合sを訪れて最後に都市vにいるときの最短距離
dp = [[float("inf")] * n for _ in range(2**n)]
# スタート地点は都市0とする。最初の時点では最短経路は0。
dp[1][0] = 0
③ dpを更新
for s in range(1, 2**n):
for u in range(n):
for v in range(n):
# 集合s(今まで訪れた都市)のうち、uには訪れていて、vには訪れていないとき
if s >> u & 1 and not s >> v & 1:
ns = s | 1 << v #集合sにvを追加
dp[ns][v] = min(dp[ns][v], dp[s][u] + dist[u][v])
dpを更新する条件は、「集合s(これまで訪れた都市)の中で、uには訪れていて、vには訪れていない」というものです。これはs >> u & 1 and not s >> v & 1
という書き方で実装しています。dpの遷移先は、ns = s | 1 << v
として、集合sに都市vを追加し、新たな集合nsを定義します。このnsを用いて、dp[ns][v] を更新します。具体的には、現在の都市uまでの最短経路dp[s][u]
に、uからvまでの距離dist[u][v]
を加えて、テーブルを更新します。
dp[ns][v] = min(dp[ns][v], dp[s][u] + dist[u][v])
ABC180 / E問の-入力例2でのdp遷移全体の流れは以下のような感じです。図の2進数は右から左、配列は左から右にidが増えるのでご注意を。
最終的なDPテーブルは以下のようになります。
④ 答えを出力
さて、先ほどのdpテーブルのどこが答えになるのでしょうか。巡回セールスマン問題では、最後に元の都市に戻って来ないといけません。なので、3つある都市をすべて訪れた集合111
であって、かつ、最後に都市0にいる場所が答えです(一番左下)。ところが、その最短経路の値がinfになっています。これは、全ての都市を回って、最後に都市0に帰ってくる経路がないということです。それもそのはず、最初に都市0にいるとしたので、dpの更新条件に「遷移先はまだ訪れてない都市」を設定している以上、都市0に戻って来ることが出来ません。
そのため、全ての都市を通った後、最後に都市0に向かい、その中で最短の経路を答えとして出力します。dp[-1]
はすべての都市を通って、0以外の都市にいる最短経路が格納されていますので、これを利用します。
# p[-1]を利用して、都市0に戻る。
# dp[-1][v]はすべての都市を通って、最後に都市vにいる最短経路。
# そこからdist[v][0]でv→0に移動する距離を加えた最短経路の最小が答えになる。
ans = min([dp[-1][v] + dist[v][0] for v in range(n)])
print(ans)
最後に
以上が巡回セールスマンの解説になります。自分は最初にいる都市と、最後にいる都市の考え方で結構苦戦しました。実は上の実装コードも説明用に書いたので、効率的でない部分があります。ただ、しっかり理解すれば、自分で書き換えたりして応用が効くと思います。あと自分は「PythonでbitDPを使い巡回セールスマン問題を解く」この記事を参考にしました。
練習問題を解きたい方は
- ABC318 / D問 (緑Diff)
- ABC301 / E問 (青Diff)
などに挑戦してみてください。
最後に、初めて学んだ内容なので、何か間違いがありましたら、コメントで教えていただけると嬉しいです。