0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【Python】Numpy配列の+=演算に意外な落とし穴(初心者向け)

Posted at

a += bの挙動がa = a + bと違うの?!

ことの経緯を説明すると...

最近,E資格を取ろうとして,某有名のディープラーニングの本を迷わず購入!
毎日夜中まで思わず読んでしまっていて...

今日もお昼に買った明治牛乳をちびちび飲みながら,本を参照しながら自動微分を実装していて,自動微分の構想面白い!って思いながら,下記の動作確認コードを実行した.

自動微分を動作確認しようぜ
x0 = Variable(np.array(1.0)) # Nodeを作る
x1 = Variable(np.array(1.0)) # Nodeを作る
t = add(x0, x1) # Add計算して新しいNodeを作る
y = add(x0, t) # Add計算して新しいNodeを作る
y.backward() # 自動微分

print(x0.grad, x1.grad) # 微分を出力

図で示すとこんな感じ
x0.png

偏微分を求めてみると,

y = t + x_0 = 2x_0 + x_1

なので

 \frac{\partial y}{\partial x_0} = 2, \frac{\partial y}{\partial x_1} = 1

ですね.ふむふむ,余裕っすね.さあ,結果見てみよう.

output
2.0 2.0 # 期待値:2.0 1.0

なんで?!本通りに実装したはずなのに,なぜ結果が違う?
そもそも,x1の偏微分の2.0ってどう計算すれば出てくるんだい?

gradを計算する処理を見てみると(関係ないところを省略)

core.py
class Variable:
    ...
    def backward(self):
        ...
        for f in funcs:
            ...
            for x, gx in zip(f.inputs, gxs):
                if x.grad is None:
                    x.grad = gx
                else:
                    x.grad += gx # 実はこれはダメだったとは,当時の自分は知らなかった.

泣きながらclaudeくんに問い合わせして,かつ色々デバッグしてみて理由がわかった.

Numpyが頑張って高速化を実現したからこその問題だった

まず,なぜ間違った結果が出たのかでいうと,以下のコードでx.gradが更新されるタイミングと同時に,gxの値も更新されてしまっていた.

x.grad += gx

理由辿り着くと以下の感じになる.

1. x.gradの初期値はgxを参照しているので,同じアドレスを持っている.

if x.grad is None:
    x.grad = gx # -> x.grad address: 4336285360; gx address: 4336285360

2.x.grad += gxの時実は以下のメソッドを使っている.

x.grad.__iadd__(gx)

__iadd__メソッドによって,現在のインスタンスを そのまま更新 する.
x.gradとgxが同じアドレスを指しているので,x.gradの値の更新→ gxが指しているアドレスの値が更新→gxも更新してしまったように見える.
(C++のポインターを学んだ方は既にわかったと思います)

一方,下記の書き方だとセーフ.

x.grad = x.grad + gx

なぜかっていうと,x.grad + gxは新しいインスタンスを作って,x.gradが新しいインスタンスを指すように更新っていう手順を踏んで計算することになるので,上記の事象が発生しなくて済む.

修正後
print(f"Before: x.grad id={id(x.grad)}, gx id={id(gx)}")
# x.grad += gxは使ってはいけない.
# numpy配列の+=演算時gxはx.gradと同じアドレスを持つ(モリ効率と計算速度の最適化のためin-placeを使っている故)
# x.gradを更新すると同時にgxも更新されてしまう事象が起きる
x.grad = x.grad + gx
print(f"After: x.grad id={id(x.grad)}, gx id={id(gx)}")
output
Before: x.grad id=4403108272, gx id=4403108272
After: x.grad id=4376983216, gx id=4403108272
grads: 2.0 1.0 # 正解!

numpyはよく使っているけど,まだまだ理解が足りないなぁって実感した.
それでは!

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?