LoginSignup
2
2

More than 5 years have passed since last update.

Batch Normalizationの誤差逆伝播(Backward)の導出

Last updated at Posted at 2019-01-02

本記事は,Batch Normalizationのfowradの定義を愚直に微分して,
backward(誤差逆伝播)の計算式を導出した際のメモです.

forwadの定義

  • 入力:$x[N][C][H][W]$
    • ※ x[N][C][H][W]はN:バッチサイズ, C:チャンネル数, H:hight, W:widthの4次元テンソル
    • ※ x[N][C][H][W]がバッチサイズ=256, 256 x 256pixelのRBG画像である場合,N=256, C=3,H=W=256となる
  • 出力:$y[N][C][H][W]$
  • 重みパラメータ:
    • $\gamma[C]$
    • $\beta[C]$
  • 定数:
    • $\epsilon$
  • foward計算式:
    • $y[N][C][H][W] = \gamma[C]\times\hat{x}[N][C][H][W]+\beta[C]$
      • $\hat{x}[N][C][H][W]=\frac{x[N][C][H][W]-\mu[C]}{\sqrt{\sigma^2[C]+\epsilon}}$
      • $\mu[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}x[n][C][h][w]$
      • $\sigma^2[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}\left(x[n][C][h][w]-\mu[C]\right)^2$

backwardの定義

  • 損失関数 (loss function):$L$
  • 入力:$\delta y[N][C][H][W]:= \frac{\partial L}{\partial y[N][C][H][W]}$
  • 出力:
    • $\delta x[N][C][H][W]:=\frac{\partial L}{\partial x[N][C][H][W]}$
    • $\delta \gamma[C]:=\frac{\partial L}{\partial \gamma[C]}$
    • $\delta \beta[C]:=\frac{\partial L}{\partial \beta[C]}$
  • 定数:

    • $\epsilon$
  • backwardの計算式

    • $\delta x[N][C][H][W]=\frac{\gamma[C]}{\sqrt{\sigma^2[C]+\epsilon}}\left[\delta y[N][C][H][W] -\frac{1}{NHW}\left(\delta \beta[C]+\delta \gamma[C]\times\hat{x}[N][C][H][W] \right)\right]$
    • $\delta \gamma[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\left(\delta y[n][C][h][w]\times\hat{x}[n][C][h][w] \right)$
    • $\delta \beta[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\delta y[n][C][h][w]$

δβの計算式の導出

\begin{align}
\delta \beta[C]
&:=
\frac
  {\partial L}
  {\partial  \beta[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \beta[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}
  \delta y[n][C][h][w]
\end{align}

δγの計算式の導出

\begin{align}
\delta \gamma[C]
&:=
\frac
  {\partial L}
  {\partial \gamma[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \gamma[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \gamma[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \hat{x}[n][C][h][w]
  \right)
\end{align}

δxの計算式の導出

\begin{align}
\delta x[N][C][H][W]
&:=
\frac
  {\partial L}
  {\partial x[N][C][H][W]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\bf{\partial  x[N][C][H][W]}}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
  \right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \gamma[C]
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right)
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \frac
      {\partial \left(x[n][C][h][w]-\mu[C] \right) }
      {\bf{\partial  x[N][C][H][W]}}
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \frac
      {\partial \left( \frac
                          {1}
                          {\sqrt{\sigma^2[C]+\epsilon}}
                \right) }
      {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \left(
      \delta^{n,h,w}_{N,H,W}
      -
      \frac
        {\partial \mu[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \left(
      -
      \frac
        {1}
        {2(\sqrt{\sigma^2[C]+\epsilon})^3}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {\partial \mu[C]}
      {\bf{\partial  x[N][C][H][W]}}
    -
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {1}
      {2\sqrt{\sigma^2[C]+\epsilon}}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&\because
\frac
  {\partial \mu[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {1}
  {NHW} \\
&\because
\frac
  {\partial \sigma^2[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {2}
  {NHW}
(\bf{x[N][C][H][W]}-\mu[C]) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    -
    \frac
      {1}
      {NHW}
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {\bf{x[N][C][H][W]}-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \delta y[n][C][h][w]
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    \left(
      1
      +
      \hat{x}[n][C][h][w]
      \times
      \bf{\hat{x}[N][C][H][W]}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
  \delta y[n][C][h][w]
  \times
  \delta^{n,h,w}_{N,H,W}
  -
  \frac
    {1}
    {NHW}
  \left(
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
    \delta y[n][C][h][w]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}} \left(
      \delta y[n][C][h][w]
      \times
      \hat{x}[n][C][h][w]
    \right)
  \right)
\right]\\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \delta y[N][C][H][W]
  -
  \frac
    {1}
    {NHW}
  \left(
    \delta \beta[C]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \delta \gamma[C]
  \right)
\right]\\
\end{align}

※ Batch Normalizationの3種類のモード(per_activationとspatial)について

cudnnでは,batchNormのモードとしてcudnnBatchNormMode_tに下記の3種が定義されています.

  • [1] CUDNN_BATCHNORM_PER_ACTIVATION
  • [2] CUDNN_BATCHNORM_SPATIAL
  • [3] CUDNN_BATCHNORM_SPATIAL_PERSISTENT

本稿は,[2] CUDNN_BATCHNORM_SPATIALの誤差逆伝播を導出したものです.
なお,
[1]は畳み込み層以外の出力を正規化するのに用いられます.
[2]は畳み込み層の出力を正規化するのに用いられます.一般に,batchNormalizationと言うと[2]を指します.
[3]は[2]の高速化版だそうです.chainerなどのdeeplearing frameworkでは一般に[3]をcallしています.ただし,下記のような条件が付きます.

  • cudnnBatchNormalizationBackward()に,cudnnBatchNormalizationForwardTraining()で計算したsavedMean及びsavedInvVarianceを渡す必要がある(NaNは不可)
  • overflowが起きた場合NaNが出力される.cudnnQueryRuntimeError()でoverflowが起きてないかチェックする必要がある

chainer/chainer/functions/normalization/batch_normalization.py

2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2