LoginSignup
3

More than 5 years have passed since last update.

【補足】Batch Normalization の理解

Posted at

TL;DR

本記事は、@t-tkd3aさんが書かれたBatch Normalization の理解の補足です。

上記記事の誤差逆伝播計算式は以下の通りですが、

\begin{align}
d_x(N,D)&=
d_{3a} 
-
\frac{1}{N}
\sum ^N _{i=1} 
  d_{3a}
\\d_{3a}&=
f_7 
\circ
\left(
  d_{12a}
  -
  f_3
  \circ
  \frac{1}{f_5 + \epsilon}
  \circ 
  \frac{1}{N}
  \sum_{j=1}^{N}(
     f_3 
     \circ
     d_{12a})
\right)
\\d_γ(D,)&=
\sum ^N _{i=1}(
  \hat{x_i} 
  \circ 
  d _{out})
\\d_β(D,)&=
\sum ^N _{i=1} 
  d_{out}
\end{align}

更に式展開を進めると以下のようになります。

\begin{align}
d_x(N,D)&=
\gamma
\circ
f_7
\circ
\left( 
  d_{out}
  -
  \frac{1}{N}
    \left(
      d_{\beta} 
      +
      f_9
      \circ
      d_{\gamma}
    \right)
\right)
\\&=
\gamma
\circ
\frac{1}{\sqrt{\sigma + \epsilon}}
  \circ
  \left( 
    d_{out}
    -
    \frac{1}{N}
      \left(
        d_{\beta}
        +
        \frac
          {x_i-\mu}
          {\sqrt{\sigma + \epsilon}}
       \circ
       d_{\gamma}
  \right)
\right)  
\end{align}

一方で、上記2つのdxが等価であることをcupy実装で検証したところ、両者の計算結果に予想以上に大きな差が出ています。
単に計算誤差なのか、コードのバグか、それとも数式が間違っているのかは検証中です。

数式展開

\begin{align}
d_x(N,D)&=
d_{3a}
-
\frac{1}{N}
\sum^N _{i=1}
d_{3a} 
\\&=
\gamma
\circ 
f_7 
\circ 
\left(
  d_{out}
  -f_9
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \bf{f_9}
    \circ
    \bf{d_{out}}
  \right)
\right)
\\&\quad
-\frac{1}{N}
\sum_{i=1}^{N}\left(
  \gamma
  \circ
  f_7
  \circ
  \left(
    d_{out}
    -f_9
    \circ
    \frac{1}{N}
    \sum_{k=1}^{N}\left(
      \bf{f_9} 
      \circ
      \bf{d_{out}}
      \right)
  \right)
\right)
\\
&=
\gamma
\circ
f_7 
\circ 
\left[
  d_{out}
  -\frac{1}{N}
  \sum_{i=1}^{N}
    d_{out}
    -f_9
    \circ
    \frac{1}{N}
    \sum_{j=1}^{N}\left(
      \bf{f_9} 
      \circ 
      \bf{d_{out}}
    \right)
    \\+ 
    \frac{1}{N}
    \sum_{i=1}^{N}\left(
      f_9 
      \circ 
      \frac{1}{N}
      \sum_{k=1}^{N}\left(
        \bf{f_9} 
        \circ 
        \bf{d_{out}}
    \right)
  \right)
\right]
\\&=
\gamma
\circ 
f_7 
\circ \left[
  d_{out}
  -\frac{1}{N}
  \sum_{i=1}^{N}
  d_{out}
  -f_9 
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \bf{f_9} 
    \circ
    \bf{d_{out}}
  \right)
  +0
\right]
\\&=
\gamma
\circ
f_7 
\circ 
\left( 
  d_{out}
  -
  \frac{1}{N}\left(
    d_{\beta}
    +
    f_9
    \circ 
    d_{\gamma}
  \right)
\right) 
\\
\because d_{3a}&=
f_7
\circ 
\left(
  d_{12a}
  -f_3
  \circ
  \frac{1}{f_5 + \epsilon}
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}(
    f_3
    \circ
    d_{12a}
  )
