Atcoder Beginner Contest 184にこのような問題があります
制限時間:2sec/ メモリ制限:1024MB
袋の中に金貨が$A$枚、銀貨が$B$枚、銅貨が$C$枚入っています。袋の中にあるいずれかの種類の硬貨が100$枚になるまで以下の操作を繰り返します。
操作:袋の中から硬貨をランダムに$1$枚取り出す。(どの硬貨も等確率で選ばれる。) 取り出した硬貨と同じ種類の硬貨を$2$枚袋に戻す。
操作回数の期待値を求めてください。制約:
・ $0 \leq A, B, C \leq 99$
・ $A + B + C \geq 1$
ABC184 D - increment of coins
この問題は文中の「繰り返す」という単語からも察せられる通り、再帰を用いて考えていくのですが(詳しい解説はリンク先の解説を見てください。)、素直に再帰関数を記述すると、例えば$(A, B, C) = (0, 0, 1)$のようなケースでは処理回数が莫大になり、TLE(制限時間オーバー)となってしまいます。そこで、メモ化回帰をするということが必要になっていきます。これは以下の記事がわかりやすいのでそちらを参照してください。
メモ化再帰を用いてC++で実装したコードがこちらになります。
# include <bits/stdc++.h>
using namespace std;
double dp[101][101][101];
double count(int a,int b,int c){
if(dp[a][b][c]){//メモ化部分。すでに計算済みだったら計算済みの数値を返す
return dp[a][b][c];
}
double d = (double)a/(a+b+c);
double e = (double)b/(a+b+c);
double f = (double)c/(a+b+c);
if(a == 100 || b== 100|| c==100){
return 0;
}else{
double dd = d*(count(a+1, b, c)+1) + e*(count(a, b+1, c)+1) + f*(count(a, b, c+1)+1);
dp[a][b][c] = dd;
return dp[a][b][c];
}
}
int main(void){
int a, b, c;
cin >> a >> b >> c;
cout << setprecision(10) << count(a, b, c) << endl;
return 0;
}
もう少し記述を最適化できると思いますがこれでACできます。C++初心者なので出力するときに有効数字を記述してあげなくてWAを1回出しました。
C++は自分で配列を用意し、計算結果を入れ、チェックするという処理が必要でしたが、Pythonではどう書けばよいでしょうか。実はPythonにはfunctools.lru_cacheというライブラリがその役割を果たしてくれます。
具体的には
@functools.lru_cache(maxsize=128, typed=False)
関数をメモ化用の呼び出し可能オブジェクトでラップし、最近の呼び出し最大 maxsize 回まで保存するするデコレータです。高価な関数や I/O に束縛されている関数を定期的に同じ引数で呼び出すときに、時間を節約できます。
結果のキャッシュには辞書が使われるので、関数の位置引数およびキーワード引数はハッシュ可能でなくてはなりません。
引数のパターンが異なる場合は、異なる呼び出しと見なされ別々のキャッシュエントリーとなります。 例えば、 f(a=1, b=2) と f(b=2, a=1) はキーワード引数の順序が異なっているので、2つの別個のキャッシュエントリーになります。https://docs.python.org/ja/3/library/functools.html
とあるように、関数の計算結果をmaxsize回保存してくれます。maxsize=Noneとすることによって保存回数を無制限にできます。
例えば、フィボナッチ数列において
@lru_cache(maxsize=None)
def fib_memo(n):
if(n < 2):
return n
else:
return fib_memo(n-1)+fib_memo(n-2)
このように記述するだけでメモ化の役割を果たしてくれます。
そこで@lru_cache
デコーダーを実際に使用して上の問題を解くと
from functools import lru_cache
@lru_cache(maxsize=None) #メモ化
def count(a, b, c):
if(a == 100 or b == 100 or c == 100):
return 0
else:
return (a/(a+b+c))*(count(a+1, b, c)+1) + (b/(a+b+c))*(count(a, b+1, c)+1) + (c/(a+b+c))*(count(a, b, c+1)+1)
a,b,c = map(int, input().split())
print(count(a, b, c))
かなり簡単に記述することができました。
しかし、その一方で実行時間を比較してみると
C++ | Python | |
---|---|---|
実行時間(ms) | 34 | 1155 |
となり、かなりの差があります。なので制約がこれより大きくなるとPythonではTLEする危険性もあります。(ちなみに、PythonでTLEした際に回避策として用いられるPypy3だと1452msでした)。
まとめ
Pythonは記述が楽なので早解きする際には使っていきたい一方で制約が大きいとTLEする可能性もありました。なので急いでいない場合にはC++で慎重に書いた方が安全です。両方使えるようになると良いとこ取りできて良さそうですね。