曖昧な理解だったのを、自前で実装できるくらいに理解しようと図解しました。その際の資料を公開します。

内容は、ほぼ"Understanding the backward pass through Batch Normalization Layer"の焼き直しです。

全結合NN の Batch Normalization

いつ行うの?

全結合のニューラルネットワークの場合、Affinの後、活性化(例:ReLU)の前
BN_NN_配置.PNG

入力は?

Affinの出力を 行 として、
(図は入力層→NN第一層での例)
BN_C_01_.PNG

ミニバッチ数分のAffin出力を並べた行列が入力
BN_C_02_.PNG

入力行列をどう演算するの?

要素毎(列内)で正規化の演算します。
 BN_C_03_.PNG

演算式は?

上図の1列の演算を示す

入力:

 上図の1列の値 { $x_1$ ... $x_N$ } ( $N$:ミニバッチ数)
 学習値 $γ$ , $β$

出力:
y_i = BatchNorm _{γβ} ( x_i )
ミニバッチ平均:
\mu _x \leftarrow \frac {1}{N} \sum _{i=1} ^{N} x_i
ミニバッチ分散:
\sigma^2 _x \leftarrow \frac {1}{N} \sum _{i=1} ^{N} (x_i - \mu_x)^2
正規化:
\hat{x_i} \leftarrow \frac {x_i - \mu_x}{\sqrt{\sigma^2_x+ \epsilon}}

変倍・移動

y_i \leftarrow γ \hat{x_i} + β 
\equiv BatchNorm _{γβ} ( x_i )

効果は?

  • 学習を速く進行させることができる(学習係数を大きくすることができる)
  • 初期値にそれほど依存しない(初期値に対してそこまで神経質にならなくてよい)
  • 過学習を抑制する(Dropout などの必要性を減らす)

逆伝播値の求め方は?

入力
識別名 次元
$d_{out}$ $(N,D)$ 後続層からの逆伝搬値
順伝播の計算時に保持した値
識別名 次元 備考
$f_3$ $(N,D)$ $x_i-μ_x$
$f_5$ $(D,)$ $ \frac{1}{N} \sum_{i=1} ^{N} (x_i - \mu_x)^2$ $=\sigma^2 _x$
$f_7$ $(D,)$ $\frac{1}{\sqrt{\sigma^2 _x + \epsilon}}$ $=\frac{1}{\sqrt{f_5 + \epsilon}}$
$f_9$ $(N,D)$ $\frac{x_i - \mu_x}{\sqrt{\sigma^2 _x + \epsilon}}$ $=\hat{x_i}$
計算過程値
識別名 次元
$d_{12a}$ $(N,D)$ $d_{out}\circ γ$
$d_{3a}$ $(N,D)$ $f_7\circ(d_{12a}-f_3\circ \frac{1}{f_5 + \epsilon} \circ \frac{1}{N}\sum ^N _{i=1}(f _3 \circ d _{12a} )) $
出力
識別名 次元
$d_x$ $(N,D)$ $d_{3a} - \frac{1}{N}\sum ^N _{i=1} d _{3a}$
$d_γ$ $(D,)$ $\sum ^N _{i=1}(\hat{x_i} \circ d _{out}$)
$d_β$ $(D,)$ $\sum ^N _{i=1} d _{out}$

逆伝播の式はどう求めた?

計算グラフで
BN_A_0_All_1_withLineNum.PNG

$d_{out}$から、各ノードの逆伝播の出力値を辿り、$d_x$を求めます。
「誤差逆伝播法等に用いる 計算グラフ の基本パーツ」 を前提にしています。

ノード⑮ 

BN_B_15.PNG

加算ノードの逆伝播は、入力値をそのまま伝搬。

出力識別名 次元
$d_{15}$ $(N,D)$ $d_{out}$

ノード⑭

BN_B_14.PNG
Broadcastノードの逆伝播は、入力値の総和。

出力識別名 次元
$d_β = d_{14}$ $(D,)$ $\sum^N_{i=1} d_{15}$

ノード⑫

BN_B_12.PNG

乗算ノードの逆伝播は、順伝播の入れ替えし伝搬

出力識別名 次元
$d_{12a}$ $(N,D)$ $f_{11} \circ d_{15}$
$d_{12b}$ $(N,D)$ $f_9 \circ d_{15}$

ノード⑪

BN_B_11.PNG

Broadcastノードの逆伝播は、入力値の総和。

出力識別名 次元
$d_γ = d_{11}$ $(D,)$ $\sum^N_{i=1} d_{12b}$

ノード⑨

BN_B_09.PNG

乗算ノードの逆伝播は、順伝播の入れ替えし伝搬

出力識別名 次元
$d_{9a}$ $(N,D)$ $f_8 \circ d_{12a}$
$d_{9b}$ $(N,D)$ $f_3 \circ d_{12a}$

ノード⑧

BN_B_08.PNG

Broadcastノードの逆伝播は、入力値の総和。

出力識別名 次元
$d_8$ $(D,)$ $\sum^N_{i=1} d_{9b}$

ノード⑦

BN_B_07.PNG

逆伝播は、$\frac{1}{x}$を$x$で微分し、入力値に乗じた値 

出力識別名 次元
$d_7$ $(D,)$ $- \frac{1}{(f_6) ^2} \circ d_8$

ノード⑥

BN_B_06.PNG

逆伝播は、$\sqrt{x+\epsilon}$ を $x$ で微分し、入力値に乗じた値

出力識別名 次元
$d_6$ $(D,)$ $ \frac{1}{2\sqrt{f_5 +\epsilon} } \circ d_7$

  (※)$\epsilon$ は 小さい値 (例:$1^{-7}$)

ノード⑤

BN_B_05.PNG

総和の逆伝播は、入力値をそのままN個に分配し、次元を$(D,)$ →$(N,D)$ に変える

出力識別名 次元
$d_5$ $(N,D)$ $\frac{1}{N} d_6$

ノード④

BN_B_04.PNG

逆伝播は、$x^2$ を $x$ で微分し、入力値に乗じた値

出力識別名 次元
$d_4$ $(N,D)$ $2$ $f_3 \circ d_5$

ノード③

BN_B_03.PNG

マイナスの加算と分岐のノード
分岐の逆伝播は加算
加算の逆伝播は分岐

出力識別名 次元
$d_{3a}$ $(N,D)$ $d_{9a} + d_4$
$d_{3b}$ $(N,D)$ $-d_{3a}$

ノード②

BN_B_02.PNG

Broadcastノードの逆伝播は、入力値の総和。

出力識別名 次元
$d_2$ $(D,)$ $\sum^N_{i=1} d_{3b}$

ノード①

BN_B_01.PNG

総和の逆伝播は、入力値をそのままN個に分配し、次元を$(D,)$ →$(N,D)$ に変える

出力識別名 次元
$d_1$ $(N,D)$ $\frac{1}{N} d_2$

ノード⓪

BN_B_00.PNG

分岐ノードの逆伝播は、入力の加算

出力識別名 次元
$d_x$ $(N,D)$ $d_{3a} + d_1$

計算グラフから求めた逆伝播式の整理

計算過程値 $d_{12a}$

d_{15}=d_{out}\\
f_{11} = Broadcast(γ) \\

 の2式を

d_{12a}=f_{11} \circ d_{15}

 に代入すると

d_{12a}=γ \circ d_{out}

計算過程値 $d_{3a}$

d_{9a}= f_8 \circ d_{12a} \\

f_8 = Broadcast(f_7)\\

d_4 = 2 (f_3 \circ d_5)\\

d_5 = \frac{1}{N} d_6\\

d_6 = \frac{1}{2\sqrt{f_5 +\epsilon} } \circ d_7\\

d_7 = - \frac{1}{(f_6) ^2} \circ d_8\\

f_6 = \sqrt{f_5 +\epsilon }\\

d_8 = \sum^N_{i=1} d_{9b}\\

d_{9b} = f_3 \circ d_{12a}\\

 の9式を

d_{3a}=d_{9a} + d_4

 に代入すると

\begin{align}
d_3a &= f_7 \circ d_{12a} + 2 (f_3 \circ \frac{1}{N} \frac{1}{2\sqrt{f_5 +\epsilon} } \circ  \frac{-1}{(\sqrt{f_5 +\epsilon }) ^2} \circ \sum^N_{i=1} (f_3 \circ d_{12a})) \\
&= f_7 \circ d_{12a} + 2 (f_3 \circ \frac{1}{N} \frac{1}{2} f_7 \circ  \frac{-1}{(\sqrt{f_5 +\epsilon }) ^2} \circ \sum^N_{i=1} (f_3 \circ d_{12a})) \\
&= f_7 \circ (d_{12a} - f_3 \circ \frac{1}{f_5 +\epsilon  } \circ \frac{1}{N}\sum^N_{i=1} (f_3 \circ d_{12a}))
\end{align}

出力 $dx$

d_1=\frac{1}{N} d_2\\
d_2=\sum^N _{i=1} d_{3b}\\
d_{3b}=-d_{3a}\\

 の3式を

dx=d_{3a}+d_1 

 に代入すると

dx=d_{3a}-\frac{1}{N} \sum^N _{i=1} d_{3a}

CNN の Batch Normalization

CNNの場合はいつ行うの?

CNNの場合、Convolutionの後、活性化(例:ReLU)の前
BN_CNN_配置.PNG

CNNの場合の入力は?

Convolution の出力の チャンネルをシリアライズし1行とし、
BN_C_04.PNG

ミニバッチ数の行数とした行列。

以後の計算は、全結合のBatch Normalization と同じ。

参考文献

関連項目

  Mind で Neural Network (準備編2) 順伝播・逆伝播 図解
  誤差逆伝播法等に用いる 計算グラフ の基本パーツ
  シンプルなNNで 学習失敗時の挙動と Batch Normalization の効果を見る

Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account log in.