TL;DR
遅くないです。
lru_cache の maxsize は設定しないとデフォルトで 128 に設定されるので気をつけましょう.
参考
動的計画法で functools.lru_cache 使ったら TLE だった
実装
AtCoder Beginner Contest184のDにて、DPの問題が出題されました。
from functools import lru_cache
def solve(n1: int, n2: int, n3: int) -> float:
@lru_cache
def expect(a: int, b: int, c: int) -> float:
if 100 in (a, b, c):
return 0
return (a/(a+b+c)) * (expect(a+1, b, c) + 1) \
+ (b/(a+b+c)) * (expect(a, b+1, c) + 1) \
+ (c/(a+b+c)) * (expect(a, b, c+1) + 1)
return expect(n1, n2, n3)
if __name__ == '__main__':
a, b, c = map(int, input().split())
print(solve(a, b, c))
こんな雰囲気で実装したらTLE...
def solve(n1: int, n2: int, n3: int) -> float:
memo = [[[-1]*101]*101]*101
def expect(a: int, b: int, c: int) -> float:
if memo[a][b][c] >= 0:
return memo[a][b][c]
elif a == 100 or b == 100 or c == 100:
memo[a][b][c] = 0
else:
memo[a][b] = (a/(a+b+c)) * (expect(a+1, b, c) + 1) \
+ (b/(a+b+c)) * (expect(a, b+1, c) + 1) \
+ (c/(a+b+c)) * (expect(a, b, c+1) + 1)
return memo[a][b]
return expect(n1, n2, n3)
if __name__ == '__main__':
a, b, c = map(int, input().split())
print(solve(a, b, c))
これだといけた。なんで?
functools.lru_cache を見てみる
実装は ココ に。
マルチスレッドな処理にも対応できるようにしているけど、具体的な処理は次のようになってた。
保持してる情報
lru_cache のラッパーには次の状態が保存されている
-
cache
キャッシュ
引数をキーに、返り値を含んだややこしいやつ(root)をバリューに持つ辞書オブジェクト -
hits/misses
cache_info()で呼び出せる.
hitsはキャッシュが使われた回数.missesは設定した関数が呼ばれた回数.
cache_clear()でリセットされる. -
full
len(cache)がmaxsizeを超えたらTrueになる. これ以降設定した関数が呼ばれるたびにrootの中からもっとも古く呼ばれているものを削除していく -
root
これがとてもややこしいが、NEXT,PREVKEY,RESULTの配列で、NEXTには呼ばれた順に、PREVには呼ばれた順の逆順に再帰的にポインタが格納されている。つまり、root[PREV][NEXT]はrootになるようになっている。
また、一番上のKEYとPREVは必ずNoneになっている。
これは実質的に自己参照構造体のようなものを表そうとしているのだと思う。Pythonにはそのようなものがないために便宜的にlistを使って表したためこのような実装になったのだと考えられる。
root の例
例えば、フィボナッチ数列で次のように呼ばれたとする。
from functools import lru_cache
@lru_cache
def fib(n):
if n == 1 or n == 0:
return 1
return fib(n-1) + fib(n-2)
print(fib(3))
この時、 結果を返す順番は fib(2) -> fib(1) -> fib(3) である。この時 root は、
{
"PREV": {
"PREV": {
"PREV": {
"PREV": self,
"NEXT": self,
"KEY": 2,
"RESULT": 1
},
"NEXT": self,
"KEY": 1,
"RESULT": 1
},
"NEXT": self,
"KEY": 3,
"RESULT": 2
},
"NEXT": {
"PREV": self,
"NEXT": {
"PREV": self,
"NEXT": {
"PREV": self,
"NEXT": self,
"KEY": 3,
"RESULT": 2
},
"KEY": 1,
"RESULT": 1
},
"KEY": 2,
"RESULT": 1
},
"KEY": null,
"RESULT": null
}
JSONっぽく書いたが、実際にはリスト型である。ここで、 self と書いたのは root 自身のポインタが格納されていることを表している。書き方がわからなくてこんな感じで書いた。
PREV をみると、外側から順に 3 -> 1 -> 2 となっていて、 NEXT をみると、外側から 2 -> 1 -> 3 となっている。キャッシュされた引数で関数が呼ばれると、この順番が変わる。細かい実装はソースコードを見て欲しい。ここでやっていることは、呼ばれた順番を保存しているということ。
maxsize の挙動
まず、 maxsize とは、 lru_cache に置いてキャッシュするレコードの数である。これが設定されている場合、 misses がカウントされるたびに、 len(cache) を確認していて、 maxsize を超えていないかどうかを確認している。 full になるとそれ以降は毎回 root とキャッシュから一番古い情報を削除している。
maxsize の設定
と、ここまでは実は maxsize が設定されている場合の挙動だ。特に設定しなくてもデフォルトで maxsize=128 になる。最初に maxsize=None をしておくと全く挙動が異なる。もはや順番などを考慮する必要はなくなるため実装も簡易。 hits や misses はあるので fib.cache_info() のようにすることでこれらの情報をみることはできるが、 root やら full やらは存在しない。
今回は何がいけなかったのか
maxsize の設定をしなかったために、 maxsize=128 にセットされてしまい、キャッシュの数が足りなくなってしまったのだと思われる。色々調べた後に maxsize=None を指定したら TLE にもならなかった。また、独自にキャッシュを memo[a][b][c] としていた時よりもわずかに(数十ms程度)速度も改善した。
ということで、不要な時には maxsize=None を明示しましょう。 (ドキュメントもう少し読めばよかった...)