Pythonは遅い遅いと言われて久しい、そして遅い。というか、LLで速い言語なんてあるはずがないので、そもそも遅いと文句を言うこと自体がお門違いという気もするが、労力が同じならプログラムは速い方が良いに決まっている。
ただし、下回りのライブラリが速い場合は別である。そして、numpyが生まれた。というわけで、Pythonで数値計算を行う場合、numpyを使えば言うほど遅くないのだが、問題はforループである。シミュレーションなどの計算で、データを順繰りに読み込んで処理するという処理は割と頻出するのだが、そんとき、ループのなかでnumpyを呼び出すコードはpure pythonになってしまうので、とても遅いと、少なくとも一般的にそう思われていると思う。
ところで、最近「じけいれつかいせき」に凝っていて、ちょうどKalmanFilterのパラメータ・フィッティングを行ったので、この遅い遅いと言われるforループをjitコンパイルすることでどの程度早くなるのか見てみることにした(よく、マンデルブロ集合とか、ローレンツの偏微分方程式とかを計算して20倍速くなったよ!みたいな話を聞くが、あれは全部pythonでやった場合なので、そんだけ早くなるのだが、そもそもnumpyのコードなので、どうよ?という話)。
class ModelContainer:
# ...(中略)...
def loss(self, ys):
yp, xs, Ps = self.filt(ys)
return self.loss_fn(yp, ys)
def filt(self, ys):
# prepare
Yt = self.Yt
H, R, Xt = Yt.pack()
F, G, Q = Xt.pack()
# css/csr matrices are usually slower
F = F.toarray()
G = G.toarray()
KF = KalmanFilter(F, G, Q, H, R)
T = ys.shape[0]
# allocate memory
xs = np.zeros((T, Yt.xdim()))
Ps = np.zeros((T, Yt.xdim(), Yt.xdim()))
yp = np.zeros((T, Yt.ydim()))
# initialize covar matrix
Ps[0] = np.diag([1.0]*Yt.xdim())
# filter it!
for t in range(1, T):
xp, P = KF.predict(xs[t-1], Ps[t-1])
yp[t], D = Yt.observe(xp, P)
xs[t], Ps[t] = KF.filt(xs[t-1], Ps[t-1], ys[t])
return values
#
# ...(中略)...
#
def loss(x): return ModelContainer(model(x), _RMSE).loss(ys)
print(datetime.strftime(datetime.now(), '%H:%M:%S'))
bnds = ((-1, 1), (-1, 1), (-1, 1), (-1, 1), (-1, 1), (-1, 1), (0, None), (0, None), (0, None))
res = minimize(loss, (0, 0, 0, 0, 0, 0, 1, 1, 1), method='TNC', bounds=bnds, options={'disp':True})
print(datetime.strftime(datetime.now(), '%H:%M:%S'))
この、ModelContainerというのが、自前のカルマンフィルタで、データをフィルタリングするクラスである。KalmanFilter自体はだいぶ前に書いた記事にのっている。なんだかんだで、classというものは便利なのでこんな感じにしてしまったのだが、本来数値計算を行うようなコードでは、あまりクラスとか使わないほうがいいかも。実際、このあとnjitするところでちょっと手間がかかった。その話は後述。predict、observe、filtなどの関数名だけでは意味が分からないと思われるかもしれないが、中で走っている処理はnumpyによる行列演算だけである。
さて、これで、パラメータフィッティングを行うと、大体このくらいの時間がかかった。
14:48:57
14:51:49
fun: 0.0052101067256022215
jac: array([ 6.62647020e-06, 1.18649014e-05, 4.51221525e-05, 3.55355502e-05,
2.41178605e-06, -2.20449527e-05, -6.92137324e-06, -1.82874550e-06,
-1.68335137e-04])
message: 'Max. number of function evaluations reached'
nfev: 1010
nit: 14
status: 3
success: False
x: array([-4.24785315e-02, 7.15799804e-02, -8.85617298e-04, -2.37834841e-02,
-3.88928185e-02, 1.10227625e-02, 1.58767720e+00, 1.32529537e+00,
1.25988155e-02])
>>> datetime(2000,1,1,14,51,49)-datetime(2000,1,1,14,48,57)
datetime.timedelta(seconds=172)
というわけで、172秒である。また、nfevという数値が1010になっているのは、最適化のために異なるパラーメータで1010回の計算が行われたことを示している。ようするに、ModelContainerのlossが実行された回数がこれだ。一方で、モデルのパラメータ数が9個、データ点数が1977個である。ざっと計算して、1010×1977回フィルタリングの計算(forループの中身)が実行されているだろうということは、うすぼんやりの私にもすぐわかる。というわけで、そこをなんとかできないか見てみよう。
class ModelContainer:
# ... 略 ....
def filt(self, ys):
# ... 略 ....
for t in range(1, T):
xp, P = KF.predict(xs[t-1], Ps[t-1])
yp[t], D = Yt.observe(xp, P)
xs[t], Ps[t] = KF.filt(xs[t-1], Ps[t-1], ys[t])
predict、observe、filt関数の中身はnumpyで計算しているだけだが、numpy呼び出しと、演算式の実行はPythonなので、このあたりのオーバーヘッドをなんとかできるかもしれない。そこで、numbaを使うことにする。
Numba
Numbaが何か知らない人はこの記事を読んでいないと思うが、一応説明する。そもそも、LLが遅いのはバイトコードを実行するからであり、ということは、Javaの処理系がやってくれているようにJust in timeで特定の処理を機械語にコンパイルできれば、早くなるだろうと期待される。それをやってくれるのがNumbaである。もっとも単純なケースではnjitというデコレーターをつけるだけで終わりなのだが、機械語までコンパイルする手前、型が解決していなければコンパイルはできない。そのため、複雑な処理をしている場合は、型指定をしたりする必要があるケースもあるので注意。というわけで、numbaしてやれば数値計算などの処理は格段に速くなるのである。すごーい。
ちなみに、なぜ、そんな小難しいことがホイホイできるのかというと、イリノイ大学の人たちが、LLVM(このLLは当初Low Levelを意味した。Lightweight Languageとは何の関係もないのだが、現在は、LLVMという呼称自体が、得に何の頭文字でもないと宣言されているらしい。なんでやねん)という素晴らしいソフトウェアを書いてくれたおかげである。一時はやったので、なんなの?LLVMって?と思っている人も多いと思うが、LLVMは実はkaleidoscopeというぱっとした機能のない言語を開発するために作られた巨大なコンパイラ基盤である(大嘘)。
冗談はともかく、プログラミング言語というのは、一見全然違うようにも見えるが、実はコードの最適化などの処理に関しては共通して使える処理が多々ある(レジスタ割り当ての最適化とか)。そこで、オレオレ言語を作りたい君たちが、一旦LLVMをターゲットとしたIR(命令セット)までコンパイルしてやればその後を引き継いで、ちゃんと(最適化された)機械語を出力してくれる、そういう夢のような仕組みがLLVMなのだ。そして、とあるPythonistaがそれをPythonに適用したライブラリ、それがNumba(ここでエコーがかかる)。
というわけで、さっそく使ってみる。
クラスメソッドのJIT化
というわけで、普通ならここで、元のコードにnjitデコレーターを当てるだけで200倍の高速化が出来ました!などという結論になるのだが、今回私が高速化しようとしているコードは既にnumpyを呼んでいるだけなので、そもそも、それほど遅くない(…と思うんだよねー、そう期待している)。なので、それほどのご利益は得られないと思われるが、それでも、ループの中で結構な回数の素のPythonが実行されているので、好奇心からちょっとやってみた。
ただし、通常クラスメソッドをnumbaしようとすると結構面倒になることが多い。そこで、次のような手法を試みることにした。
class ModelContainer:
def __init__(self, Yt, loss_fn=_RMSE):
self.Yt = Yt
self.loss_fn = loss_fn
def loss(self, ys):
yp, xs, Ps = self.filt(ys)
return self.loss_fn(yp, ys)
@staticmethod
@njit
def fast_filt(xs, Ps, ys, yp, T, F, G, Q, H, R):
for t in range(1, T):
# prediction
xp = F @ xs[t-1]
P = F @ Ps[t-1] @ F.T + G @ Q @ G.T
f = yp[t] = H @ xp
D = H @ P @ H.T + R
# kalman gain
K = (np.linalg.solve(D.T, H @ P.T)).T
# update
xs[t] = xp + K @ (ys[t] - f)
Ps[t] = P - K @ H @ P
def filt(self, ys):
# ... (略) ...
Ps[0] = np.diag([1.0]*Yt.xdim())
self.fast_filt(xs, Ps, ys, yp, T, F, G, Q, H, R)
# filter it!
#for t in range(1, T):
# xp, P = KF.predict(xs[t-1], Ps[t-1])
# yp[t], D = Yt.observe(xp, P)
# xs[t], Ps[t] = KF.filt(xs[t-1], Ps[t-1], ys[t])
# return values
return yp, xs, Ps
なんかもう、オブジェクト指向もクソもないようなコードになってしまったが、実際、クラスオブジェクトをnumba化(?)するのは結構な手間なので、ピンポイントで最適化したい場所をstaticmethodとして切り出すと、型しても何もしなくとも、ちゃんとJITコンパイルされる模様である(ただし、ここまでやるならいっそのこと普通に関数として定義したらいいんじゃないの?という考え方もあるよね…まー、それでもいいかなぁ)。冒頭で、数値計算するときは、あんましオブジェクト指向的にしないほうがいいかも、などとのたまったのはこれが理由だ。特に、forの三重ループを回して、素のPythonで計算するようなコードを書くつもりなら、numbaでなんびゃく倍とか高速化できるらしいので、それを当てにして、なるべく簡素に書いたほうが後々困らないと思う。
ちなみに、もし、コードの中でどうしてもself.func()など、クラス・オブジェクトのメソッドの呼び出しを行いたい場合は、クラス自体をnumbaに対応させなければならないようで、結構めんどくさい(ので、今回も諦めた)。そこが、numbaのちょっと困ったところ(だが、まー、仕組み上、どうしようもないよね、それはさ)。では実行してみる。
14:44:19
14:45:56
fun: 0.0052101067256022215
jac: array([ 6.62647020e-06, 1.18649014e-05, 4.51221525e-05, 3.55355502e-05,
2.41178605e-06, -2.20449527e-05, -6.92137324e-06, -1.82874550e-06,
-1.68335137e-04])
message: 'Max. number of function evaluations reached'
nfev: 1010
nit: 14
status: 3
success: False
x: array([-4.24785315e-02, 7.15799804e-02, -8.85617298e-04, -2.37834841e-02,
-3.88928185e-02, 1.10227625e-02, 1.58767720e+00, 1.32529537e+00,
1.25988155e-02])
>>> datetime(2000,1,1,14,45,56)-datetime(2000,1,1,14,44,19)
datetime.timedelta(seconds=97)
時間が前後しているのは、記事の構成上、「Numba前→Numba後」のようにした方がいいかなと思ったため。まー、どうでもいいよね。
というわけで、172秒だったものが97秒になっている。まー、半分くらい?やってみる前は3~4倍くらいになればいいなと思っていたのだが、ちょっと甘かったようである。Pythonを舐めすぎていた。この結果を見て、2倍になったんだから早くなったんじゃん!(よって、素のPythonは遅い)と思うのか、あー、JITにしてもこのくらいしか差が出ないのか~(Python思ったほど遅くないのかなぁ?)と思うのかは、あなた次第である。
さて、ここまで来ると、「いやさ、お前関数呼び出しをアンラップしただろ?それで早くなったんじゃねーの?」と思う人もいるかもしれない。そこで、fast_filtからnjitを取り除いて実行してみた。そうすれば、ちゃんとした比較になるもんね(たぶん)。
16:17:00
16:19:53
fun: 0.0052101067256022215
jac: array([ 6.62647020e-06, 1.18649014e-05, 4.51221525e-05, 3.55355502e-05,
2.41178605e-06, -2.20449527e-05, -6.92137324e-06, -1.82874550e-06,
-1.68335137e-04])
message: 'Max. number of function evaluations reached'
nfev: 1010
nit: 14
status: 3
success: False
x: array([-4.24785315e-02, 7.15799804e-02, -8.85617298e-04, -2.37834841e-02,
-3.88928185e-02, 1.10227625e-02, 1.58767720e+00, 1.32529537e+00,
1.25988155e-02])
>>> datetime(2000,1,1,16,19,53)-datetime(2000,1,1,16,17,0)
datetime.timedelta(seconds=173)
関数呼び出しをアンラップしたことによる影響はほとんどない。やっぱし、forループ内のnumpy呼び出し等々がnumbaによって高速化されているのが、早くなった理由だと思われる。実際、百万とか二百万程度のループと関数呼び出し「だけ」なら、素のPythonだって一瞬で終わる。だから、やっぱり高速化されたのはnumpy呼び出しのオーバーヘッドだろう、というのがここでの結論。もし、そこまでPythonを馬鹿にしている人がいるのなら、結構な回数のforループを(適当な計算処理を入れて)実行してみるとよい。思ったほど遅くはないはず。
スパース行列の演算速度
ちなみに、行列の大きさが11×11~20×20とかそのくらいになるので、np.arrayではなく、scipy.sparseのにした方が、余計な計算が行われず早いのかと思って当初はそうしていたのだが、信じられないほど遅かったので結局np.arrayで統一することにした。言語処理とかしている人から比べたらアホみたいに小さい行列なのかもしれないけど、無駄な計算はしない方が早いかなーと思ったりもしたのである。全然そんなことはなかった。たぶん、数倍速くなったので、不必要にsparseな行列は使わないほうがいいよ(そもそも、block_diagが便利だったのでそうしていたというのもあった)。これも、「分かっていない」せいなので恥ずかしいから知らんぷりしようと思ったのだが、同じ勘違いをしている人が、もしかしてひょっとすると、p<0.01で棄却されるくらい僅かな確率で、この記事を読んでいるかもしれないので念のため。
結論
パラメータ数が増えたり、データ数が大きくなったりすれば、2倍の高速化、がご利益となる場面もそのうち出てくるかもしれない。ただ、今のところはそこまで速度にこだわる必要はないかなぁ…と思ったりもしたのでした。もし、4倍くらいの速度になるのなら、ちゃんとクラスオブジェクトをNumba化してどーのこーのというのもアリかなと思っていたのだが、このくらいなら待ってもいいや。どうせ1~2分だしさ(でも、北極の白クマさんたちを救うには、その2分を1分にするのが大事だぞ、そこのエンバイロメンタリストの君!)。
というわけで、それなりにnumpyなどで書かれたコードならば、numbaで爆速ということもないということが分かりました。ただ、numpy呼び出しのオーバーヘッドだけで、実行時間全体の50%というのはやっぱり、ちょっと遅いかな(どっちやねん)。
ちなみに、もとのコードでは、ループ外でnp.zerosで配列を確保したり、スパースな行列をtoarray()したりしているので、それをやめたら早くなるかと思って、試しにちょっと実験してみたが、配列の確保や、スパース行列の変換は1000回くらい繰り返したところで一瞬だということが分かった(minimizeのiterationが1010回なので、高々そのくらいのはず)。メモリを確保するので、それなりに遅いのかなーと思ったのだが、1000回くらいなら体感的には一瞬。気にする必要はないみたい。
というわけで、jitにした方が速いことは速いのだが、そもそも素のpythonでもnumpyを使えば、計算処理はそれほど遅くないと個人的には思う。なんか、遅い遅い言われてるから、余計なオーバーヘッドがもっとあるのかと思っていたのだが、ちょっとバカにしすぎていたようだ。ごめんねPythonくん。
以上です。