\right)
\\&=
f_7 
\circ 
\left(
  d_{out}
  \circ
  \gamma 
  -
  \frac
    {x_i - \mu}
    {\sqrt{\sigma + \epsilon}} 
  \circ 
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \frac
      {\bf{x_j}- \mu}
      {\sqrt{\sigma + \epsilon}} 
    \circ 
    \bf{d_{out}}
    \circ
    \gamma
  \right)
\right)
\\&= 
\gamma
\circ 
f_7 
\circ 
\left(
  d_{out}
  -f_9 
  \circ
  \frac{1}{N}
  \sum_{j=1}^{N}\left(
    \bf{f_9}
    \circ
    \bf{d_{out}}
  \right)
\right)

\\ \because 
\frac{1}{N}&
\sum_{i=1}^{N}\left(
  f_9
  \circ
  \frac{1}{N}
  \sum_{k=1}^{N}\left(
    \bf{f_9}
    \circ
    \bf{d_{out}}
  \right)
\right)
\\&=
\frac{1}{N}
\sum_{i=1}^{N}\left(
  f_9 
  \circ
  \frac{1}{N}
  d_{\gamma}
\right)
\\&=
\frac
  {d_{\gamma}}
  {N^2}
\circ
\sum_{i=1}^{N}\left(
  f_9 
\right)
\\&=
\frac
  {d_{\gamma}}
  {N^2}
\circ
\sum_{i=1}^{N}\left(
  \frac
    {x_i - \mu}
    {\sqrt{\sigma + \epsilon}}
\right)
\\&=
\frac
  {d_{\gamma}}
  {N^2\sqrt{\sigma + \epsilon}}
\circ
\sum_{i=1}^{N}
  \left(x_i - \mu \right)
\\&=
\frac
  {d_{\gamma}}
  {N^2\sqrt{\sigma + \epsilon}}
\circ
0 
= 0
\end{align}

python実装による検証

コード

"""
Created on Mon Nov 12 02:47:15 2018

@author: ryotakata
"""
USE_GPU = True

if USE_GPU:
    import cupy as xp
    from cupy.random import *
else:
    import numpy as xp
    from numpy.random import *

def main():
    size = (512, 4) # batch_size = 512
    bn = BatchNorm(size)
    x = random(size)
    dy = random(size)
    dx0, _, _ = bn.backward_impl0(dy, x)
    dx1, _, _ = bn.backward_impl1(dy, x)
    print ("x: \n", x)
    print ("dy: \n", dy)
    print ("dx1 - dx0: \n", (dx1 - dx0))
    return

