0
0

More than 3 years have passed since last update.

Pythonのlru_cacheが遅かった(と勘違いしたので調査した)件

Last updated at Posted at 2020-11-23

TL;DR

遅くないです。
lru_cachemaxsize は設定しないとデフォルトで 128 に設定されるので気をつけましょう.
参考

動的計画法で functools.lru_cache 使ったら TLE だった

実装

AtCoder Beginner Contest184のDにて、DPの問題が出題されました。

from functools import lru_cache

def solve(n1: int, n2: int, n3: int) -> float:
    @lru_cache
    def expect(a: int, b: int, c: int) -> float:
        if 100 in (a, b, c):
            return 0

        return (a/(a+b+c)) * (expect(a+1, b, c) + 1) \
                + (b/(a+b+c)) * (expect(a, b+1, c) + 1) \
                + (c/(a+b+c)) * (expect(a, b, c+1) + 1)

    return expect(n1, n2, n3)

if __name__ == '__main__':
    a, b, c = map(int, input().split())
    print(solve(a, b, c))

こんな雰囲気で実装したらTLE...

def solve(n1: int, n2: int, n3: int) -> float:
    memo = [[[-1]*101]*101]*101
    def expect(a: int, b: int, c: int) -> float:
        if memo[a][b][c] >= 0:
            return memo[a][b][c]
        elif a == 100 or b == 100 or c == 100:
            memo[a][b][c] = 0
        else:
            memo[a][b] = (a/(a+b+c)) * (expect(a+1, b, c) + 1) \
                + (b/(a+b+c)) * (expect(a, b+1, c) + 1) \
                + (c/(a+b+c)) * (expect(a, b, c+1) + 1)
        return memo[a][b]

    return expect(n1, n2, n3)

if __name__ == '__main__':
    a, b, c = map(int, input().split())
    print(solve(a, b, c))

これだといけた。なんで?

functools.lru_cache を見てみる

実装は ココ に。

マルチスレッドな処理にも対応できるようにしているけど、具体的な処理は次のようになってた。

保持してる情報

lru_cache のラッパーには次の状態が保存されている

  • cache
    キャッシュ
    引数をキーに、返り値を含んだややこしいやつ(root)をバリューに持つ辞書オブジェクト

  • hits / misses
    cache_info() で呼び出せる.
    hits はキャッシュが使われた回数. misses は設定した関数が呼ばれた回数.
    cache_clear() でリセットされる.

  • full
    len(cache)maxsize を超えたら True になる. これ以降設定した関数が呼ばれるたびに root の中からもっとも古く呼ばれているものを削除していく

  • root
    これがとてもややこしいが、 NEXT, PREV KEY, RESULT の配列で、 NEXT には呼ばれた順に、 PREV には呼ばれた順の逆順に再帰的にポインタが格納されている。つまり、 root[PREV][NEXT]root になるようになっている。
    また、一番上の KEYPREV は必ず None になっている。
     これは実質的に自己参照構造体のようなものを表そうとしているのだと思う。Pythonにはそのようなものがないために便宜的に list を使って表したためこのような実装になったのだと考えられる。

root の例

例えば、フィボナッチ数列で次のように呼ばれたとする。

from functools import lru_cache

@lru_cache
def fib(n):
    if n == 1 or n == 0:
        return 1

    return fib(n-1) + fib(n-2)

print(fib(3))

この時、 結果を返す順番は fib(2) -> fib(1) -> fib(3) である。この時 root は、

{
  "PREV": {
    "PREV": {
      "PREV": {
        "PREV": self,
        "NEXT": self,
        "KEY": 2,
        "RESULT": 1
      },
      "NEXT": self,
      "KEY": 1,
      "RESULT": 1
    },
    "NEXT": self,
    "KEY": 3,
    "RESULT": 2
  },
  "NEXT": {
    "PREV": self,
    "NEXT": {
      "PREV": self,
      "NEXT": {
        "PREV": self,
        "NEXT": self,
        "KEY": 3,
        "RESULT": 2
      },
      "KEY": 1,
      "RESULT": 1
    },
    "KEY": 2,
    "RESULT": 1
  },
  "KEY": null,
  "RESULT": null
}

JSONっぽく書いたが、実際にはリスト型である。ここで、 self と書いたのは root 自身のポインタが格納されていることを表している。書き方がわからなくてこんな感じで書いた。
PREV をみると、外側から順に 3 -> 1 -> 2 となっていて、 NEXT をみると、外側から 2 -> 1 -> 3 となっている。キャッシュされた引数で関数が呼ばれると、この順番が変わる。細かい実装はソースコードを見て欲しい。ここでやっていることは、呼ばれた順番を保存しているということ。

maxsize の挙動

まず、 maxsize とは、 lru_cache に置いてキャッシュするレコードの数である。これが設定されている場合、 misses がカウントされるたびに、 len(cache) を確認していて、 maxsize を超えていないかどうかを確認している。 full になるとそれ以降は毎回 root とキャッシュから一番古い情報を削除している。

maxsize の設定

と、ここまでは実は maxsize が設定されている場合の挙動だ。特に設定しなくてもデフォルトで maxsize=128 になる。最初に maxsize=None をしておくと全く挙動が異なる。もはや順番などを考慮する必要はなくなるため実装も簡易。 hitsmisses はあるので fib.cache_info() のようにすることでこれらの情報をみることはできるが、 root やら full やらは存在しない。

今回は何がいけなかったのか

maxsize の設定をしなかったために、 maxsize=128 にセットされてしまい、キャッシュの数が足りなくなってしまったのだと思われる。色々調べた後に maxsize=None を指定したら TLE にもならなかった。また、独自にキャッシュを memo[a][b][c] としていた時よりもわずかに(数十ms程度)速度も改善した。

ということで、不要な時には maxsize=None を明示しましょう。 (ドキュメントもう少し読めばよかった...)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0