TL;DR
遅くないです。
lru_cache
の maxsize
は設定しないとデフォルトで 128
に設定されるので気をつけましょう.
参考
動的計画法で functools.lru_cache
使ったら TLE
だった
実装
AtCoder Beginner Contest184のDにて、DPの問題が出題されました。
from functools import lru_cache
def solve(n1: int, n2: int, n3: int) -> float:
@lru_cache
def expect(a: int, b: int, c: int) -> float:
if 100 in (a, b, c):
return 0
return (a/(a+b+c)) * (expect(a+1, b, c) + 1) \
+ (b/(a+b+c)) * (expect(a, b+1, c) + 1) \
+ (c/(a+b+c)) * (expect(a, b, c+1) + 1)
return expect(n1, n2, n3)
if __name__ == '__main__':
a, b, c = map(int, input().split())
print(solve(a, b, c))
こんな雰囲気で実装したらTLE...
def solve(n1: int, n2: int, n3: int) -> float:
memo = [[[-1]*101]*101]*101
def expect(a: int, b: int, c: int) -> float:
if memo[a][b][c] >= 0:
return memo[a][b][c]
elif a == 100 or b == 100 or c == 100:
memo[a][b][c] = 0
else:
memo[a][b] = (a/(a+b+c)) * (expect(a+1, b, c) + 1) \
+ (b/(a+b+c)) * (expect(a, b+1, c) + 1) \
+ (c/(a+b+c)) * (expect(a, b, c+1) + 1)
return memo[a][b]
return expect(n1, n2, n3)
if __name__ == '__main__':
a, b, c = map(int, input().split())
print(solve(a, b, c))
これだといけた。なんで?
functools.lru_cache
を見てみる
実装は ココ に。
マルチスレッドな処理にも対応できるようにしているけど、具体的な処理は次のようになってた。
保持してる情報
lru_cache
のラッパーには次の状態が保存されている
-
cache
キャッシュ
引数をキーに、返り値を含んだややこしいやつ(root
)をバリューに持つ辞書オブジェクト -
hits
/misses
cache_info()
で呼び出せる.
hits
はキャッシュが使われた回数.misses
は設定した関数が呼ばれた回数.
cache_clear()
でリセットされる. -
full
len(cache)
がmaxsize
を超えたらTrue
になる. これ以降設定した関数が呼ばれるたびにroot
の中からもっとも古く呼ばれているものを削除していく -
root
これがとてもややこしいが、NEXT
,PREV
KEY
,RESULT
の配列で、NEXT
には呼ばれた順に、PREV
には呼ばれた順の逆順に再帰的にポインタが格納されている。つまり、root[PREV][NEXT]
はroot
になるようになっている。
また、一番上のKEY
とPREV
は必ずNone
になっている。
これは実質的に自己参照構造体のようなものを表そうとしているのだと思う。Pythonにはそのようなものがないために便宜的にlist
を使って表したためこのような実装になったのだと考えられる。
root
の例
例えば、フィボナッチ数列で次のように呼ばれたとする。
from functools import lru_cache
@lru_cache
def fib(n):
if n == 1 or n == 0:
return 1
return fib(n-1) + fib(n-2)
print(fib(3))
この時、 結果を返す順番は fib(2)
-> fib(1)
-> fib(3)
である。この時 root
は、
{
"PREV": {
"PREV": {
"PREV": {
"PREV": self,
"NEXT": self,
"KEY": 2,
"RESULT": 1
},
"NEXT": self,
"KEY": 1,
"RESULT": 1
},
"NEXT": self,
"KEY": 3,
"RESULT": 2
},
"NEXT": {
"PREV": self,
"NEXT": {
"PREV": self,
"NEXT": {
"PREV": self,
"NEXT": self,
"KEY": 3,
"RESULT": 2
},
"KEY": 1,
"RESULT": 1
},
"KEY": 2,
"RESULT": 1
},
"KEY": null,
"RESULT": null
}
JSONっぽく書いたが、実際にはリスト型である。ここで、 self
と書いたのは root
自身のポインタが格納されていることを表している。書き方がわからなくてこんな感じで書いた。
PREV
をみると、外側から順に 3 -> 1 -> 2 となっていて、 NEXT
をみると、外側から 2 -> 1 -> 3 となっている。キャッシュされた引数で関数が呼ばれると、この順番が変わる。細かい実装はソースコードを見て欲しい。ここでやっていることは、呼ばれた順番を保存しているということ。
maxsize
の挙動
まず、 maxsize
とは、 lru_cache
に置いてキャッシュするレコードの数である。これが設定されている場合、 misses
がカウントされるたびに、 len(cache)
を確認していて、 maxsize
を超えていないかどうかを確認している。 full
になるとそれ以降は毎回 root
とキャッシュから一番古い情報を削除している。
maxsize
の設定
と、ここまでは実は maxsize
が設定されている場合の挙動だ。特に設定しなくてもデフォルトで maxsize=128
になる。最初に maxsize=None
をしておくと全く挙動が異なる。もはや順番などを考慮する必要はなくなるため実装も簡易。 hits
や misses
はあるので fib.cache_info()
のようにすることでこれらの情報をみることはできるが、 root
やら full
やらは存在しない。
今回は何がいけなかったのか
maxsize
の設定をしなかったために、 maxsize=128
にセットされてしまい、キャッシュの数が足りなくなってしまったのだと思われる。色々調べた後に maxsize=None
を指定したら TLE
にもならなかった。また、独自にキャッシュを memo[a][b][c]
としていた時よりもわずかに(数十ms程度)速度も改善した。
ということで、不要な時には maxsize=None
を明示しましょう。 (ドキュメントもう少し読めばよかった...)