Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
2
Help us understand the problem. What is going on with this article?
@takata150802

Batch Normalizationの誤差逆伝播(Backward)の導出

More than 1 year has passed since last update.

本記事は,Batch Normalizationのfowradの定義を愚直に微分して,
backward(誤差逆伝播)の計算式を導出した際のメモです.

forwadの定義

  • 入力:$x[N][C][H][W]$
    • ※ x[N][C][H][W]はN:バッチサイズ, C:チャンネル数, H:hight, W:widthの4次元テンソル
    • ※ x[N][C][H][W]がバッチサイズ=256, 256 x 256pixelのRBG画像である場合,N=256, C=3,H=W=256となる
  • 出力:$y[N][C][H][W]$
  • 重みパラメータ:
    • $\gamma[C]$
    • $\beta[C]$
  • 定数:
    • $\epsilon$
  • foward計算式:
    • $y[N][C][H][W] = \gamma[C]\times\hat{x}[N][C][H][W]+\beta[C]$
      • $\hat{x}[N][C][H][W]=\frac{x[N][C][H][W]-\mu[C]}{\sqrt{\sigma^2[C]+\epsilon}}$
      • $\mu[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}x[n][C][h][w]$
      • $\sigma^2[C]=\frac{1}{NHW}\sum_{n=1,h=1,w=1}^{N,H,W}\left(x[n][C][h][w]-\mu[C]\right)^2$

