#はじめに
この記事はABC233 G. Strongest TakahashiをPythonでTLEせずに解くテクニックについて書いていきます.
Pythonの実行時間を短縮するための小技のようなものを挙げていくので緑~水色の人も読んで損はないと思います.
使用言語はPyPy3です.
この問題についてですがDifficultyが2438と高めなのも厄介ですが, とにかくPythonでは制限時間がきついです.
問題の内容と解法については問題ページから確認できます.
##1. 公式解法の通りに実装
まず公式解法の通りに再帰関数を使ってDFSを実装してみました.
計算量はO(N^6)です
N = int(input())
S = [input() for _ in range(N)]
dp = [[[[None]*N for _ in range(N)] for _ in range(N)] for _ in range(N)]
def dfs(Rlo,Rhi,Clo,Chi):
if (Rlo > Rhi) or (Clo > Chi):
return 0
if not dp[Rlo][Rhi][Clo][Chi] is None:
return dp[Rlo][Rhi][Clo][Chi]
dp[Rlo][Rhi][Clo][Chi] = max(Rhi-Rlo+1,Chi-Clo+1)
for i in range(Rlo,Rhi+1):
for j in range(Clo,Chi+1):
if S[i][j] == "#":
break
else:
dp[Rlo][Rhi][Clo][Chi] = min(dp[Rlo][Rhi][Clo][Chi],dfs(Rlo,i-1,Clo,Chi)+dfs(i+1,Rhi,Clo,Chi))
for i in range(Clo,Chi+1):
for j in range(Rlo,Rhi+1):
if S[j][i] == "#":
break
else:
dp[Rlo][Rhi][Clo][Chi] = min(dp[Rlo][Rhi][Clo][Chi],dfs(Rlo,Rhi,Clo,i-1)+dfs(Rlo,Rhi,i+1,Chi))
return dp[Rlo][Rhi][Clo][Chi]
print(dfs(0,N-1,0,N-1))
###結果
はい、余裕でTLEです(笑).
およそ半分のケースで引っかかっています.
##2. O(N^5)に落としてみた
dfs内の二重forループを一重にできればO(N^5)に落とせそうです.
そこで前処理として「i行目のl列目からr列目までに"#"は含まれるか」を判定できる3次元配列を作りました.
コード内のpre_rとpre_cがこれに当たります.
N = int(input())
S = [input() for _ in range(N)]
dp = [[[[None]*N for _ in range(N)] for _ in range(N)] for _ in range(N)]
pre_r = [[[0]*N for _ in range(N)] for _ in range(N)]
for R in range(N):
for l in range(N):
for r in range(N):
for k in range(l,r+1):
if S[R][k] == "#":
pre_r[R][l][r] = 1
pre_c = [[[0]*N for _ in range(N)] for _ in range(N)]
for C in range(N):
for l in range(N):
for r in range(N):
for k in range(l,r+1):
if S[k][C] == "#":
pre_c[C][l][r] = 1
def dfs(Rlo,Rhi,Clo,Chi):
if (Rlo > Rhi) or (Clo > Chi):
return 0
if not dp[Rlo][Rhi][Clo][Chi] is None:
return dp[Rlo][Rhi][Clo][Chi]
dp[Rlo][Rhi][Clo][Chi] = max(Rhi-Rlo+1,Chi-Clo+1)
for i in range(Rlo,Rhi+1):
if not pre_r[i][Clo][Chi]:
dp[Rlo][Rhi][Clo][Chi] = min(dp[Rlo][Rhi][Clo][Chi],dfs(Rlo,i-1,Clo,Chi)+dfs(i+1,Rhi,Clo,Chi))
for i in range(Clo,Chi+1):
if not pre_c[i][Rlo][Rhi]:
dp[Rlo][Rhi][Clo][Chi] = min(dp[Rlo][Rhi][Clo][Chi],dfs(Rlo,Rhi,Clo,i-1)+dfs(Rlo,Rhi,i+1,Chi))
return dp[Rlo][Rhi][Clo][Chi]
print(dfs(0,N-1,0,N-1))
###結果
これまた余裕でTLEです(笑).
TLEの数はさっきとほとんど変わりません.
##3. テクニック①:4次元配列をやめる
コード中の4次元配列dpを1次元に書き換えてみます.
やり方は簡単で, 次のような変換を行ってやればいいです.
conv = lambda Rlo,Rhi,Clo,Chi:Rlo*(N+1)**3+Rhi*(N+1)**2+Clo*(N+1)+Chi
dp = [0]*(N+1)**4
dp[conv(Rlo,Rhi,Clo,Chi)] #もともとはdp[Rlo][Rhi][Clo][Chi]
###結果
TLEの数を見るに, おそらく少し早くなりました.
Pythonのlistは動的配列なので1次元にすると早くなるのも納得できるような気がします.
でもまだまだ道は長そうです...
##4. テクニック②:配列参照を行う回数を減らす
いままでのコードだとdp[conv(Rlo,Rhi,Clo,Chi)]という配列参照をdfs内で何度も行っていました.
この回数を減らしてみます.
具体的には適当な値で代替して, 最後にdpに代入すればいいです.
# ~~~~~~~ 略
def dfs(Rlo,Rhi,Clo,Chi):
if (Rlo > Rhi) or (Clo > Chi):
return 0
if not dp[conv(Rlo,Rhi,Clo,Chi)] is None:
return dp[conv(Rlo,Rhi,Clo,Chi)]
tmp = max(Rhi-Rlo+1,Chi-Clo+1) #tmpと置いた
for i in range(Rlo,Rhi+1):
if not pre_r[i][Clo][Chi]:
tmp = min(tmp,dfs(Rlo,i-1,Clo,Chi)+dfs(i+1,Rhi,Clo,Chi))
for i in range(Clo,Chi+1):
if not pre_c[i][Rlo][Rhi]:
tmp = min(tmp,dfs(Rlo,Rhi,Clo,i-1)+dfs(Rlo,Rhi,i+1,Chi))
dp[conv(Rlo,Rhi,Clo,Chi)] = tmp
return tmp
print(dfs(0,N-1,0,N-1))
###結果
ここまでTLEの数を減らすことができました.
このテクニックは結構な場面で使うような気がします.
しかしまだTLEは残っています.
##5. テクニック③:dpの初期値を変えてみる
今までdpの初期値をNoneとしていたのを-1と変えてみます.
dp = [-1]*N**4
###結果
TLE×3!!!
dpの初期値を変えてみると結構早くなることもあるような気がします.
この後色々試してみましたが, 再帰を使った方針ではこれ以上の高速化は望めませんでした.
誰かいい方法を知っていれば教えてほしいです.
マイメロ、ショック...
##6. 結論:再帰をやめる
知ってる人も多いと思いますが, Pythonは再帰が遅いです. PyPyは特に遅いです.
理由はよくわかりませんがメモリが関係していると思います(名推理).
再帰を回避する方法として次の2パターンが考えられます.
①非再帰dfsをする
②ボトムアップに解く
①に関してはスタックでやるらしいですが, やり方しらないですごめんなさい.
先ほどまではトップダウンに, つまり一番大きい長方形から始めて徐々に細かく見ていくという方法で解いていました.
これを逆にボトムアップに一番小さい長方形から解いていくようにすると再帰を使わないで通常のDP問題として扱えます.
実装のコツはコード中のdr,dcのように長方形の一辺を決めるループを一番外に持ってくることです.
N = int(input())
S = [input() for _ in range(N)]
conv = lambda Rlo,Rhi,Clo,Chi:Rlo*(N+1)**3+Rhi*(N+1)**2+Clo*(N+1)+Chi
dp = [0]*(N+1)**4
pre_r = [[[0]*N for _ in range(N)] for _ in range(N)]
for R in range(N):
for l in range(N):
for r in range(N):
for k in range(l,r+1):
if S[R][k] == "#":
pre_r[R][l][r] = 1
break
pre_c = [[[0]*N for _ in range(N)] for _ in range(N)]
for C in range(N):
for l in range(N):
for r in range(N):
for k in range(l,r+1):
if S[k][C] == "#":
pre_c[C][l][r] = 1
break
for dr in range(1,N+1):
for dc in range(1,N+1):
for r in range(N):
if r+dr >= N+1:
break
for c in range(N):
if c+dc >= N+1:
break
if dr == 1 and dc == 1:
dp[conv(r,r+1,c,c+1)] = 1 if S[r][c] == "#" else 0
continue
tmp = max(dr,dc)
for i in range(r,r+dr):
if not pre_r[i][c][c+dc-1]:
tmp = min(tmp,dp[conv(r,i,c,c+dc)]+dp[conv(i+1,r+dr,c,c+dc)])
for j in range(c,c+dc):
if not pre_c[j][r][r+dr-1]:
tmp = min(tmp,dp[conv(r,r+dr,c,j)]+dp[conv(r,r+dr,j+1,c+dc)])
dp[conv(r,r+dr,c,c+dc)] = tmp
print(dp[conv(0,N,0,N)])
###結果
実行時間は1752msでした.
なんとかギリギリ通すことができました.
##おわりに
こういったチャレンジングな問題を通じて効率的なPythonの書き方を体得していけたらいいですね.
ちなみにこの問題を本番中にPythonで通した人は2人いますが
2人とも公式解法よりも少し効率のいい解法で解いていて実行時間に余裕をもって通していました.
本当にすごい思います.
##追記
Twitterで@mink1618033さんが再帰でも通せる方法を教えて下さいました!
1次元配列にするためのconvを関数にするのではなくベタ書きするのがポイントらしいです!
@mink1618033さんのコード