概要
AtCoderのDP(Dynamic Programming: 動的計画法)問題を解いていたところ、他の言語ならACになるコードがPython/PyPyだとTLEになるという現象に遭遇したため、備忘録として残しておきます。
DP問題について
こちらの記事で詳しく解説されているようにDP問題の解き方はいくつかありますが、ざっくり分けるとボトムアップ的にfor文を回してDPテーブルを構築する手法と、トップダウン的に再帰関数を呼び出すメモ化再帰の二通りがあります。いずれにしても計算済みの結果を保持しておくことで処理の高速化を図るという点では同様ですが、メモ化再帰の場合は関数呼び出しのオーバーヘッドがあるため、一部の言語では制限時間内に処理が終わらずTLEになるということがあります。
例題
解法1:for文によるボトムアップ的なDP
N, K = map(int, input().split())
h = list(map(int, input().split()))
dp = [10**9] * N
for i in range(N):
if i == 0:
dp[0] = 0
elif i == 1:
dp[1] = abs(h[0]-h[1])
else:
for j in range(1, K+1):
if j > i:
break
dp[i] = min(dp[i], dp[i - j] + abs(h[i - j] - h[i]))
print(dp[N-1])
結果
言語 | 結果 |
---|---|
Python | TLE |
PyPy | AC |
Go | AC |
Pythonで提出するとTLEになり、PyPyで提出するとACになります。また、Goなどの高速な言語で同様のコードを書いてもACを取れます。
解法2:メモ化再帰(関数呼び出し回数:大)
import sys
sys.setrecursionlimit(10**6)
N, K = map(int, input().split())
h = list(map(int, input().split()))
INF = 10**9
dp = [INF] * N
def solve(n):
if dp[n] != INF:
return dp[n]
if n == 0:
dp[n] = 0
elif n == 1:
dp[n] = abs(h[0] - h[1])
else:
for i in range(1, K + 1):
if n - i < 0:
break
cost = solve(n - i) + abs(h[n - i] - h[n])
dp[n] = min(dp[n], cost)
return dp[n]
print(solve(N-1))
package main
import (
"fmt"
"math"
)
var (
N, K, INF int
h, dp []int
)
func solve(n int) int {
if dp[n] != INF {
return dp[n]
}
if n == 0 {
dp[n] = 0
} else if n == 1 {
dp[n] = int(math.Abs(float64(h[n] - h[n-1])))
} else {
for i := 1; i <= K; i++ {
if n - i < 0 {
break
}
var cost = solve(n - i) + int(math.Abs(float64(h[n-i]-h[n])))
dp[n] = int(math.Min(float64(dp[n]), float64(cost)))
}
}
return dp[n]
}
func main() {
fmt.Scan(&N, &K)
h = make([]int, N)
dp = make([]int, N)
INF = int(math.Pow(10, 9))
for i := 0; i < N; i++ {
fmt.Scan(&h[i])
}
for i := 0; i < N; i++ {
dp[i] = INF
}
fmt.Println(solve(N - 1))
}
結果
言語 | 結果 |
---|---|
Python | TLE |
PyPy | TLE |
Go | AC |
いわゆる普通のメモ化再帰なのですが、PythonやPyPyで提出するとTLEになり、Go等の高速な言語で書いて提出するとACになります。Python/PyPyでACを取るには、もう一工夫が必要です。
解法3:メモ化再帰(関数呼び出し回数:小)
import sys
sys.setrecursionlimit(10**6)
N, K = map(int, input().split())
h = list(map(int, input().split()))
INF = 10**9
dp = [INF] * N
def solve(n):
if dp[n] != INF:
return dp[n]
if n == 0:
dp[n] = 0
elif n == 1:
dp[n] = abs(h[0] - h[1])
else:
for i in range(1, K+1):
if n - i < 0:
break
if dp[n - i] != INF:
# 関数呼び出し回数を減らすため、dp[n-i]が計算済みの場合はそれを使う
cost = dp[n - i] + abs(h[n - i] - h[n])
else:
cost = solve(n - i) + abs(h[n - i] - h[n])
dp[n] = min(dp[n], cost)
return dp[n]
print(solve(N-1))
結果
言語 | 結果 |
---|---|
Python | TLE |
PyPy | AC |
Go | AC |
解法2とほぼ同じコードですが、こちらの場合はsolve(n-i)を呼び出す前にdp[n-i]が計算済みかどうかをチェックし、計算済みならそれを使うようにしています。こうすることで関数呼び出しのオーバーヘッドを減らし、PyPyでACさせることができるようになります。ただし、この工夫をもってしてもPythonで提出するとTLEになります。
まとめ
各解法と言語別の結果は下記のようになりました。つまりPythonで提出するとどうやってもTLE、PyPyだと工夫すればAC、Go等の高速な言語なら特に意識せずACを取れるということになります。
ということで、DP問題はC++やGo等の高速な言語で提出するのが得策かもしれません。どうしてもPythonで解きたい場合は、下記を意識した方が良さそうです。
- 可能であればメモ化再帰ではなく、for文によるボトムアップ的DPを行う
- メモ化再帰を行う場合は、関数呼び出しの回数を極力減らす(関数を呼ぶ前にDPテーブルをチェックする)
- PythonではなくPyPyで提出する
その他お気づきの点や認識等間違っている部分がありましたら、ご指摘頂けると助かります。