Help us understand the problem. What is going on with this article?

Batch Normalizationの誤差逆伝播(Backward)の導出

More than 1 year has passed since last update.

本記事は,Batch Normalizationのfowradの定義を愚直に微分して,
backward(誤差逆伝播)の計算式を導出した際のメモです.

forwadの定義

  • 入力:$x[N][C][H][W]$
    • ※ x[N][C][H][W]はN:バッチサイズ, C:チャンネル数, H:hight, W:widthの4次元テンソル
    • ※ x[N][C][H][W]がバッチサイズ=256, 256 x 256pixelのRBG画像である場合,N=256, C=3,H=W=256となる
  • 出力:$y[N][C][H][W]$
  • 重みパラメータ:
    • $\gamma[C]$
    • $\beta[C]$
  • 定数:
    • $\epsilon$
  • foward計算式:
    • $y[N][C][H][W] = \gamma[C]\times\hat{x}[N][C][H][W]+\beta[C]$
      • $\hat{x}[N][C][H][W]=\frac{x[N][C][H][W]-\mu[C]}{\sqrt{\sigma^2[C]+\epsilon}}$
      • $\mu[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}x[n][C][h][w]$
      • $\sigma^2[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}\left(x[n][C][h][w]-\mu[C]\right)^2$

backwardの定義

  • 損失関数 (loss function):$L$
  • 入力:$\delta y[N][C][H][W]:= \frac{\partial L}{\partial y[N][C][H][W]}$
  • 出力:
    • $\delta x[N][C][H][W]:=\frac{\partial L}{\partial x[N][C][H][W]}$
    • $\delta \gamma[C]:=\frac{\partial L}{\partial \gamma[C]}$
    • $\delta \beta[C]:=\frac{\partial L}{\partial \beta[C]}$
  • 定数:

    • $\epsilon$
  • backwardの計算式

    • $\delta x[N][C][H][W]=\frac{\gamma[C]}{\sqrt{\sigma^2[C]+\epsilon}}\left[\delta y[N][C][H][W] -\frac{1}{NHW}\left(\delta \beta[C]+\delta \gamma[C]\times\hat{x}[N][C][H][W] \right)\right]$
    • $\delta \gamma[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\left(\delta y[n][C][h][w]\times\hat{x}[n][C][h][w] \right)$
    • $\delta \beta[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\delta y[n][C][h][w]$

δβの計算式の導出

\begin{align}
\delta \beta[C]
&:=
\frac
  {\partial L}
  {\partial  \beta[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \beta[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}
  \delta y[n][C][h][w]
\end{align}

δγの計算式の導出

\begin{align}
\delta \gamma[C]
&:=
\frac
  {\partial L}
  {\partial \gamma[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \gamma[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \gamma[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \hat{x}[n][C][h][w]
  \right)
\end{align}

δxの計算式の導出

\begin{align}
\delta x[N][C][H][W]
&:=
\frac
  {\partial L}
  {\partial x[N][C][H][W]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\bf{\partial  x[N][C][H][W]}}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
  \right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \gamma[C]
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right)
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \frac
      {\partial \left(x[n][C][h][w]-\mu[C] \right) }
      {\bf{\partial  x[N][C][H][W]}}
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \frac
      {\partial \left( \frac
                          {1}
                          {\sqrt{\sigma^2[C]+\epsilon}}
                \right) }
      {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \left(
      \delta^{n,h,w}_{N,H,W}
      -
      \frac
        {\partial \mu[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \left(
      -
      \frac
        {1}
        {2(\sqrt{\sigma^2[C]+\epsilon})^3}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {\partial \mu[C]}
      {\bf{\partial  x[N][C][H][W]}}
    -
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {1}
      {2\sqrt{\sigma^2[C]+\epsilon}}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&\because
\frac
  {\partial \mu[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {1}
  {NHW} \\
&\because
\frac
  {\partial \sigma^2[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {2}
  {NHW}
(\bf{x[N][C][H][W]}-\mu[C]) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    -
    \frac
      {1}
      {NHW}
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {\bf{x[N][C][H][W]}-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \delta y[n][C][h][w]
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    \left(
      1
      +
      \hat{x}[n][C][h][w]
      \times
      \bf{\hat{x}[N][C][H][W]}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
  \delta y[n][C][h][w]
  \times
  \delta^{n,h,w}_{N,H,W}
  -
  \frac
    {1}
    {NHW}
  \left(
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
    \delta y[n][C][h][w]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}} \left(
      \delta y[n][C][h][w]
      \times
      \hat{x}[n][C][h][w]
    \right)
  \right)
\right]\\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \delta y[N][C][H][W]
  -
  \frac
    {1}
    {NHW}
  \left(
    \delta \beta[C]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \delta \gamma[C]
  \right)
\right]\\
\end{align}

※ Batch Normalizationの3種類のモード(per_activationとspatial)について

cudnnでは,batchNormのモードとしてcudnnBatchNormMode_tに下記の3種が定義されています.

  • [1] CUDNN_BATCHNORM_PER_ACTIVATION
  • [2] CUDNN_BATCHNORM_SPATIAL
  • [3] CUDNN_BATCHNORM_SPATIAL_PERSISTENT

本稿は,[2] CUDNN_BATCHNORM_SPATIALの誤差逆伝播を導出したものです.
なお,
[1]は畳み込み層以外の出力を正規化するのに用いられます.
[2]は畳み込み層の出力を正規化するのに用いられます.一般に,batchNormalizationと言うと[2]を指します.
[3]は[2]の高速化版だそうです.chainerなどのdeeplearing frameworkでは一般に[3]をcallしています.ただし,下記のような条件が付きます.

  • cudnnBatchNormalizationBackward()に,cudnnBatchNormalizationForwardTraining()で計算したsavedMean及びsavedInvVarianceを渡す必要がある(NaNは不可)
  • overflowが起きた場合NaNが出力される.cudnnQueryRuntimeError()でoverflowが起きてないかチェックする必要がある

https://github.com/chainer/chainer/blob/72d7238d61fe2ca9e49f2cf5901a0c5c37a5ef86/chainer/functions/normalization/batch_normalization.py#L680

chainer/chainer/functions/normalization/batch_normalization.py

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした