Help us understand the problem. What is going on with this article?

【補足】Batch Normalization の理解

More than 1 year has passed since last update.

TL;DR

本記事は、@t-tkd3aさんが書かれたBatch Normalization の理解の補足です。

上記記事の誤差逆伝播計算式は以下の通りですが、

\begin{align}
d_x(N,D)&=
d_{3a} 
-
\frac{1}{N}
\sum ^N _{i=1} 
  d_{3a}
\\d_{3a}&=
f_7 
\circ
\left(
  d_{12a}
  -
  f_3
  \circ
  \frac{1}{f_5 + \epsilon}
  \circ 
  \frac{1}{N}
  \sum_{j=1}^{N}(
     f_3 
     \circ
     d_{12a})
\right)
\\d_γ(D,)&=
\sum ^N _{i=1}(
  \hat{x_i} 
  \circ 
  d _{out})
\\d_β(D,)&=
\sum ^N _{i=1} 
  d_{out}
\end{align}

更に式展開を進めると以下のようになります。

\begin{align}
d_x(N,D)&=
\gamma
\circ
f_7
\circ
\left( 
  d_{out}
  -
  \frac{1}{N}
    \left(
      d_{\beta} 
      +
      f_9
      \circ
      d_{\gamma}
    \right)
\right)
\\&=
\gamma
\circ
\frac{1}{\sqrt{\sigma + \epsilon}}
  \circ
  \left( 
    d_{out}
    -
    \frac{1}{N}
      \left(
        d_{\beta}
        +
        \frac
          {x_i-\mu}
          {\sqrt{\sigma + \epsilon}}
       \circ
       d_{\gamma}
  \right)
\right)  
\end{align}

一方で、上記2つのdxが等価であることをcupy実装で検証したところ、両者の計算結果に予想以上に大きな差が出ています。
単に計算誤差なのか、コードのバグか、それとも数式が間違っているのかは検証中です。

数式展開

\begin{align}
d_x(N,D)&=
d_{3a}
-
\frac{1}{N}
\sum^N _{i=1}
d_{3a} 
\\&=
\gamma
\circ 
f_7 
\circ 
\left(
  d_{out}
  -f_9
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \bf{f_9}
    \circ
    \bf{d_{out}}
  \right)
\right)
\\&\quad
-\frac{1}{N}
\sum_{i=1}^{N}\left(
  \gamma
  \circ
  f_7
  \circ
  \left(
    d_{out}
    -f_9
    \circ
    \frac{1}{N}
    \sum_{k=1}^{N}\left(
      \bf{f_9} 
      \circ
      \bf{d_{out}}
      \right)
  \right)
\right)
\\
&=
\gamma
\circ
f_7 
\circ 
\left[
  d_{out}
  -\frac{1}{N}
  \sum_{i=1}^{N}
    d_{out}
    -f_9
    \circ
    \frac{1}{N}
    \sum_{j=1}^{N}\left(
      \bf{f_9} 
      \circ 
      \bf{d_{out}}
    \right)
    \\+ 
    \frac{1}{N}
    \sum_{i=1}^{N}\left(
      f_9 
      \circ 
      \frac{1}{N}
      \sum_{k=1}^{N}\left(
        \bf{f_9} 
        \circ 
        \bf{d_{out}}
    \right)
  \right)
\right]
\\&=
\gamma
\circ 
f_7 
\circ \left[
  d_{out}
  -\frac{1}{N}
  \sum_{i=1}^{N}
  d_{out}
  -f_9 
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \bf{f_9} 
    \circ
    \bf{d_{out}}
  \right)
  +0
\right]
\\&=
\gamma
\circ
f_7 
\circ 
\left( 
  d_{out}
  -
  \frac{1}{N}\left(
    d_{\beta}
    +
    f_9
    \circ 
    d_{\gamma}
  \right)
\right) 
\\
\because d_{3a}&=
f_7
\circ 
\left(
  d_{12a}
  -f_3
  \circ
  \frac{1}{f_5 + \epsilon}
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}(
    f_3
    \circ
    d_{12a}
  )
\right)
\\&=
f_7 
\circ 
\left(
  d_{out}
  \circ
  \gamma 
  -
  \frac
    {x_i - \mu}
    {\sqrt{\sigma + \epsilon}} 
  \circ 
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \frac
      {\bf{x_j}- \mu}
      {\sqrt{\sigma + \epsilon}} 
    \circ 
    \bf{d_{out}}
    \circ
    \gamma
  \right)
\right)
\\&= 
\gamma
\circ 
f_7 
\circ 
\left(
  d_{out}
  -f_9 
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \bf{f_9}
    \circ
    \bf{d_{out}}
  \right)
\right)

\\ \because 
\frac{1}{N}&
\sum_{i=1}^{N}\left(
  f_9
  \circ
  \frac{1}{N}
  \sum_{k=1}^{N}\left(
    \bf{f_9}
    \circ
    \bf{d_{out}}
  \right)
\right)
\\&=
\frac{1}{N}
\sum_{i=1}^{N}\left(
  f_9 
  \circ
  \frac{1}{N}
  d_{\gamma}
\right)
\\&=
\frac
  {d_{\gamma}}
  {N^2}
\circ
\sum_{i=1}^{N}\left(
  f_9 
\right)
\\&=
\frac
  {d_{\gamma}}
  {N^2}
\circ
\sum_{i=1}^{N}\left(
  \frac
    {x_i - \mu}
    {\sqrt{\sigma + \epsilon}}
\right)
\\&=
\frac
  {d_{\gamma}}
  {N^2\sqrt{\sigma + \epsilon}}
\circ
\sum_{i=1}^{N}
  \left(x_i - \mu \right)
\\&=
\frac
  {d_{\gamma}}
  {N^2\sqrt{\sigma + \epsilon}}
\circ
0 
= 0
\end{align}

python実装による検証

コード

"""
Created on Mon Nov 12 02:47:15 2018

@author: ryotakata
"""
USE_GPU = True

if USE_GPU:
    import cupy as xp
    from cupy.random import *
else:
    import numpy as xp
    from numpy.random import *

def main():
    size = (512, 4) # batch_size = 512
    bn = BatchNorm(size)
    x = random(size)
    dy = random(size)
    dx0, _, _ = bn.backward_impl0(dy, x)
    dx1, _, _ = bn.backward_impl1(dy, x)
    print ("x: \n", x)
    print ("dy: \n", dy)
    print ("dx1 - dx0: \n", (dx1 - dx0))
    return

class BatchNorm(object):
    """size = (batch_size, n_model)"""
    def __init__(self,size, eps = 2e-05):
        assert len(size) == 2
        self.size = size
        self.eps = eps
        N_, D_ = size
        self.xhat = xp.ones(size,dtype=float) * xp.nan
        self.gamma = random(D_)
        self.beta = random(D_)

    def forward(self,x):
        N_, D_ = self.size
        assert x.ndim == 2
        assert x.shape == (N_, D_)
        self.saved_mean = x.mean(axis = 0)
        assert self.saved_mean.shape == (D_,)
        self.saved_inv_var = 1./ xp.sqrt(x.var(axis = 0) + self.eps)
        assert self.saved_inv_var.shape == (D_,)
        y = xp.ones(x.shape) * xp.nan
        self.xhat = xp.ones(x.shape) * xp.nan
        for (i, xi) in enumerate(x):
            self.xhat[i] = (xi - self.saved_mean) / self.saved_inv_var
            y[i] = self.gamma * self.xhat[i] + self.beta
        return y

    def backward_impl0(self, dy, x):
        N_, D_ = self.size
        assert dy.ndim == x.ndim == 2
        assert dy.shape == x.shape == (N_, D_)
        _ = self.forward(x)
        """calculate the gradient with respect to gamma and beta"""
        dgamma = self.xhat * dy
        dgamma = dgamma.sum(axis = 0)
        assert dgamma.shape == (D_,)
        dbeta = dy.sum(axis = 0)
        assert dbeta.shape == (D_,)
        """calculate the gradient of x"""
        """f3, f5, f7, d12a"""
        f7 = self.saved_inv_var
        f3 = xp.ones(x.shape) * xp.nan
        for (i, xi) in enumerate(x):
            f3[i] = xi - self.saved_mean
        d12a = xp.ones(dy.shape) * xp.nan
        for (i, dyi) in enumerate(dy):
            d12a[i] = self.gamma * dyi
        inv_f5_eps = f7 ** 2
        assert inv_f5_eps.shape == (D_,)
        """d3a"""
        d3a = xp.ones(dy.shape) * xp.nan
        tmp_ = f3 * d12a
        assert tmp_.shape == (N_, D_)
        tmp_ = tmp_.sum(axis = 0) / N_
        assert tmp_.shape == (D_,)
        for i in range(N_):
            d3a[i] = f7 * (d12a[i] - f3[i] * inv_f5_eps * tmp_)
        """dx"""
        dx = xp.ones(x.shape) * xp.nan
        tmp_ = d3a.sum(axis = 0)
        assert tmp_.shape == (D_,)
        for (i, d3ai) in enumerate(d3a):
            dx[i] = d3ai - 1 / N_ * tmp_
        return dx, dgamma, dbeta

    def backward_impl1(self, dy, x):
        N_, D_ = self.size
        assert dy.ndim == x.ndim == 2
        assert dy.shape == x.shape == (N_, D_)
        _ = self.forward(x)
        """calculate the gradient with respect to gamma and beta"""
        dgamma = self.xhat * dy
        dgamma = dgamma.sum(axis = 0)
        assert dgamma.shape == (D_,)
        dbeta = dy.sum(axis = 0)
        assert dbeta.shape == (D_,)
        """calculate the gradient of x"""
        dx = xp.ones(x.shape) * xp.nan
        for (i, dyi) in enumerate(dy):
            dx[i] = dyi - 1 / N_ * (dbeta + self.xhat[i] * dgamma)
            dx[i] = self.gamma * self.saved_inv_var * dx[i]
        return dx, dgamma, dbeta

if __name__ == '__main__':
    main()

結果

x: 
 [[0.90067645 0.05195989 0.94890906 0.39859517]
 [0.44419045 0.96537073 0.66998211 0.3254723 ]
 [0.14176828 0.99931822 0.0932397  0.6887744 ]
 ...
 [0.81677439 0.79172529 0.97273544 0.3386522 ]
 [0.13069407 0.1792883  0.68598364 0.13730587]
 [0.75877388 0.91132649 0.84422827 0.34956074]]
dy: 
 [[0.68581415 0.60774472 0.17332992 0.2146162 ]
 [0.85103707 0.56038386 0.85837978 0.37975325]
 [0.67934268 0.75050501 0.71611867 0.37727164]
 ...
 [0.19264311 0.25868408 0.91539923 0.47523352]
 [0.62664839 0.40625006 0.47672807 0.93141321]
 [0.53504761 0.53021432 0.56164742 0.66057609]]
dx1 - dx0: 
 [[-0.01040809  0.00645451 -0.03261807 -0.00284718]
 [ 0.00159596 -0.00626283 -0.01279289 -0.00477366]
 [ 0.00954865 -0.00673548  0.02820001  0.00479781]
 ...
 [-0.00820175 -0.00384518 -0.03431157 -0.00442643]
 [ 0.00983987  0.00468173 -0.01393022 -0.00973105]
 [-0.00667653 -0.00551038 -0.02517772 -0.00413903]]
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした