LoginSignup
0
0

More than 5 years have passed since last update.

Weight normalizationのbackpropagation

Last updated at Posted at 2017-01-06

https://www.youtube.com/watch?v=i94OvYb6noo
上記リンクの講義で主張されているintuition重視の発想に感化されたので、例題としてweight normalization論文のbackpropagationの式をこの発想で導出してみました。

話を簡単にするために、重みは2次元ベクトルとしています。バイアス項は省きました。また、右端の出力$y$に非線形の関数を適用する部分も省いてあります。

前向きの計算を書いておくと、以下のようになります。$\boldsymbol{v}$がweight normalization前の重み、$\boldsymbol{w}$がweight normalization後の重みです。$u$,$t$などの記号は適当に決めています。

$u_1=v_1^2$, $u_2 = v_2^2$
$s=u_1 + u_2$
$r={\sqrt s}$ ($\Vert \boldsymbol{v} \Vert$の計算)
$t_1 = v_1 / r$, $t_2 = v_2 / r$ (normalizationの計算)
$w_1 = gt_1$, $w_2 = gt_2$ ($g$は論文にあるとおり)
$a_1 = w_1 x_1$, $a_2 = w_2 x_2$
$y = a_1 + a_2$

では、lossを$L$として、$\frac{\partial L}{\partial y}$からスタートして、上図の右端から左へerrorをbackpropagateさせていきます。

$\frac{\partial L}{\partial a_1} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial a_1}$
$y=a_1 + a_2$なので$\frac{\partial y}{\partial a_1}=1$。よって
$\frac{\partial L}{\partial a_1} = \frac{\partial L}{\partial y}\frac{\partial y}{\partial a_1}=\frac{\partial L}{\partial y}$
つまり、足し算は右からきたerrorを枝分かれさせてそのまま左へ送り出します。上述の講義では、足し算はgradient distributorだと言われています。
$a_2$についても同様です。

ひとつ左に移動します。

$\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial a_1}\frac{\partial a_1}{\partial w_1}$
$a_1 = w_1 x_1$なので$\frac{\partial a_1}{\partial w_1} = x_1$。よって
$\frac{\partial L}{\partial w_1} = \frac{\partial L}{\partial a_1}\frac{\partial a_1}{\partial w_1} = \frac{\partial L}{\partial a_1} x_1$
$w_2$についても同様です。また$x_1$については、
$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial a_1}\frac{\partial a_1}{\partial x_1}$
$a_1 = w_1 x_1$なので$\frac{\partial a_1}{\partial x_1} = w_1$。よって
$\frac{\partial L}{\partial x_1} = \frac{\partial L}{\partial a_1}\frac{\partial a_1}{\partial x_1} = \frac{\partial L}{\partial a_1} w_1$
これが上述の講義でgradient "switcher"と呼ばれている、掛け算についてのbackpropagationです。
(なお、$x_1$や$x_2$が直前の層の出力である場合も、計算の仕方は変わりません。)
ところで論文に$\nabla_{\boldsymbol{w}}L$とあるのは、いま求めた$w_1$や$w_2$に関する偏微分のことです。つまり、
$\nabla_{\boldsymbol{w}}L = (\frac{\partial L}{\partial w_1} , \frac{\partial L}{\partial w_2}) = \frac{\partial L}{\partial y}\times(x_1, x_2) = \frac{\partial L}{\partial y} \boldsymbol{x}$
のことです。論文にas used normallyという表現(式(3)のすぐ下)がありますが、まさにそのとおりで、これは普通のbackpropagationでも計算する値です。

ひとつ左に移動します。

$\frac{\partial L}{\partial t_1} = \frac{\partial L}{\partial w_1}\frac{\partial w_1}{\partial t_1}$
$w_1 = gt_1$なので$\frac{\partial w_1}{\partial t_1}=g$。よって
$\frac{\partial L}{\partial t_1} = \frac{\partial L}{\partial w_1}\frac{\partial w_1}{\partial t_1} = \frac{\partial L}{\partial w_1} g$
$t_2$についても同様に
$\frac{\partial L}{\partial t_2} = \frac{\partial L}{\partial w_2}\frac{\partial w_2}{\partial t_2} = \frac{\partial L}{\partial w_2} g$
となります。
$g$については、前向きの計算で同じ$g$という値が枝分かれして使われているので、backpropagationでは右から合流してくるerrorをすべて加算します(この説明も上の動画にあります)。
$\frac{\partial L}{\partial g} = \frac{\partial L}{\partial w_1}\frac{\partial w_1}{\partial g} + \frac{\partial L}{\partial w_2}\frac{\partial w_2}{\partial g}
= \frac{\partial L}{\partial w_1} t_1 + \frac{\partial L}{\partial w_2} t_2$
この結果を、$\nabla_{\boldsymbol{w}}L$を使って書き直すと
$\frac{\partial L}{\partial g} = \nabla_{\boldsymbol{w}}L \cdot \boldsymbol{t}$
となります。$\cdot$は内積です。そして、上の図では$\boldsymbol{t} = \frac{\boldsymbol{v}}{\Vert \boldsymbol{v} \Vert}$としているので、
$\frac{\partial L}{\partial g} = \frac{ \nabla_{\boldsymbol{w}}L \cdot \boldsymbol{v}}{\Vert \boldsymbol{v} \Vert}$
となり論文の式(3)の左半分と一致します。

ひとつ左に移動します。

$r$については、前向きの計算で枝分かれして同じ値が使われているので、合流してくるerrorを加算します。
$\frac{\partial L}{\partial r} = \frac{\partial L}{\partial t_1}\frac{\partial t_1}{\partial r} + \frac{\partial L}{\partial t_2}\frac{\partial t_2}{\partial r}$
$t_1 = v_1 / r$と、上の結果を使うと
$\frac{\partial L}{\partial r} = - \frac{\partial L}{\partial w_1} g \frac{v_1}{r^2} - \frac{\partial L}{\partial t_2} g \frac{v_2}{r^2}$

ひとつ左に移動します。

$\frac{\partial L}{\partial s} = \frac{\partial L}{\partial r}\frac{\partial r}{\partial s}$
$r={\sqrt s}$より
$\frac{\partial L}{\partial s} = \frac{\partial L}{\partial r} \frac{1}{2r}$

ひとつ左に移動します。

$\frac{\partial L}{\partial u_1} = \frac{\partial L}{\partial s}\frac{\partial s}{\partial u_1} = \frac{\partial L}{\partial s}$
$u_2$についても同様です。

ひとつ左に移動します。

$v_1$, $v_2$は、それぞれ前向きの計算で同じ値が枝分かれして使われているので、合流してくるerrorを加算します。
$\frac{\partial L}{\partial v_1} = \frac{\partial L}{\partial u_1}\frac{\partial u_1}{\partial v_1} + \frac{\partial L}{\partial t_1}\frac{\partial t_1}{\partial v_1}$
$u_1 = v_1^2$より$\frac{\partial u_1}{\partial v_1} = 2 v_1$です。
また、$t_1 = \frac{v_1}{r}$より、$\frac{\partial t_1}{\partial v_1} = \frac{1}{r}$です。
$\frac{\partial L}{\partial t_1}$については、上のほうでおこなった計算に戻ると分かるので
$\frac{\partial L}{\partial v_1} = \frac{\partial L}{\partial u_1}\frac{\partial u_1}{\partial v_1} + \frac{\partial L}{\partial t_1}\frac{\partial t_1}{\partial v_1}
= 2 \frac{\partial L}{\partial s} v_1 + \frac{\partial L}{\partial w_1} g \frac{1}{r}$
となります。この式を、これまでの結果を使って書き直すと、次のようになります。

\begin{align}
\frac{\partial L}{\partial v_1} 
& = 2 \frac{\partial L}{\partial s} v_1 + \frac{\partial L}{\partial w_1} g \frac{1}{r} \\
& = 2 \frac{\partial L}{\partial r} \frac{1}{2r} v_1 + \frac{\partial L}{\partial w_1} \frac{g}{r} \\
& = - \bigg(\frac{\partial L}{\partial w_1} g \frac{v_1}{r^2} + \frac{\partial L}{\partial t_2} g \frac{v_2}{r^2} \bigg) \frac{1}{r} v_1 + \frac{\partial L}{\partial w_1} \frac{g}{r} \\
& = - \bigg(\frac{\partial L}{\partial w_1} g \frac{v_1}{r} + \frac{\partial L}{\partial t_2} g \frac{v_2}{r}\bigg) \frac{1}{r^2} v_1 + \frac{\partial L}{\partial w_1} \frac{g}{r} \\
& = - \bigg(\frac{\partial L}{\partial w_1} t_1 + \frac{\partial L}{\partial t_2} t_2\bigg) \frac{g}{r^2} v_1 + \frac{\partial L}{\partial w_1} \frac{g}{r} \\
& = - \frac{\partial L}{\partial g} \frac{g}{r^2} v_1 + \frac{\partial L}{\partial w_1} \frac{g}{r}
\end{align}

$v_2$に関する偏微分も同様に
$\frac{\partial L}{\partial v_2} = - \frac{\partial L}{\partial g} \frac{g}{r^2} v_2 + \frac{\partial L}{\partial w_2}\frac{g}{r}$
となります。$v_1$と$v_2$に関する結果をまとめて書くと
$\frac{\partial L}{\partial \boldsymbol{v}} = - \frac{\partial L}{\partial g} \frac{g}{r^2} \boldsymbol{v} + \frac{\partial L}{\partial \boldsymbol{w}}\frac{g}{r}$
となります。
$r=\Vert \boldsymbol{v} \Vert$だったので、
$\nabla_\boldsymbol{v} L = - \frac{g \nabla_g L}{\Vert \boldsymbol{v} \Vert^2} \boldsymbol{v} + \frac{g}{\Vert \boldsymbol{v} \Vert} \nabla_\boldsymbol{w} L$
となり、これは論文の式(3)の右半分に一致します。

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0