backwardの定義

  • 損失関数 (loss function):$L$
  • 入力:$\delta y[N][C][H][W]:= \frac{\partial L}{\partial y[N][C][H][W]}$
  • 出力:
    • $\delta x[N][C][H][W]:=\frac{\partial L}{\partial x[N][C][H][W]}$
    • $\delta \gamma[C]:=\frac{\partial L}{\partial \gamma[C]}$
    • $\delta \beta[C]:=\frac{\partial L}{\partial \beta[C]}$
  • 定数:

    • $\epsilon$
  • backwardの計算式

    • $\delta x[N][C][H][W]=\frac{\gamma[C]}{\sqrt{\sigma^2[C]+\epsilon}}\left[\delta y[N][C][H][W] -\frac{1}{NHW}\left(\delta \beta[C]+\delta \gamma[C]\times\hat{x}[N][C][H][W] \right)\right]$
    • $\delta \gamma[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\left(\delta y[n][C][h][w]\times\hat{x}[n][C][h][w] \right)$
    • $\delta \beta[C] = \sum_{n=1,h=1,w=1}^{N,H,W}\delta y[n][C][h][w]$

δβの計算式の導出

\begin{align}
\delta \beta[C]
&:=
\frac
  {\partial L}
  {\partial  \beta[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \beta[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \beta[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}
  \delta y[n][C][h][w]
\end{align}

δγの計算式の導出

\begin{align}
\delta \gamma[C]
&:=
\frac
  {\partial L}
  {\partial \gamma[C]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\partial  \gamma[C]}
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \frac
    {\partial  }
    {\partial  \gamma[C]}
    \left(
      \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
    \right)
  \right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \delta y[n][C][h][w]
  \times
  \hat{x}[n][C][h][w]
  \right)
\end{align}

δxの計算式の導出

\begin{align}
\delta x[N][C][H][W]
&:=
\frac
  {\partial L}
  {\partial x[N][C][H][W]} \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  y[n][C][h][w]}
    {\bf{\partial  x[N][C][H][W]}}
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \gamma[C]\times\hat{x}[n][C][h][w]+\beta[C]
  \right)
\right) \\
&=
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \gamma[C]
  \frac
    {\partial  }
    {\bf{\partial  x[N][C][H][W]}}
  \left(
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right)
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{N,H,W}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \frac
      {\partial \left(x[n][C][h][w]-\mu[C] \right) }
      {\bf{\partial  x[N][C][H][W]}}
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \frac
      {\partial \left( \frac
                          {1}
                          {\sqrt{\sigma^2[C]+\epsilon}}
                \right) }
      {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&=
\gamma[C]
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \left(
      \delta^{n,h,w}_{N,H,W}
      -
      \frac
        {\partial \mu[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
    \times
    \frac
      {1}
      {\sqrt{\sigma^2[C]+\epsilon}}
    +
    (x[n][C][h][w]-\mu[C])
    \times
    \left(
      -
      \frac
        {1}
        {2(\sqrt{\sigma^2[C]+\epsilon})^3}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {\partial \mu[C]}
      {\bf{\partial  x[N][C][H][W]}}
    -
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {1}
      {2\sqrt{\sigma^2[C]+\epsilon}}
      \frac
        {\partial \sigma^2[C]}
        {\bf{\partial  x[N][C][H][W]}}
  \right]
\right) \\
&\because
\frac
  {\partial \mu[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {1}
  {NHW} \\
&\because
\frac
  {\partial \sigma^2[C]}
  {\bf{\partial  x[N][C][H][W]}}
=
\frac
  {2}
  {NHW}
(\bf{x[N][C][H][W]}-\mu[C]) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \frac
    {\partial L}
    {\partial  y[n][C][h][w]}
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    -
    \frac
      {1}
      {NHW}
    \frac
      {x[n][C][h][w]-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
    \frac
      {\bf{x[N][C][H][W]}-\mu[C]}
      {\sqrt{\sigma^2[C]+\epsilon}}
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\sum_{n=1,h=1,w=1}^{\bf{N,H,W}}\left(
  \delta y[n][C][h][w]
  \times
  \left[
    \delta^{n,h,w}_{N,H,W}
    -
    \frac
      {1}
      {NHW}
    \left(
      1
      +
      \hat{x}[n][C][h][w]
      \times
      \bf{\hat{x}[N][C][H][W]}
    \right)
  \right]
\right) \\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
  \delta y[n][C][h][w]
  \times
  \delta^{n,h,w}_{N,H,W}
  -
  \frac
    {1}
    {NHW}
  \left(
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}}
    \delta y[n][C][h][w]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \sum_{n=1,h=1,w=1}^{\bf{N,H,W}} \left(
      \delta y[n][C][h][w]
      \times
      \hat{x}[n][C][h][w]
    \right)
  \right)
\right]\\
&=
\frac
  {\gamma[C]}
  {\sqrt{\sigma^2[C]+\epsilon}}
\times
\left[
  \delta y[N][C][H][W]
  -
  \frac
    {1}
    {NHW}
  \left(
    \delta \beta[C]
    +
    \bf{\hat{x}[N][C][H][W]}
    \times
    \delta \gamma[C]
  \right)
\right]\\
\end{align}

※ Batch Normalizationの3種類のモード(per_activationとspatial)について

cudnnでは,batchNormのモードとしてcudnnBatchNormMode_tに下記の3種が定義されています.

  • [1] CUDNN_BATCHNORM_PER_ACTIVATION
  • [2] CUDNN_BATCHNORM_SPATIAL
  • [3] CUDNN_BATCHNORM_SPATIAL_PERSISTENT

本稿は,[2] CUDNN_BATCHNORM_SPATIALの誤差逆伝播を導出したものです.
なお,
[1]は畳み込み層以外の出力を正規化するのに用いられます.
[2]は畳み込み層の出力を正規化するのに用いられます.一般に,batchNormalizationと言うと[2]を指します.
[3]は[2]の高速化版だそうです.chainerなどのdeeplearing frameworkでは一般に[3]をcallしています.ただし,下記のような条件が付きます.

  • cudnnBatchNormalizationBackward()に,cudnnBatchNormalizationForwardTraining()で計算したsavedMean及びsavedInvVarianceを渡す必要がある(NaNは不可)
  • overflowが起きた場合NaNが出力される.cudnnQueryRuntimeError()でoverflowが起きてないかチェックする必要がある

chainer/chainer/functions/normalization/batch_normalization.py

2
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
2
Help us understand the problem. What is going on with this article?