Help us understand the problem. What is going on with this article?

# 【補足】Batch Normalization の理解

More than 1 year has passed since last update.

# TL;DR

\begin{align}
d_x(N,D)&=
d_{3a}
-
\frac{1}{N}
\sum ^N _{i=1}
d_{3a}
\\d_{3a}&=
f_7
\circ
\left(
d_{12a}
-
f_3
\circ
\frac{1}{f_5 + \epsilon}
\circ
\frac{1}{N}
\sum_{j=1}^{N}(
f_3
\circ
d_{12a})
\right)
\\d_γ(D,)&=
\sum ^N _{i=1}(
\hat{x_i}
\circ
d _{out})
\\d_β(D,)&=
\sum ^N _{i=1}
d_{out}
\end{align}

\begin{align}
d_x(N,D)&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-
\frac{1}{N}
\left(
d_{\beta}
+
f_9
\circ
d_{\gamma}
\right)
\right)
\\&=
\gamma
\circ
\frac{1}{\sqrt{\sigma + \epsilon}}
\circ
\left(
d_{out}
-
\frac{1}{N}
\left(
d_{\beta}
+
\frac
{x_i-\mu}
{\sqrt{\sigma + \epsilon}}
\circ
d_{\gamma}
\right)
\right)
\end{align}

# 数式展開

