Pythonのメモ化機能 LRU Cache
PythonではLRU Cacheという機能があり、問題をまず再帰関数でプログラムしたあとLRU Cacheをオンにすることで劇的にスピードアップできる場合があります。
でもその仕組を理解した上で、自分でDICT型を使って実装すれば、さらに効率化できるという例を今回は示したいと思います。
例題
トランプの52枚をシャッフルしたとき、どのカードも隣のカードと数字が同じにならない(ペアがまったくない)確率を求めよ
解法#1 すべての順列を発生させる
まず今後のプログラムの答え合わせ用にpythonのdistinct_permutationsを使ってすべての順列を発生させて数えてみます。数字を13までではとても無理なので数字は4までの16枚で走らせます。
答えは約3.59%となりました。
from more_itertools import distinct_permutations
N, s = 4, 4
def nopair(s):
for i in range(1,len(s)):
if s[i]==s[i-1]: return False
return True
cards = "".join([str(i)*s for i in range(1,N+1)])
print(f"N={N}, cards:{cards}")
nopr, total = 0, 0
for p in distinct_permutations(cards):
nopr += nopair(''.join(p))
total += 1
print(f"No Pair:{nopr}, Total:{total}, Probability:{nopr/total:08f}")
# N=4, cards:1111222233334444
# No Pair:2265024, Total:63063000, Probability:0.035917
解法#2 再帰関数で実装
再帰関数の引数のデータを以下のように定義します。これをカードの残り枚数の合計が0になるまで続けてペアがあった数字がなかった場合の数を数えます。ここでのポイントは後のメモ化のことを考えずに素直にインプリすることです。
以下のコードでは、すべての数値は0始まりで記憶しているので、数字は$0 - 12$で表されます。
引数 | データ | 初期値 |
---|---|---|
cards | 各数字のカードの残り枚数 | [4,4,4,4] |
lastr | 前回の数字 | -1 |
pairs | ペアがあった数字をビット表現 | 0b0000 |
from copy import copy
N, Suits = 4,4
bts = [1<<r for r in range(N)]
def shuff(cards, lastr, pairs): # Remaining cards, last rank, pair found(bit expression)
if sum(cards)==0:
return 1, pairs.bit_count()==0 # (return total, 1: pair found, 0: not found)
total, num = 0, 0
for r in range(N):
if cards[r] > 0:
cards1 = copy(cards)
cards1[r] -= 1
total1, num1 = shuff(cards1,r, pairs | (bts[r]*int(r==lastr)))
total += total1
num += num1
return (total, num)
total, rperf = shuff([Suits for i in range(N)], -1, 0)
print(f"N = {N}, Answer: {(rperf/total):.010f} = {rperf} / {total}")
# N = 4, Answer: 0.0359168451 = 2265024 / 63063000
時間はかかりましたが、解法#1と同じ答えがでたのでプログラムは正しそうです。このプログラムのコードをなるべく変更せずにメモ化をして行きます。
解法#3 DICT変数を使ってメモ化
まず関数の引数をそのままDICT変数にメモ化します。そのため引数keyにできるようにhashableなものに変更する必要があります。リストが2重になったりしているので関数hashableを定義します
def hashable(args):
return tuple(tuple(l) if type(l)==list else l for l in args)
print(hashable( [[4,4,4,4], -1, 0]))
# ((4, 4, 4, 4), -1, 0)
このhashableで作ったhkeyをキーにして、関数の出力を辞書に格納して。入力が同じならその出力を返します。そのために関数shuffの入口と出口に以下の4行(####1,2,3,4) を挿入します。
mem = dict() #### 1
def shuff(cards, lastr, pairs): # Remaining cards, last rank, pair found(bit expression)
hkey = hashable((cards, lastr, pairs)) #### 2
if hkey in mem: return mem[hkey] #### 3
:
mem[hkey] = (total, num) #### 4
return (total, num)
# N = 7, Answer: 2778291737177034960 / 66475579247327250000 = 0.0417941712
これはほぼLRU Cacheを使ったのと同等の効果が期待できます。
以下のように$N=7$くらいまでは求められるようになりました。でも$N=13$まではさらなる改善が必要です。
解法 | N=4 | N=7 |
---|---|---|
#1 | 96sec | - |
#2 | 130sec | - |
#3 | 0sec | 143sec |
解法#4 データの対称性を使う
解法#3のメモ化では、以下のように同じ答えになるデータを別のものとしていますが、これを同じキーにできれば更に効率を上げることができるはずです。
ケース | データ | |
---|---|---|
1を2枚取った場合 | ((2, 4, 4, 4), 0, 1) | - |
2を2枚取った場合 | ((4, 2, 4, 4), 1, 2) | - |
このデータを数字ごとに (残り枚数、前のカードだった、ペアがあった) を記録するようにすると、ソートすることによって、同じものとみなすことができます。
ケース | データ | ソート |
---|---|---|
((2,4,4,4),0,1) | ((2,1,1),(4,0,0),(4,0,0),(4,0,0)) | ((2,1,1),(4,0,0),(4,0,0),(4,0,0)) |
((4,2,4,4),1,2) | ((4,0,0),(2,1,1),(4,0,0),(4,0,0)) | ((2,1,1),(4,0,0),(4,0,0),(4,0,0)) |
このデータの並べ替えをしてソートする関数symmは以下のようになります。
def symm(args): # create hash value considering symmetries
cards, lastr, bits = args
return tuple(tuple(l) for l in sorted([[cards[r],int(r==lastr),int(bits & bts[r] > 0)] for r in range(N)]))
この関数symmをhashableの代わりに使えば他の変更は必要ありません。
$N=13$でも約 4.5% という答えを約20秒でだすことができました。
N = 13, Answer: 0.0454762823 = 4184920420968817245135211427730337964623328025600 / 92024242230271040357108320801872044844750000000000
解法 | N=4 | N=7 | N=13 |
---|---|---|---|
#1 | 96sec | - | - |
#2 | 130sec | - | - |
#3 | 0sec | 143sec | - |
#4 | 0sec | 0sec | 20sec |
まとめ
LRU Cacheと同様の機能を自分でDICT型を使って実装し、そのデータを解析することによりさらなる効率化ができる可能性があることを示しました。しかもオリジナルのコードはそのまま使えるので、変更によりバグの入るリスクを少なくできます。
(開発環境:Google Colab)
この考え方はProject Euler Problem 687: Shuffling Cardsを解くのに役に立ちます