話の経緯
某「○○さんのプログラムブラッシュアップしてほしいんだ」
私(ブラッシュアップとはまた抽象的な)「わかりました」
で、
○○さん「・・・ということをするプログラムなんです」
私「なるほど。で、(引数が多すぎるとかいろいろツッコミたいのは置いといて)ブラッシュアップって具体的に何したらいいんだろう」
○○さん「何したらいいんですかね。あ、このプログラムすごい実行に時間がかかるんですが」
私「ほうほう(プログラムを眺める)」
def calc(d, a, x1, x2, x3):
sum_ = 0
for i in range(len(d)):
sum_ += (d[i] - (a[0]*x1[i] + a[1]*x2[i] + a[2]*x3[i]))**2
return sum_
私(´・ω・`)
○○さん(´・ω・`)
私「この、d, a, x1, x2, x3って(話聞いてる限りだと)NumPy配列だよね」
○○さん「ええ多分」
私「なんで自力で計算してるの?」
○○さん「え?」
というわけで上記のコードを爆速化する方法です。
行列演算に直すんだ!
私「Pythonは遅い。特に関数呼び出しはすっごく遅い。後、この要素アクセスなんだけどこれ実体はメソッド呼び出しだからむっちゃコストがかかってる。一方でNumPyは速い。数値計算を速くするために作られたものだからすごい速い。つまり、Pythonの世界に持ってきて自分で計算なんてしないでNumPyにやらせればいい。これって積和の計算だからNumPyで計算できるはず」
○○さん「はあ」
私「えーと、aが3要素で、x1, x2, x3は、iでインデックスしてるんだしdと同じ要素数か(Nとする)。てことは、こいつらまとめて3×N行列にできるよね(Xとする)。そうすればaとXで行列計算できる」
○○さん「はあ」
私「後はx1, x2, x3を一つの行列にまとめる方法がわかればいいんでけど、concatenateか?んー、shapeが(N,)だからうまくいかない、(1,N)にreshapeすればいいかな」
○○さん「・・・」
私「行列計算により、1×N行列、まあN要素のベクトルになるわけだけど・・・、dから計算値を引くのも、差分を二乗するのも、和を取るのもNumPyで完結できるな」
○○さん「・・・・・・」
私「まあ私が書いちゃうと、こんな感じ」
def calc_improve(d, a, x1, x2, x3):
x = np.concatenate((
x1.reshape(1, -1),
x2.reshape(1, -1),
x3.reshape(1, -1)
))
return np.sum((d - np.dot(a, x))**2)
本当に速くなったか計ってみよう
さてというわけで自力で書いてたコードをNumPyだけで計算が完結するように書き換えました。時間の都合で○○さんプログラムでは速度比較できなかったのですがcalc
とcalc_improve
でどれぐらい速くなったか比べてみましょう。え?タイトルでネタバレしてる?(笑)
念のため、計算結果が変わらないかも確認しましょう。なお実行時間だけわかればいいので計算させてるデータは適当です。
LOOP = 100
SIZE = 10000
d = np.arange(0, SIZE*0.1, 0.1)
x1 = np.arange(SIZE*1*0.1, SIZE*2*0.1, 0.1)
x2 = np.arange(SIZE*2*0.1, SIZE*3*0.1, 0.1)
x3 = np.arange(SIZE*3*0.1, SIZE*4*0.1, 0.1)
a = np.array([1, 2, 3])
print(calc(d, a, x1, x2, x3))
print(calc_improve(d, a, x1, x2, x3))
import time
elapsed = []
for l in range(LOOP):
start = time.time()
calc(d, a, x1, x2, x3)
end = time.time()
elapsed.append(end - start)
print(np.average(elapsed))
elapsed = []
for l in range(LOOP):
start = time.time()
calc_improve(d, a, x1, x2, x3)
end = time.time()
elapsed.append(end - start)
print(np.average(elapsed))
結果です。
2743250833749.798
2743250833749.2515
0.7366799211502075
0.0009990835189819336
計算結果(上2行)微妙にずれてるけど、まあ誤差の範囲だしむしろ自力計算の方が間違ってる気はする。
で、実行時間ですが、
>>> 0.7366799211502075 / 0.0009990835189819336
737.3556936469981
圧倒的ですね。ここまで速くなるとは思いませんでした。
おわりに
というわけで今回は改善をお願いされたプログラムを爆速化してみました。
まあ爆速化というよりそもそもNumPyのことがわかってなくて自力で書いていたのが原因ですが、NumPyだけで計算が完結できるように書くと非常に高速に計算ができるいうことがおわかりいただけたかと思います。