本記事は,Batch Normalizationのfowradの定義を愚直に微分して,
backward(誤差逆伝播)の計算式を導出した際のメモです.
forwadの定義
- 入力:$x[N][C][H][W]$
- ※ x[N][C][H][W]はN:バッチサイズ, C:チャンネル数, H:hight, W:widthの4次元テンソル
- ※ x[N][C][H][W]がバッチサイズ=256, 256 x 256pixelのRBG画像である場合,N=256, C=3,H=W=256となる
- 出力:$y[N][C][H][W]$
- 重みパラメータ:
- $\gamma[C]$
- $\beta[C]$
- 定数:
- $\epsilon$
- foward計算式:
- $y[N][C][H][W] = \gamma[C]\times\hat{x}[N][C][H][W]+\beta[C]$
- $\hat{x}[N][C][H][W]=\frac{x[N][C][H][W]-\mu[C]}{\sqrt{\sigma^2[C]+\epsilon}}$
- $\mu[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}x[n][C][h][w]$
- $\sigma^2[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}\left(x[n][C][h][w]-\mu[C]\right)^2$
- $y[N][C][H][W] = \gamma[C]\times\hat{x}[N][C][H][W]+\beta[C]$
backwardの定義
- 損失関数 (loss function):$L$
- 入力:$\delta y[N][C][H][W]:= \frac{\partial L}{\partial y[N][C][H][W]}$
- 出力:
- $\delta x[N][C][H][W]:=\frac{\partial L}{\partial x[N][C][H][W]}$
- $\delta \gamma[C]:=\frac{\partial L}{\partial \gamma[C]}$
- $\delta \beta[C]:=\frac{\partial L}{\partial \beta[C]}$
-
定数:
- $\epsilon$
-
backwardの計算式
- $\delta x[N][C][H][W]=\frac{\gamma[C]}{\sqrt{\sigma^2[C]+\epsilon}}\left[\delta y[N][C][H][W] -\frac{1}{NHW}\left(\delta \beta[C]+\delta \gamma[C]\times\hat{x}[N][C][H][W] \right)\right]$
- $\delta \gamma[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\left(\delta y[n][C][h][w]\times\hat{x}[n][C][h][w] \right)$
- $\delta \beta[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\delta y[n][C][h][w]$
δβの計算式の導出
\begin{align}
\delta \beta[C]
&:=
\frac
{\partial L}
{\partial \beta[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\frac
{\partial y[n][C][h][w]}
{\partial \beta[C]}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\delta y[n][C][h][w]
\times
\frac
{\partial y[n][C][h][w]}
{\partial \beta[C]}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\delta y[n][C][h][w]
\times
\frac
{\partial }
{\partial \beta[C]}
\left(
\gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
\right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}
\delta y[n][C][h][w]
\end{align}
δγの計算式の導出
\begin{align}
\delta \gamma[C]
&:=
\frac
{\partial L}
{\partial \gamma[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\frac
{\partial y[n][C][h][w]}
{\partial \gamma[C]}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\delta y[n][C][h][w]
\times
\frac
{\partial }
{\partial \gamma[C]}
\left(
\gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
\right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\delta y[n][C][h][w]
\times
\hat{x}[n][C][h][w]
\right)
\end{align}
δxの計算式の導出
\begin{align}
\delta x[N][C][H][W]
&:=
\frac
{\partial L}
{\partial x[N][C][H][W]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\frac
{\partial y[n][C][h][w]}
{\bf{\partial x[N][C][H][W]}}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\frac
{\partial }
{\bf{\partial x[N][C][H][W]}}
\left(
\gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
\right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\gamma[C]
\frac
{\partial }
{\bf{\partial x[N][C][H][W]}}
\left(
\frac
{x[n][C][h][w]-\mu[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\right)
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\left[
\frac
{\partial \left(x[n][C][h][w]-\mu[C] \right) }
{\bf{\partial x[N][C][H][W]}}
\times
\frac
{1}
{\sqrt{\sigma^2[C]+\epsilon}}
+
(x[n][C][h][w]-\mu[C])
\times
\frac
{\partial \left( \frac
{1}
{\sqrt{\sigma^2[C]+\epsilon}}
\right) }
{\bf{\partial x[N][C][H][W]}}
\right]
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\left[
\left(
\delta^{n,h,w}_{N,H,W}
-
\frac
{\partial \mu[C]}
{\bf{\partial x[N][C][H][W]}}
\right)
\times
\frac
{1}
{\sqrt{\sigma^2[C]+\epsilon}}
+
(x[n][C][h][w]-\mu[C])
\times
\left(
-
\frac
{1}
{2(\sqrt{\sigma^2[C]+\epsilon})^3}
\frac
{\partial \sigma^2[C]}
{\bf{\partial x[N][C][H][W]}}
\right)
\right]
\right) \\
&=
\frac
{\gamma[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\left[
\delta^{n,h,w}_{N,H,W}
-
\frac
{\partial \mu[C]}
{\bf{\partial x[N][C][H][W]}}
-
\frac
{x[n][C][h][w]-\mu[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\frac
{1}
{2\sqrt{\sigma^2[C]+\epsilon}}
\frac
{\partial \sigma^2[C]}
{\bf{\partial x[N][C][H][W]}}
\right]
\right) \\
&\because
\frac
{\partial \mu[C]}
{\bf{\partial x[N][C][H][W]}}
=
\frac
{1}
{NHW} \\
&\because
\frac
{\partial \sigma^2[C]}
{\bf{\partial x[N][C][H][W]}}
=
\frac
{2}
{NHW}
(\bf{x[N][C][H][W]}-\mu[C]) \\
&=
\frac
{\gamma[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
\frac
{\partial L}
{\partial y[n][C][h][w]}
\times
\left[
\delta^{n,h,w}_{N,H,W}
-
\frac
{1}
{NHW}
-
\frac
{1}
{NHW}
\frac
{x[n][C][h][w]-\mu[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\frac
{\bf{x[N][C][H][W]}-\mu[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\right]
\right) \\
&=
\frac
{\gamma[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
\delta y[n][C][h][w]
\times
\left[
\delta^{n,h,w}_{N,H,W}
-
\frac
{1}
{NHW}
\left(
1
+
\hat{x}[n][C][h][w]
\times
\bf{\hat{x}[N][C][H][W]}
\right)
\right]
\right) \\
&=
\frac
{\gamma[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
\delta y[n][C][h][w]
\times
\delta^{n,h,w}_{N,H,W}
-
\frac
{1}
{NHW}
\left(
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
\delta y[n][C][h][w]
+
\bf{\hat{x}[N][C][H][W]}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}} \left(
\delta y[n][C][h][w]
\times
\hat{x}[n][C][h][w]
\right)
\right)
\right]\\
&=
\frac
{\gamma[C]}
{\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
\delta y[N][C][H][W]
-
\frac
{1}
{NHW}
\left(
\delta \beta[C]
+
\bf{\hat{x}[N][C][H][W]}
\times
\delta \gamma[C]
\right)
\right]\\
\end{align}
※ Batch Normalizationの3種類のモード(per_activationとspatial)について
cudnnでは,batchNormのモードとしてcudnnBatchNormMode_tに下記の3種が定義されています.
- [1] CUDNN_BATCHNORM_PER_ACTIVATION
- [2] CUDNN_BATCHNORM_SPATIAL
- [3] CUDNN_BATCHNORM_SPATIAL_PERSISTENT
本稿は,[2] CUDNN_BATCHNORM_SPATIALの誤差逆伝播を導出したものです.
なお,
[1]は畳み込み層以外の出力を正規化するのに用いられます.
[2]は畳み込み層の出力を正規化するのに用いられます.一般に,batchNormalizationと言うと[2]を指します.
[3]は[2]の高速化版だそうです.chainerなどのdeeplearing frameworkでは一般に[3]をcallしています.ただし,下記のような条件が付きます.
- cudnnBatchNormalizationBackward()に,cudnnBatchNormalizationForwardTraining()で計算したsavedMean及びsavedInvVarianceを渡す必要がある(NaNは不可)
- overflowが起きた場合NaNが出力される.cudnnQueryRuntimeError()でoverflowが起きてないかチェックする必要がある
chainer/chainer/functions/normalization/batch_normalization.py