class BatchNorm(object):
    """size = (batch_size, n_model)"""
    def __init__(self,size, eps = 2e-05):
        assert len(size) == 2
        self.size = size
        self.eps = eps
        N_, D_ = size
        self.xhat = xp.ones(size,dtype=float) * xp.nan
        self.gamma = random(D_)
        self.beta = random(D_)

    def forward(self,x):
        N_, D_ = self.size
        assert x.ndim == 2
        assert x.shape == (N_, D_)
        self.saved_mean = x.mean(axis = 0)
        assert self.saved_mean.shape == (D_,)
        self.saved_inv_var = 1./ xp.sqrt(x.var(axis = 0) + self.eps)
        assert self.saved_inv_var.shape == (D_,)
        y = xp.ones(x.shape) * xp.nan
        self.xhat = xp.ones(x.shape) * xp.nan
        for (i, xi) in enumerate(x):
            self.xhat[i] = (xi - self.saved_mean) / self.saved_inv_var
            y[i] = self.gamma * self.xhat[i] + self.beta
        return y

    def backward_impl0(self, dy, x):
        N_, D_ = self.size
        assert dy.ndim == x.ndim == 2
        assert dy.shape == x.shape == (N_, D_)
        _ = self.forward(x)
        """calculate the gradient with respect to gamma and beta"""
        dgamma = self.xhat * dy
        dgamma = dgamma.sum(axis = 0)
        assert dgamma.shape == (D_,)
        dbeta = dy.sum(axis = 0)
        assert dbeta.shape == (D_,)
        """calculate the gradient of x"""
        """f3, f5, f7, d12a"""
        f7 = self.saved_inv_var
        f3 = xp.ones(x.shape) * xp.nan
        for (i, xi) in enumerate(x):
            f3[i] = xi - self.saved_mean
        d12a = xp.ones(dy.shape) * xp.nan
        for (i, dyi) in enumerate(dy):
            d12a[i] = self.gamma * dyi
        inv_f5_eps = f7 ** 2
        assert inv_f5_eps.shape == (D_,)
        """d3a"""
        d3a = xp.ones(dy.shape) * xp.nan
        tmp_ = f3 * d12a
        assert tmp_.shape == (N_, D_)
        tmp_ = tmp_.sum(axis = 0) / N_
        assert tmp_.shape == (D_,)
        for i in range(N_):
            d3a[i] = f7 * (d12a[i] - f3[i] * inv_f5_eps * tmp_)
        """dx"""
        dx = xp.ones(x.shape) * xp.nan
        tmp_ = d3a.sum(axis = 0)
        assert tmp_.shape == (D_,)
        for (i, d3ai) in enumerate(d3a):
            dx[i] = d3ai - 1 / N_ * tmp_
        return dx, dgamma, dbeta

    def backward_impl1(self, dy, x):
        N_, D_ = self.size
        assert dy.ndim == x.ndim == 2
        assert dy.shape == x.shape == (N_, D_)
        _ = self.forward(x)
        """calculate the gradient with respect to gamma and beta"""
        dgamma = self.xhat * dy
        dgamma = dgamma.sum(axis = 0)
        assert dgamma.shape == (D_,)
        dbeta = dy.sum(axis = 0)
        assert dbeta.shape == (D_,)
        """calculate the gradient of x"""
        dx = xp.ones(x.shape) * xp.nan
        for (i, dyi) in enumerate(dy):
            dx[i] = dyi - 1 / N_ * (dbeta + self.xhat[i] * dgamma)
            dx[i] = self.gamma * self.saved_inv_var * dx[i]
        return dx, dgamma, dbeta

if __name__ == '__main__':
    main()

結果

x: 
 [[0.90067645 0.05195989 0.94890906 0.39859517]
 [0.44419045 0.96537073 0.66998211 0.3254723 ]
 [0.14176828 0.99931822 0.0932397  0.6887744 ]
 ...
 [0.81677439 0.79172529 0.97273544 0.3386522 ]
 [0.13069407 0.1792883  0.68598364 0.13730587]
 [0.75877388 0.91132649 0.84422827 0.34956074]]
dy: 
 [[0.68581415 0.60774472 0.17332992 0.2146162 ]
 [0.85103707 0.56038386 0.85837978 0.37975325]
 [0.67934268 0.75050501 0.71611867 0.37727164]
 ...
 [0.19264311 0.25868408 0.91539923 0.47523352]
 [0.62664839 0.40625006 0.47672807 0.93141321]
 [0.53504761 0.53021432 0.56164742 0.66057609]]
dx1 - dx0: 
 [[-0.01040809  0.00645451 -0.03261807 -0.00284718]
 [ 0.00159596 -0.00626283 -0.01279289 -0.00477366]
 [ 0.00954865 -0.00673548  0.02820001  0.00479781]
 ...
 [-0.00820175 -0.00384518 -0.03431157 -0.00442643]
 [ 0.00983987  0.00468173 -0.01393022 -0.00973105]
 [-0.00667653 -0.00551038 -0.02517772 -0.00413903]]

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3