競技プログラミングを Ruby でやっているのですが、適切なアルゴリズムを選んでも実装次第で時間制限超過 (TLE) になってしまうことがあります。最近 bit DP を用いた Held–Karp algorithm を書く際にハマったので、どのように書くと良いのか色々試してみました。
※ Held–Karp algorithm は、巡回セールスマン問題などを時間計算量 O(N2·2N) で解くアルゴリズムです。動的計画法の一種で、グラフの「通過した頂点の集合」と「最後に通過した頂点」の組を部分問題として計算・記録します。
TL;DR
基本的には多重ループの内側の処理をなるべく減らします。(当たり前?)
- 計算量の N2 にあたる二重ループの回数自体を減らすため、集合をビット列で表した整数のビットが立っている桁と立っていない桁を事前に列挙します
- ビット列が変わる度にリストを求め直すよりも、前回からの差分だけ更新するほうが速いです
- DPのテーブルの各セルへは1回のみ書き込むようにできます
- そのため書き込み済みの値を毎回読み込んで比較する必要がありません
- 解く問題によっては配列の集計メソッドを使えるかもしれません(最小値の計算には効果が無さそうでした)
- その他、二次元配列を参照する際は行の抽出部分だけループの外に出しておきます
例えば有向グラフの全頂点を1回ずつ通るパス(全部で n! 通り)の最短距離を求める場合は、以下のようにできます1。
##
# 有向グラフの全頂点を1回ずつ通るパスの最短距離を求める
#
# @param n [Integer] 頂点の数
# @param d [Array] 隣接行列、有向辺 u → v の距離は d[v][u]
# @return [Float] 最短経路の距離
#
def bit_dp(n, d)
pow_n = 1 << n
z = 1e30
table = Array.new(pow_n) { Array.new(n, z) }
n.times { |v| table[1 << v][v] = 0.0 }
# 二進 n 桁整数の、ビットの立っている桁と立っていない桁を列挙(初期化)
lasts = []
nexts = [*0...n].reverse!
(1...pow_n).each do |bitset| # 注意: 0 を除くこと
row = table[bitset]
# ビットの立っている桁と立っていない桁を列挙(前回から差分更新)
v = nexts.pop
nexts.concat(lasts.pop(v))
lasts << v
nexts.each do |v|
dv = d[v]
# 最小値を計算
d_new = z # テーブルの値は読み込まなくていい
lasts.each do |u|
d_tmp = row[u] + dv[u]
d_new = d_tmp if d_tmp < d_new
end
# ここの参照先は毎回異なり、同じセルへ2回書き込むことは起きない
table[bitset | 1 << v][v] = d_new
end
end
table[-1].min
end
#--- main ---#
require 'benchmark'
puts RUBY_DESCRIPTION
n, seed = gets.split.values_at(0, 1).map!(&:to_i)
puts "n = #{n}, seed = #{seed}"
srand(seed)
dist = Array.new(n) { Array.new(n) { rand } }
Benchmark.bmbm do |x|
x.report { puts bit_dp(n, dist) }
end
試したコード
初期の形
「通過済みの頂点の集合」をビット列で表した整数 bitset
がある値をとる際に、通過済みの頂点 u
から未通過の頂点 v
へパスを伸ばしたときの距離を全パターン調べて比較していきます。単純に実装すると、 bitset
の u
, v
ビット目が立っているかどうか確認する必要があります。
def bit_dp(n, d)
pow_n = 1 << n
z = 1e30
table = Array.new(pow_n) { Array.new(n, z) }
n.times { |v| table[1 << v][v] = 0.0 }
(1...pow_n).each do |bitset|
n.times do |v|
n.times do |u|
next if bitset[v] == 1 || bitset[u] == 0
d_tmp = table[bitset][u] + d[v][u]
table[bitset | 1 << v][v] = d_tmp if d_tmp < table[bitset | 1 << v][v]
end
end
end
table[-1].min
end
二重ループの添字をあらかじめ絞り込む
上のコードで u
と v
の二重ループは、毎回ある桁にビットが立っているか確認していて非効率です。 u
と v
のとる添字を列挙しておくことで、必要最小限のループ(約 1/4 の量)を回すことができます。
(1...pow_n).each do |bitset|
- n.times do |v|
- n.times do |u|
- next if bitset[v] == 1 || bitset[u] == 0
-
+ lasts, nexts = (0...n).partition { |v| bitset[v] == 1 }
+
+ nexts.each do |v|
+ lasts.each do |u|
d_tmp = table[bitset][u] + d[v][u]
ちなみに二重ループを Array#product
で一重にすることもできますが、かえって遅くなってしまいました。
保存先のセルへの読み書きを減らす
DPテーブルのセル table[bitset | 1 << v][v]
との比較を毎回やっていますが、 u
のループを回している間はこのセルの位置は変化しません。 Array#[]
や Array#[]=
をループの外に出し、代わりに適当なローカル変数を使うようにします。
nexts.each do |v|
+ d_new = table[bitset | 1 << v][v]
lasts.each do |u|
d_tmp = table[bitset][u] + d[v][u]
- table[bitset | 1 << v][v] = d_tmp if d_tmp < table[bitset | 1 << v][v]
+ d_new = d_tmp if d_tmp < d_new
end
+ table[bitset | 1 << v][v] = d_new
end
その他、ループ内の同じ処理を外に出す
二次元配列 table
, d
を参照する処理のうち u
の関わらない部分(行の抽出)は、ループの外側で済ませられます。
これでコードの一番内側のループは非常に短くなりました。(さらに短くする検討は後ほどします)
(1...pow_n).each do |bitset|
+ row = table[bitset]
lasts, nexts = (0...n).partition { |v| bitset[v] == 1 }
nexts.each do |v|
+ dv = d[v]
d_new = table[bitset | 1 << v][v]
lasts.each do |u|
- d_tmp = table[bitset][u] + d[v][u]
+ d_tmp = row[u] + dv[u]
d_new = d_tmp if d_tmp < d_new
end
table[bitset | 1 << v][v] = d_new
end
end
DPテーブルから読み込むのをやめる
定数倍高速化には寄与しない箇所についても、少しでも実行時間制限への余裕を稼ぐために改善を検討していきます。
実は前節のコードだと、テーブルの同じセルに2回以上書き込むことはありません。したがって v
のループの始めに d_new = table[bitset | 1 << v][v]
としている箇所は、単にテーブルを初期化したときの値 z
を読み込んでいるだけです。
※ 初期化で table[1 << v][v] = ...
とした箇所を上書きしないよう、 bitset
のループの添字は 0
を除いておく必要があります。
nexts.each do |v|
dv = d[v]
- d_new = table[bitset | 1 << v][v]
+ d_new = z
lasts.each do |u|
こうなるとテーブル初期化時に z
で埋める必要も無いのですが、 Ruby だと結局 nil
で埋めることになるため効果を見込めないのと、処理後のテーブルが nil
で歯抜けになるため解く問題によっては面倒になりえます。
ビットの桁のリストを差分更新する
ビットの立っている桁を毎回 O(N) で列挙していましたが、似た bitset
からは似た結果が得られるので非効率な気がします。
bitset lasts nexts
0b1100_0110 -> [7, 6, 2, 1] , [5, 4, 3, 0]
0b1100_0111 -> [7, 6, 2, 1, 0], [5, 4, 3]
0b1100_1000 -> [7, 6, 3] , [5, 4, 2, 1, 0]
そこで前回の結果を利用して差分だけ更新するように変えてみます。 bitset
が 1 増えたとき、全体の 1/2 の場面(偶数→奇数)では要素 1 つがリスト間を移動するだけです。 1/4 の場面では 2 つ、 1/8 の場面では 3 つ、…となるため平均的には要素 2 つの移動だけで済みます。つまり差分更新の計算量は ならし O(1) です。
+ lasts = []
+ nexts = [*0...n].reverse!
+
(1...pow_n).each do |bitset|
row = table[bitset]
- lasts, nexts = (0...n).partition { |v| bitset[v] == 1 }
+ v = nexts.pop
+ nexts.concat(lasts.pop(v))
+ lasts << v
※ この実装は pow_n
回目の更新で v = nil
となり例外が発生します。今回はループが pow_n - 1
回なので問題ありません。
ここまで変形したコードをコメント付きで記事冒頭に載せています。
配列のメソッドで集計する
高速化のためには、C言語で実装された Ruby のメソッドに任せる(自分で書かない)というのも手です。
u
のループでしていることは最小値の計算なので、 Array#min
で置き換えられます。(いまは bitset != 0
を前提としているので、配列 lasts
は空ではありません)
ただし最小値の計算の場合、最初に #map
で添字から値に変換する必要があるため、非効率な可能性があります。 #sum
などで集計するタイプの問題なら効率的かもしれません。
nexts.each do |v|
dv = d[v]
- d_new = z
- lasts.each do |u|
- d_tmp = row[u] + dv[u]
- d_new = d_tmp if d_tmp < d_new
- end
+ d_new = lasts.map do |u|
+ row[u] + dv[u]
+ end.min
table[bitset | 1 << v][v] = d_new
end
計測
手元のPC上でメソッドを何回か実行したときの平均秒数を以下の表に示します。
バージョン: Ruby 2.7.1 ( AtCoder と同じもの)
コード | n = 16 | n = 17 | n = 18 |
---|---|---|---|
初期の形 | 1.862 | 4.116 | 9.182 |
添字絞り込み | 0.878 | 1.972 | 4.252 |
読み書き削減 | 0.671 | 1.480 | 3.213 |
行を抽出 | 0.618 | 1.350 | 2.930 |
読み込み削減 | 0.589 | 1.290 | 2.783 |
桁を差分更新 | 0.462 | 1.015 | 2.377 |
map, min 利用 | 0.510 | 1.130 | 2.547 |
- 二重ループの添字をあらかじめ絞り込むのは効果が大きく、時間が半減しています。
- ループ内部を簡略化していくことで、さらに3割ほど改善できています。
- ビットの立っている桁のリストを差分更新するのも、無視できない程度には効果が出ています。
- 一方で、最小値の計算を配列のメソッドで行うと、少し遅くなっています。
TLE 回避のためであれば、読み書き削減まで適用したらあとは書きやすいように書いてもよさそうです。
おまけ
以前にスマホのパターンロックの総数を求めた際も同じアルゴリズムを使いましたが、コードは添字の絞り込みまでしかしていませんでした。今回調べた高速化を施したところ、同じく実行時間が3割ほど減りました。
+lasts = []
+nexts = [*0...N].reverse!
+
(1...POW_N).each do |bitset|
+ row = table[bitset]
+
# 使用済みの点 i から未使用の点 j へ線を引くことで新しいパターンを作れる
- lasts, nexts = (0...N).partition { |k| bitset[k] == 1 }
+ k = nexts.pop
+ nexts.concat(lasts.pop(k))
+ lasts << k
+
- lasts.product(nexts) do |i, j|
- s = passing_points[i][j]
- next if bitset & s != s # 通過する点が1つでも未使用なら線を引けない
- table[bitset | 1 << j][j] += table[bitset][i]
- end
+ nexts.each do |j|
+ ppj = passing_points[j] # 対称行列なので列を抜き出すのと同じ
+ cnt = lasts.sum do |i|
+ s = ppj[i]
+ bitset.allbits?(s) ? row[i] : 0 # 通過する点が1つでも未使用なら線を引けない
+ end
+ table[bitset | 1 << j][j] = cnt
+ end
end
参考
-
巡回セールスマン問題(ハミルトン閉路の最短距離)に適用する場合は、頂点 n を外して残りの (n-1) 点で bit DP するといいです。初期値の登録や最後の集計の際に、頂点 n との距離を考慮します。 ↩