LoginSignup
10
5

More than 3 years have passed since last update.

DP問題をPythonで解く際の注意点

Last updated at Posted at 2020-06-18

概要

AtCoderのDP(Dynamic Programming: 動的計画法)問題を解いていたところ、他の言語ならACになるコードがPython/PyPyだとTLEになるという現象に遭遇したため、備忘録として残しておきます。

DP問題について

こちらの記事で詳しく解説されているようにDP問題の解き方はいくつかありますが、ざっくり分けるとボトムアップ的にfor文を回してDPテーブルを構築する手法と、トップダウン的に再帰関数を呼び出すメモ化再帰の二通りがあります。いずれにしても計算済みの結果を保持しておくことで処理の高速化を図るという点では同様ですが、メモ化再帰の場合は関数呼び出しのオーバーヘッドがあるため、一部の言語では制限時間内に処理が終わらずTLEになるということがあります。

例題

DPまとめコンテストB問題: Frog2

解法1:for文によるボトムアップ的なDP

N, K = map(int, input().split())
h = list(map(int, input().split()))
dp = [10**9] * N

for i in range(N):
    if i == 0:
        dp[0] = 0
    elif i == 1:
        dp[1] = abs(h[0]-h[1])
    else:
        for j in range(1, K+1):
            if j > i:
                break
            dp[i] = min(dp[i], dp[i - j] + abs(h[i - j] - h[i]))

print(dp[N-1])

結果

言語 結果
Python TLE
PyPy AC
Go AC

Pythonで提出するとTLEになり、PyPyで提出するとACになります。また、Goなどの高速な言語で同様のコードを書いてもACを取れます。

解法2:メモ化再帰(関数呼び出し回数:大)

Pythonで書いた場合
import sys
sys.setrecursionlimit(10**6)

N, K = map(int, input().split())
h = list(map(int, input().split()))
INF = 10**9
dp = [INF] * N

def solve(n):
    if dp[n] != INF:
        return dp[n]

    if n == 0:
        dp[n] = 0
    elif n == 1:
        dp[n] = abs(h[0] - h[1])
    else:
        for i in range(1, K + 1):
            if n - i < 0:
                break
            cost = solve(n - i) + abs(h[n - i] - h[n])
            dp[n] = min(dp[n], cost)

    return dp[n]

print(solve(N-1))
Goで書いた場合
package main

import (
    "fmt"
    "math"
)

var (
    N, K, INF int
    h, dp []int
)

func solve(n int) int {
    if dp[n] != INF {
        return dp[n]
    }

    if n == 0 {
        dp[n] = 0
    } else if n == 1 {
        dp[n] = int(math.Abs(float64(h[n] - h[n-1])))
    } else {
        for i := 1; i <= K; i++ {
            if n - i < 0 {
                break
            }
            var cost =  solve(n - i) + int(math.Abs(float64(h[n-i]-h[n])))
            dp[n] = int(math.Min(float64(dp[n]), float64(cost)))
        }
    }

    return dp[n]
}

func main() {
    fmt.Scan(&N, &K)
    h = make([]int, N)
    dp = make([]int, N)
    INF = int(math.Pow(10, 9))

    for i := 0; i < N; i++ {
        fmt.Scan(&h[i])
    }
    for i := 0; i < N; i++ {
        dp[i] = INF
    }

    fmt.Println(solve(N - 1))
}

結果

言語 結果
Python TLE
PyPy TLE
Go AC

いわゆる普通のメモ化再帰なのですが、PythonやPyPyで提出するとTLEになり、Go等の高速な言語で書いて提出するとACになります。Python/PyPyでACを取るには、もう一工夫が必要です。

解法3:メモ化再帰(関数呼び出し回数:小)

import sys
sys.setrecursionlimit(10**6)

N, K = map(int, input().split())
h = list(map(int, input().split()))
INF = 10**9
dp = [INF] * N

def solve(n):
    if dp[n] != INF:
        return dp[n]

    if n == 0:
        dp[n] = 0
    elif n == 1:
        dp[n] = abs(h[0] - h[1])
    else:
        for i in range(1, K+1):
            if n - i < 0:
                break
            if dp[n - i] != INF:
                # 関数呼び出し回数を減らすため、dp[n-i]が計算済みの場合はそれを使う
                cost = dp[n - i] + abs(h[n - i] - h[n])
            else:
                cost = solve(n - i) + abs(h[n - i] - h[n])
            dp[n] = min(dp[n], cost)

    return dp[n]

print(solve(N-1))

結果

言語 結果
Python TLE
PyPy AC
Go AC

解法2とほぼ同じコードですが、こちらの場合はsolve(n-i)を呼び出す前にdp[n-i]が計算済みかどうかをチェックし、計算済みならそれを使うようにしています。こうすることで関数呼び出しのオーバーヘッドを減らし、PyPyでACさせることができるようになります。ただし、この工夫をもってしてもPythonで提出するとTLEになります。

まとめ

各解法と言語別の結果は下記のようになりました。つまりPythonで提出するとどうやってもTLE、PyPyだと工夫すればAC、Go等の高速な言語なら特に意識せずACを取れるということになります。

Screen Shot 2020-06-18 at 22.01.10.png

ということで、DP問題はC++やGo等の高速な言語で提出するのが得策かもしれません。どうしてもPythonで解きたい場合は、下記を意識した方が良さそうです。

  • 可能であればメモ化再帰ではなく、for文によるボトムアップ的DPを行う
  • メモ化再帰を行う場合は、関数呼び出しの回数を極力減らす(関数を呼ぶ前にDPテーブルをチェックする)
  • PythonではなくPyPyで提出する

その他お気づきの点や認識等間違っている部分がありましたら、ご指摘頂けると助かります。

10
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
5