1. takata150802
Changes in body
Source | HTML | Preview

本記事は,Batch Normalizationのfowradの定義を愚直に微分して,
backward(誤差逆伝播)の計算式を導出した際のメモです.

※ Batch Normalizationの2種類のモード(per_activationとspatial)について
cudnnでは,batchNormのモードとしてcudnnBatchNormMode_tに下記の3種が定義されています.

  • [1] CUDNN_BATCHNORM_PER_ACTIVATION
  • [2] CUDNN_BATCHNORM_SPATIAL
  • [3] CUDNN_BATCHNORM_SPATIAL_PERSISTENT

本稿は,[2] CUDNN_BATCHNORM_SPATIALの誤差逆伝播を導出したものです.
なお,
[1]は畳み込み層以外の出力を正規化するのに用いられます.
[2]は畳み込み層の出力を正規化するのに用いられます.一般に,batchNormalizationと言うと[2]を指します.
[3]は[2]の高速化版だそうです.chainerなどのdeeplearing frameworkでは一般に[3]をcallしています.ただし,下記のような条件が付きます.

  • cudnnBatchNormalizationBackward()に,cudnnBatchNormalizationForwardTraining()で計算したsavedMean及びsavedInvVarianceを渡す必要がある(NaNは不可)
  • overflowが起きた場合NaNが出力される.cudnnQueryRuntimeError()でoverflowが起きてないかチェックする必要がある

forwadの定義

  • 入力:$x[N][C][H][W]$
    • ※ x[N][C][H][W]はN:バッチサイズ, C:チャンネル数, H:hight, W:widthの4次元テンソル
    • ※ x[N][C][H][W]がバッチサイズ=256, 256 x 256pixelのRBG画像である場合,N=256, C=3,H=W=256となる
  • 出力:$y[N][C][H][W]$
  • 重みパラメータ:
    • $\gamma[C]$
    • $\beta[C]$
  • 定数:
    • $\epsilon$
  • foward計算式:
    • $y[N][C][H][W] = \gamma[C]\times\hat{x}[N][C][H][W]+\beta[C]$
      • $\hat{x}[N][C][H][W]=\frac{x[N][C][H][W]-\mu[C]}{\sqrt{\sigma^2[C]+\epsilon}}$
      • $\mu[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}x[n][C][h][w]$
      • $\sigma^2[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}\left(x[n][C][h][w]-\mu[C]\right)^2$

backwardの定義

  • 損失関数 (loss function):$L$
  • 入力:$\delta y[N][C][H][W]:= \frac{\partial L}{\partial y[N][C][H][W]}$
  • 出力:
    • $\delta x[N][C][H][W]:=\frac{\partial L}{\partial x[N][C][H][W]}$
    • $\delta \gamma[C]:=\frac{\partial L}{\partial \gamma[C]}$
    • $\delta \beta[C]:=\frac{\partial L}{\partial \beta[C]}$
  • 定数:

    • $\epsilon$
  • backwardの計算式

    • $\delta x[N][C][H][W]=\frac{\gamma[C]}{\sqrt{\sigma^2[C]+\epsilon}}\left[\delta y[N][C][H][W] -\frac{1}{NHW}\left(\delta \beta[C]+\delta \gamma[C]\times\hat{x}[N][C][H][W] \right)\right]$
    • $\delta \gamma[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\left(\delta y[n][C][h][w]\times\hat{x}[n][C][h][w] \right)$
    • $\delta \beta[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\delta y[n][C][h][w]$

δβの計算式の導出

\begin{align}
\delta \beta[C]
&:=
\frac
  {\partial L}
  {\partial  \beta[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \beta[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}
  \delta y[n][C][h][w]
\end{align}

δγの計算式の導出

\begin{align}
\delta \gamma[C]
&:=
\frac
  {\partial L}
  {\partial \gamma[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \gamma[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \gamma[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \hat{x}[n][C][h][w]
  \right)
\end{align}

δxの計算式の導出

\begin{align}
\delta x[N][C][H][W]
&:=
\frac
  {\partial L}
  {\partial x[N][C][H][W]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\bf{\partial  x[N][C][H][W]}}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
  \right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \gamma[C]
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right)
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \frac
      {\partial \left(x[n][C][h][w]-\mu[C] \right) }
      {\bf{\partial  x[N][C][H][W]}}
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \frac
      {\partial \left( \frac
                          {1}
                          {\sqrt{\sigma^2[C]+\epsilon}}
                \right) }
      {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \left(
      \delta^{n,h,w}_{N,H,W}
      -
      \frac
        {\partial \mu[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \left(
      -
      \frac
        {1}
        {2(\sqrt{\sigma^2[C]+\epsilon})^3}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {\partial \mu[C]}
      {\bf{\partial  x[N][C][H][W]}}
    -
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {1}
      {2\sqrt{\sigma^2[C]+\epsilon}}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&\because
\frac
  {\partial \mu[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {1}
  {NHW} \\
&\because
\frac
  {\partial \sigma^2[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {2}
  {NHW}
(\bf{x[N][C][H][W]}-\mu[C]) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    -
    \frac
      {1}
      {NHW}
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {\bf{x[N][C][H][W]}-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \delta y[n][C][h][w]
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    \left(
      1
      +
      \hat{x}[n][C][h][w]
      \times
      \bf{\hat{x}[N][C][H][W]}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
  \delta y[n][C][h][w]
  \times
  \delta^{n,h,w}_{N,H,W}
  -
  \frac
    {1}
    {NHW}
  \left(
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
    \delta y[n][C][h][w]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}} \left(
      \delta y[n][C][h][w]
      \times
      \hat{x}[n][C][h][w]
    \right)
  \right)
\right]\\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \delta y[N][C][H][W]
  -
  \frac
    {1}
    {NHW}
  \left(
    \delta \beta[C]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \delta \gamma[C]
  \right)
\right]\\
\end{align}