対象者
前回の記事の続きです。
本記事では逆伝播について、計算グラフを用いて簡単に説明します。連鎖律についてはここで簡単に触れていますので省略します。
どうしても数式多くなりがちな逆伝播ですが、実装を交えて極力直感的な理解で済むように説明したいと思います。
次回の記事はこちら
目次
スカラでの逆伝播
例の如くまずはスカラでの逆伝播から見ていきます。とはいえスカラでもちゃんと理論的に追うとなかなか大変なのが逆伝播です。
が、ものすごく簡単に説明していきます。
スカラでの逆伝播理論
まずは簡単に以下の計算グラフで考えていきましょう。
図中の順伝播をまずは丁寧に書くと
\begin{align}
m &= wx \\
n &= m + b \\
y &= \sigma(n)
\end{align}
のようになっていますね。この逆伝播を考えます。
それぞれの偏微分を考えてみます。
$\cfrac{\partial m}{\partial x} = w \quad \quad$ $\cfrac{\partial m}{\partial w} = x$
$\cfrac{\partial n}{\partial m} = 1 \quad \quad$ $\cfrac{\partial n}{\partial b} = 1$
$\cfrac{\partial y}{\partial n} = \sigma'(n)$
こんな感じですね。ではここから出力$y$に対する入力$x$と重み$w$、バイアス$b$の偏微分を連鎖律を用いて計算してみましょう。
$\cfrac{\partial y}{\partial x} = $ $\cfrac{\partial y}{\partial n}$ $\cfrac{\partial n}{\partial x} = $ $\cfrac{\partial y}{\partial n}$ $\cfrac{\partial n}{\partial m}$ $\cfrac{\partial m}{\partial x}$ $=\cfrac{}{}$ $\sigma'(n) \cfrac{}{}$ $\times \cfrac{}{} $ $1 \cfrac{}{}$ $\times\cfrac{}{} $ $w \cfrac{}{} $
$\cfrac{\partial y}{\partial w} = $ $\cfrac{\partial y}{\partial n}$ $\cfrac{\partial n}{\partial w} = $ $\cfrac{\partial y}{\partial n}$ $\cfrac{\partial n}{\partial m}$ $\cfrac{\partial m}{\partial w}$ $=\cfrac{}{}$ $\sigma'(n) \cfrac{}{}$ $\times \cfrac{}{} $ $1 \cfrac{}{}$ $\times\cfrac{}{} $ $x \cfrac{}{} $
$\cfrac{\partial y}{\partial b} = $ $\cfrac{\partial y}{\partial n}$ $\cfrac{\partial n}{\partial b}$ $=\cfrac{}{}$ $\sigma'(n) \cfrac{}{}$ $\times \cfrac{}{} $ $1 \cfrac{}{}$
<< 注意 >>
無理矢理高さを揃えているため「$\frac{}{}$」が所々混じってます。なにかいい方法あれば教えてください。あと数式に色を振る他の方法もぜひ...
となりますね。これに上流からの誤差$E$を流せば図にあるような逆伝播になります。
やっぱり数式が多く複雑に見えますが、つまり計算ノードを通る時に偏微分を乗算してるだけです。
これでもスカラなので簡単な方だったりします。
スカラでの逆伝播実装
では実装してみます。数式通りにコーディングしましょう。実装先のコードはこちらです。
def backward(self, grad):
"""
逆伝播の実装
"""
dact = grad*self.act.backward(self.x, self.y)
self.grad_w = dact*self.x
self.grad_b = dact
self.grad_x = dact*self.w
return self.grad_x
中間層はこのままでOKですが、出力層については処理を一部追加する必要がありますね。
後ほど触れます。
行列での逆伝播
続いて行列での逆伝播を考えてみましょう。こちらは非常に数式が多く大変ですが、きちんと追いかけてみるとわかるはずです。
先ほどと同じように数式に色を振って、ぱっと見でわかるようにしようと思います。気が重い...
行列での逆伝播理論
以下の計算グラフを考えます。
順伝播で使ったものと同じですね。2ニューロンしかないレイヤーモデルですが、これでもご覧の通りかなり数式が多くなります。
まずはスカラの時と同じようにそれぞれの部分の偏微分を求めておきましょう。
$\cfrac{\partial l}{\partial x_1} = w_{1, 1} \quad$ $\cfrac{\partial l}{\partial w_{1, 1}} = x_1$
$\cfrac{\partial m}{\partial l} = 1 \quad$ $\cfrac{\partial m}{\partial b_1} = 1 \quad$ $\cfrac{\partial m}{\partial r} = 1$
$\cfrac{\partial n}{\partial x_1} = w_{1, 2} \quad$ $\cfrac{\partial n}{\partial w_{1, 2}} = x_1 \quad$
$\cfrac{\partial y_1}{\partial m} = \sigma'_1(m)$
$\cfrac{\partial p}{\partial x_2} = w_{2, 2} \quad$ $\cfrac{\partial p}{\partial w_{2, 2}} = x_2$
$\cfrac{\partial q}{\partial p} = 1 \quad$ $\cfrac{\partial q}{\partial b_2} = 1 \quad$ $\cfrac{\partial q}{\partial n} = 1$
$\cfrac{\partial r}{\partial x_2} = w_{2, 1} \quad$ $\cfrac{\partial r}{\partial w_{2, 1}} = x_2$
$\cfrac{\partial y_2}{\partial q} = \sigma'_2(q)$
数式ばっかりでうんざりですね...とはいえ仕方ないです。
では連鎖律を適用してみましょう。
$\cfrac{\partial y_1}{\partial w_{1, 1}} =$ $\cfrac{\partial y_1}{\partial m}$ $\cfrac{\partial m}{\partial l}$ $\cfrac{\partial l}{\partial w_{1, 1}}$ $=\cfrac{}{}$ $\sigma'_1(m) \cfrac{}{}$ $\times \cfrac{}{}$ $1 \cfrac{}{}$ $\times \cfrac{}{}$ $x_1 \cfrac{}{}$
$\cfrac{\partial y_2}{\partial w_{1, 2}} =$ $\cfrac{\partial y_2}{\partial q}$ $\cfrac{\partial q}{\partial n}$ $\cfrac{\partial n}{\partial w_{1, 2}}$ $=\cfrac{}{}$ $\sigma'_2(q) \cfrac{}{}$ $\times \cfrac{}{}$ $1 \cfrac{}{}$ $\times \cfrac{}{}$ $x_1 \cfrac{}{}$
$\cfrac{\partial y_1}{\partial w_{2, 1}} =$ $\cfrac{\partial y_1}{\partial m}$ $\cfrac{\partial m}{\partial r}$ $\cfrac{\partial r}{\partial w_{2, 1}}$ $=\cfrac{}{}$ $\sigma'_1(m) \cfrac{}{}$ $\times \cfrac{}{}$ $1 \cfrac{}{}$ $\times \cfrac{}{}$ $x_2 \cfrac{}{}$
$\cfrac{\partial y_2}{\partial w_{1, 2}} =$ $\cfrac{\partial y_2}{\partial q}$ $\cfrac{\partial q}{\partial p}$ $\cfrac{\partial p}{\partial w_{2, 2}}$ $=\cfrac{}{}$ $\sigma'_2(q) \cfrac{}{}$ $\times \cfrac{}{}$ $1 \cfrac{}{}$ $\times \cfrac{}{}$ $x_2 \cfrac{}{}$
\Rightarrow
\boldsymbol{dW} =
\left(
\begin{array}{cc}
\cfrac{\partial y_1}{\partial w_{1, 1}} & \cfrac{\partial y_2}{\partial w_{1, 2}} \\
\cfrac{\partial y_1}{\partial w_{2, 1}} & \cfrac{\partial y_2}{\partial w_{2, 2}}
\end{array}
\right)
=
\left(
\begin{array}{c}
x_1 \\
x_2
\end{array}
\right)
\left(
\begin{array}{cc}
\sigma'_1(m) & \sigma'_2(q)
\end{array}
\right)
=
\boldsymbol{X}^{\top}\boldsymbol{\sigma'}
$\cfrac{\partial y_1}{\partial b_1} =$ $\cfrac{\partial y_1}{\partial m}$ $\cfrac{\partial m}{\partial b_1}$ $=\cfrac{}{}$ $\sigma'_1(m)\cfrac{}{}$ $\times \cfrac{}{}$ $1\cfrac{}{}$
\Rightarrow
\boldsymbol{dB} =
\left(
\begin{array}{cc}
\cfrac{\partial y_1}{\partial b_1} & \cfrac{\partial y_2}{\partial b_2}
\end{array}
\right)
=
\left(
\begin{array}{cc}
\sigma'_1(m) & \sigma'_2(q)
\end{array}
\right)
$\cfrac{\partial y_1}{\partial x_1} + \cfrac{\partial y_2}{\partial x_1} =$ $\cfrac{\partial y_1}{\partial m}$ $\cfrac{\partial m}{\partial l}$ $\cfrac{\partial l}{\partial x_1}$ $+\cfrac{}{}$ $\cfrac{\partial y_2}{\partial q}$ $\cfrac{\partial q}{\partial n}$ $\cfrac{\partial n}{\partial x_1}$ $=\cfrac{}{}$ $\sigma'1(m) \cfrac{}{}$ $\times \cfrac{}{}$ $1 \cfrac{}{}$ $\times \cfrac{}{}$ $w{1, 1} \cfrac{}{}$ $+\cfrac{}{}$ $\sigma'2(q) \cfrac{}{}$ $\times \cfrac{}{}$ $1 \cfrac{}{}$ $\times \cfrac{}{}$ $w{1, 2} \cfrac{}{}$
\Rightarrow
\boldsymbol{dX} =
\left(
\begin{array}{cc}
\cfrac{\partial y_1}{\partial x_1} + \cfrac{\partial y_2}{\partial x_1} & \cfrac{\partial y_1}{\partial x_2} + \cfrac{\partial y_2}{\partial x_2}
\end{array}
\right)
=
\left(
\begin{array}{cc}
\sigma'_1(m) & \sigma'_2(q)
\end{array}
\right)
\left(
\begin{array}{cc}
w_{1, 1} & w_{2, 1} \\
w_{1, 2} & w_{2, 2}
\end{array}
\right)
=
\boldsymbol{\sigma'}\boldsymbol{W}^{\top}
のようになります。これらに上流からの誤差を要素積で乗算することで上の図のような誤差が伝播します。
なお、$x_1$や$x_2$は分岐して他の計算ノードに値を流していたことからも分かる通り、流れてきた偏微分は足し合わせて計算する必要があります。なので
\left(
\begin{array}{cc}
\cfrac{\partial y_1}{\partial x_1} + \cfrac{\partial y_2}{\partial x_1} & \cfrac{\partial y_1}{\partial x_2} + \cfrac{\partial y_2}{\partial x_2}
\end{array}
\right)
となっているんですね。この式は直感的(そして理論的)には、出力$y_1$や$y_2$に対しての入力$x_1$や$x_2$の影響を表しています。
行列での逆伝播理論バッチ考慮版
さて、ここまではバッチサイズ$N=1$を暗に仮定していたりします。それでは$N \ne 1$の時はどうなるでしょうか?
答えは簡単です。バイアス$B$に関する誤差のみ変わります。
これを数式で見てみましょう。入力$N \times L$、出力$N \times M$とすると
\begin{align}
\underbrace{\boldsymbol{dW}}_{L \times M} &= \underbrace{\boldsymbol{X}^{\top}}_{L \times N} \underbrace{\boldsymbol{\sigma'}}_{N \times M} \\
\underbrace{\boldsymbol{dB}}_{1 \times M} &= \underbrace{\boldsymbol{\sigma'}}_{N\times M} \\
\underbrace{\boldsymbol{dX}}_{N \times L} &= \underbrace{\boldsymbol{\sigma'}}_{N \times M} \underbrace{\boldsymbol{W}^{\top}}_{M \times L}
\end{align}
となります。流れてくる勾配は入力時のそれぞれの形状と同じである必要があるため$\underbrace{\boldsymbol{dW}}_{L \times M}$や$\underbrace{\boldsymbol{dB}}_{1 \times M}$、$\underbrace{\boldsymbol{dX}}_{N \times L}$となりますが、この中で$\underbrace{\boldsymbol{dB}}_{1 \times M}$だけ形状が一致していません。
ではどうするのかというと、バイアス$\underbrace{\boldsymbol{B}}_{1 \times M}$が順伝播の際はブロードキャスト機能によって自動的に$\underbrace{\boldsymbol{B}}_{N \times M}$とされていたことを思い出します。
つまり、全てのバッチデータに対して全く同じバイアスを適用しているということは、同じバイアスを$N$個に分岐して流したのと同義ですので、和を取ればいいということになります。
\underbrace{\boldsymbol{dB}}_{1 \times M} = \sum_{i=1}^{N}{\underbrace{\boldsymbol{\sigma'}_i}_{1\times M}} \xrightarrow{\textrm{coding}} \textrm{sum}(\boldsymbol{\sigma'}, \textrm{axis}=0)
これで実装のための理論が終了しました。
行列での逆伝播実装
では実装していきます。スカラでの実装を書き換えます。
def backward(self, grad):
"""
逆伝播の実装
"""
dact = grad*self.act.backward(self.u, self.y)
self.grad_w = self.x.T@dact
self.grad_b = np.sum(dact, axis=0)
self.grad_x = dact@self.w.T
return self.grad_x
それぞれ行列演算に変わっていますね。また、numpy.sum
関数はaxis=0
と指定しないと全ての要素を足した結果を返してしまうため注意しましょう。
出力層のオーバーライド
出力層の逆伝播については少し処理を加える必要があります。
def backward(self, t):
"""
逆伝播の実装
"""
# 出力層の活性化関数がsoftmax関数で損失関数が交差エントロピー誤差の場合
# 誤差の伝播を場合分けしておく
if isinstance(self.act, type(get_act("softmax"))) \
and isinstance(self.errfunc, type(get_err("Cross"))):
dact = self.y - t
self.grad_w = self.x.T@dact
self.grad_b = np.sum(dact, axis=0)
self.grad_x = dact@self.w.T
return self.grad_x
elif isinstance(self.act, type(get_act("sigmoid"))) \
and isinstance(self.errfunc, type(get_err("Binary"))):
dact = self.y - t
self.grad_w = self.x.T@dact
self.grad_b = np.sum(dact, axis=0)
self.grad_x = dact@self.w.T
return self.grad_x
else:
grad = self.errfunc.backward(self.y, t)
return super().backward(grad)
オーバーライドの内容自体は簡単ですね。出力層の活性化関数がsoftmax関数で損失関数に交差エントロピー誤差を用いている場合は、活性化関数を逆伝播する誤差までコードのようにスキップできます。
また、出力層の活性化関数がsigmoid関数で損失関数に二値交差エントロピー誤差を用いている場合は、活性化関数を逆伝播する誤差までコードのようにスキップできます。
それ以外の場合は損失関数の微分を計算して誤差を流す必要がありますね。
__init__
メソッドの実装
さて、上記で出てきたerrfunc
をメンバに持たせておきましょう。
修正前
def __init__(self, *, prev=1, n=1,
name="", wb_width=1,
act="ReLU", err_func="square",
act_dict={}, **kwds):
self.prev = prev # 一つ前の層の出力数 = この層への入力数
self.n = n # この層の出力数 = 次の層への入力数
self.name = name # この層の名前
# 重みとバイアスを設定
self.w = wb_width*np.random.randn(prev, n)
self.b = wb_width*np.random.randn(n)
# 活性化関数(クラス)を取得
self.act = get_act(act, **act_dict)
# 損失関数(クラス)を取得
self.errfunc = get_errfunc(err_func)
2020/7/8修正
errfunc
を持たせるのはBaseLayer
ではなくOutputLayer
でした。
def __init__(self, *, err_func="Square", **kwds):
# 損失関数(クラス)を取得
self.errfunc = get_err(err_func)
super().__init__(**kwds)
いずれ損失関数についても記事を書きます。
おわりに
数式のカラーリングの大変さなんとかならないかな...