メモ化とは
メモ化(英: memoization)とは、プログラムの高速化のための最適化技法の一種であり、サブルーチン呼び出しの結果を後で再利用するために保持し、そのサブルーチン(関数)の呼び出し毎の再計算を防ぐ手法である。wikiより
簡単に言うと、関数の引数と戻り値をペアで覚えておいて、次回同じ引数が来た時に関数の中を通らず結果を返すことで高速化する手法です。
Python3.8では標準ライブラリfunctools.lru_cache
でメモ化が可能です。こちらとか参照ください。ただしこちらは引数にnumpy配列を取れなかったので以下、自作しました。
メモ化デコレータ作成
from numba import njit
from functools import wraps
import numpy as np
import time
numpy配列に対応するために
引数を無理やりstrにキャストして辞書のキーにしています。
def cache(f):
c = {}
@wraps(f)
def _wrapper(*args,**kwargs):
key = ''.join([str(arg) for arg in args]) + ''.join([str(k)+str(v) for k,v in kwargs.items()])
if key not in c:
c[key] = f(*args,**kwargs)
return c[key]
return _wrapper
ついでに時間計測デコレータ
def proc_time(f):
@wraps(f)
def _wrapper(*args, **kwargs):
st = time.perf_counter_ns()
res = f(*args, **kwargs)
t = (time.perf_counter_ns() - st) * 0.000000001
print(f'{format(t,".2E")} sec')
return res
return _wrapper
自作メモ化テスト
@proc_time
@cache
def calc(a,b,c):
"""負荷のあるテキトーな計算"""
for _ in range(10000):
for i in range(1, len(a)-1):
a[i] = (a[i-1]+a[i+1])*0.5 + b
a[len(a)-1] = a[len(a)-2]
return a+c
calc(np.zeros(100), 1, True)
calc(np.zeros(100), 1, True)
calc(np.zeros(100), 1, c=True)
1.84E+00 sec
1.12E-03 sec
1.57E+00 sec
1回目は普通に計算しているので遅いです。
2回目はメモ化で1回目の結果を呼び出しているので速いです。
numpy配列に対応したメモ化になっています。
3回目は引数の与え方が違うので1回目や2回目と別物と判断、計算しているので遅いです。
メモ化デコレータとnjitデコレータを併用する
calc関数に@njitデコレータ追加しました。
njitで高速化するので計算回数を100から10000に増やしています。
@proc_time
@cache
@njit # 追加
def calc2(a,b,c):
"""負荷のあるテキトーな計算"""
for _ in range(10000):
for i in range(1, len(a)-1):
a[i] = (a[i-1]+a[i+1])*0.5 + b
a[len(a)-1] = a[len(a)-2]
return a+c
calc2(np.zeros(10000),0,True) # ダミー実行
calc2(np.zeros(10000),1,True)
calc2(np.zeros(10000),1,True)
9.64E-01 sec
6.48E-01 sec
2.66E-04 sec
1回目はjitコンパイル時間を除外するためのダミー実行です。
cacheの効果は2回目と3回目の差で確認でき、
njit併用でもメモ化動作していることが確認ます。
njitした関数からメモ化した関数を呼び出す
@cache
@njit
def _calc3(a,b,c):
for _ in range(10000):
for i in range(1, len(a)-1):
a[i] = (a[i-1]+a[i+1])*0.5 + b
a[len(a)-1] = a[len(a)-2]
return a+c
@proc_time
@njit
def calc3(a,b,c):
return _calc3(a,b,c)
calc3(np.zeros(10000,dtype=np.int8),0,True)
calc3(np.zeros(10000,dtype=np.int8),1,True)
calc3(np.zeros(10000,dtype=np.int8),1,True)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name '_calc3': Cannot determine Numba type of <class 'function'>
File "..\..\..\AppData\Local\Temp\ipykernel_4136\4257892594.py", line 13:
<source missing, REPL/exec in use?>
うまくいきません。
まとめ
- メモ化デコレータを作成しました。以下特徴を持っています。
- メモ化した関数の引数にnumpy配列を取れる
- 条件付きで、メモ化デコレータとnjitデコレータを併用できる
- keyが長い文字列なので検索効率悪いです。