\begin{align}
d_x(N,D)&=
d_{3a}
-
\frac{1}{N}
\sum^N _{i=1}
d_{3a}
\\&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
-\frac{1}{N}
\sum_{i=1}^{N}\left(
\gamma
\circ
f_7
\circ
\left(
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{k=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\right)
\\
&=
\gamma
\circ
f_7
\circ
\left[
d_{out}
-\frac{1}{N}
\sum_{i=1}^{N}
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\\+
\frac{1}{N}
\sum_{i=1}^{N}\left(
f_9
\circ
\frac{1}{N}
\sum_{k=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\right]
\\&=
\gamma
\circ
f_7
\circ \left[
d_{out}
-\frac{1}{N}
\sum_{i=1}^{N}
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
+0
\right]
\\&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-
\frac{1}{N}\left(
d_{\beta}
+
f_9
\circ
d_{\gamma}
\right)
\right)
\\
\because d_{3a}&=
f_7
\circ
\left(
d_{12a}
-f_3
\circ
\frac{1}{f_5 + \epsilon}
\circ
\frac{1}{N}
\sum_{j=1}^{N}(
f_3
\circ
d_{12a}
)
\right)
\\&=
f_7
\circ
\left(
d_{out}
\circ
\gamma
-
\frac
{x_i - \mu}
{\sqrt{\sigma + \epsilon}}
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\frac
{\bf{x_j}- \mu}
{\sqrt{\sigma + \epsilon}}
\circ
\bf{d_{out}}
\circ
\gamma
\right)
\right)
\\&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)

\\ \because
\frac{1}{N}&
\sum_{i=1}^{N}\left(
f_9
\circ
\frac{1}{N}
\sum_{k=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\\&=
\frac{1}{N}
\sum_{i=1}^{N}\left(
f_9
\circ
\frac{1}{N}
d_{\gamma}
\right)
\\&=
\frac
{d_{\gamma}}
{N^2}
\circ
\sum_{i=1}^{N}\left(
f_9
\right)
\\&=
\frac
{d_{\gamma}}
{N^2}
\circ
\sum_{i=1}^{N}\left(
\frac
{x_i - \mu}
{\sqrt{\sigma + \epsilon}}
\right)
\\&=
\frac
{d_{\gamma}}
{N^2\sqrt{\sigma + \epsilon}}
\circ
\sum_{i=1}^{N}
\left(x_i - \mu \right)
\\&=
\frac
{d_{\gamma}}
{N^2\sqrt{\sigma + \epsilon}}
\circ
0
= 0
\end{align}

# python実装による検証

## コード

"""
Created on Mon Nov 12 02:47:15 2018

@author: ryotakata
"""
USE_GPU = True

if USE_GPU:
import cupy as xp
from cupy.random import *
else:
import numpy as xp
from numpy.random import *

def main():
size = (512, 4) # batch_size = 512
bn = BatchNorm(size)
x = random(size)
dy = random(size)
dx0, _, _ = bn.backward_impl0(dy, x)
dx1, _, _ = bn.backward_impl1(dy, x)
print ("x: \n", x)
print ("dy: \n", dy)
print ("dx1 - dx0: \n", (dx1 - dx0))
return

class BatchNorm(object):
"""size = (batch_size, n_model)"""
def __init__(self,size, eps = 2e-05):
assert len(size) == 2
self.size = size
self.eps = eps
N_, D_ = size
self.xhat = xp.ones(size,dtype=float) * xp.nan
self.gamma = random(D_)
self.beta = random(D_)

def forward(self,x):
N_, D_ = self.size
assert x.ndim == 2
assert x.shape == (N_, D_)
self.saved_mean = x.mean(axis = 0)
assert self.saved_mean.shape == (D_,)
self.saved_inv_var = 1./ xp.sqrt(x.var(axis = 0) + self.eps)
assert self.saved_inv_var.shape == (D_,)
y = xp.ones(x.shape) * xp.nan
self.xhat = xp.ones(x.shape) * xp.nan
for (i, xi) in enumerate(x):
self.xhat[i] = (xi - self.saved_mean) / self.saved_inv_var
y[i] = self.gamma * self.xhat[i] + self.beta
return y

def backward_impl0(self, dy, x):
N_, D_ = self.size
assert dy.ndim == x.ndim == 2
assert dy.shape == x.shape == (N_, D_)
_ = self.forward(x)
"""calculate the gradient with respect to gamma and beta"""
dgamma = self.xhat * dy
dgamma = dgamma.sum(axis = 0)
assert dgamma.shape == (D_,)
dbeta = dy.sum(axis = 0)
assert dbeta.shape == (D_,)
"""f3, f5, f7, d12a"""
f7 = self.saved_inv_var
f3 = xp.ones(x.shape) * xp.nan
for (i, xi) in enumerate(x):
f3[i] = xi - self.saved_mean
d12a = xp.ones(dy.shape) * xp.nan
for (i, dyi) in enumerate(dy):
d12a[i] = self.gamma * dyi
inv_f5_eps = f7 ** 2
assert inv_f5_eps.shape == (D_,)
"""d3a"""
d3a = xp.ones(dy.shape) * xp.nan
tmp_ = f3 * d12a
assert tmp_.shape == (N_, D_)
tmp_ = tmp_.sum(axis = 0) / N_
assert tmp_.shape == (D_,)
for i in range(N_):
d3a[i] = f7 * (d12a[i] - f3[i] * inv_f5_eps * tmp_)
"""dx"""
dx = xp.ones(x.shape) * xp.nan
tmp_ = d3a.sum(axis = 0)
assert tmp_.shape == (D_,)
for (i, d3ai) in enumerate(d3a):
dx[i] = d3ai - 1 / N_ * tmp_
return dx, dgamma, dbeta

def backward_impl1(self, dy, x):
N_, D_ = self.size
assert dy.ndim == x.ndim == 2
assert dy.shape == x.shape == (N_, D_)
_ = self.forward(x)
"""calculate the gradient with respect to gamma and beta"""
dgamma = self.xhat * dy
dgamma = dgamma.sum(axis = 0)
assert dgamma.shape == (D_,)
dbeta = dy.sum(axis = 0)
assert dbeta.shape == (D_,)
dx = xp.ones(x.shape) * xp.nan
for (i, dyi) in enumerate(dy):
dx[i] = dyi - 1 / N_ * (dbeta + self.xhat[i] * dgamma)
dx[i] = self.gamma * self.saved_inv_var * dx[i]
return dx, dgamma, dbeta

if __name__ == '__main__':
main()

## 結果

x:
[[0.90067645 0.05195989 0.94890906 0.39859517]
[0.44419045 0.96537073 0.66998211 0.3254723 ]
[0.14176828 0.99931822 0.0932397  0.6887744 ]
...
[0.81677439 0.79172529 0.97273544 0.3386522 ]
[0.13069407 0.1792883  0.68598364 0.13730587]
[0.75877388 0.91132649 0.84422827 0.34956074]]
dy:
[[0.68581415 0.60774472 0.17332992 0.2146162 ]
[0.85103707 0.56038386 0.85837978 0.37975325]
[0.67934268 0.75050501 0.71611867 0.37727164]
...
[0.19264311 0.25868408 0.91539923 0.47523352]
[0.62664839 0.40625006 0.47672807 0.93141321]
[0.53504761 0.53021432 0.56164742 0.66057609]]
dx1 - dx0:
[[-0.01040809  0.00645451 -0.03261807 -0.00284718]
[ 0.00159596 -0.00626283 -0.01279289 -0.00477366]
[ 0.00954865 -0.00673548  0.02820001  0.00479781]
...
[-0.00820175 -0.00384518 -0.03431157 -0.00442643]
[ 0.00983987  0.00468173 -0.01393022 -0.00973105]
[-0.00667653 -0.00551038 -0.02517772 -0.00413903]]
Why not register and get more from Qiita?
1. We will deliver articles that match you
By following users and tags, you can catch up information on technical fields that you are interested in as a whole
2. you can read useful information later efficiently
By "stocking" the articles you like, you can search right away