TL;DR
本記事は、@t-tkd3aさんが書かれたBatch Normalization の理解の補足です。
上記記事の誤差逆伝播計算式は以下の通りですが、
\begin{align}
d_x(N,D)&=
d_{3a}
-
\frac{1}{N}
\sum ^N _{i=1}
d_{3a}
\\d_{3a}&=
f_7
\circ
\left(
d_{12a}
-
f_3
\circ
\frac{1}{f_5 + \epsilon}
\circ
\frac{1}{N}
\sum_{j=1}^{N}(
f_3
\circ
d_{12a})
\right)
\\d_γ(D,)&=
\sum ^N _{i=1}(
\hat{x_i}
\circ
d _{out})
\\d_β(D,)&=
\sum ^N _{i=1}
d_{out}
\end{align}
更に式展開を進めると以下のようになります。
\begin{align}
d_x(N,D)&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-
\frac{1}{N}
\left(
d_{\beta}
+
f_9
\circ
d_{\gamma}
\right)
\right)
\\&=
\gamma
\circ
\frac{1}{\sqrt{\sigma + \epsilon}}
\circ
\left(
d_{out}
-
\frac{1}{N}
\left(
d_{\beta}
+
\frac
{x_i-\mu}
{\sqrt{\sigma + \epsilon}}
\circ
d_{\gamma}
\right)
\right)
\end{align}
一方で、上記2つのdxが等価であることをcupy実装で検証したところ、両者の計算結果に予想以上に大きな差が出ています。
単に計算誤差なのか、コードのバグか、それとも数式が間違っているのかは検証中です。
数式展開
\begin{align}
d_x(N,D)&=
d_{3a}
-
\frac{1}{N}
\sum^N _{i=1}
d_{3a}
\\&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\\&\quad
-\frac{1}{N}
\sum_{i=1}^{N}\left(
\gamma
\circ
f_7
\circ
\left(
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{k=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\right)
\\
&=
\gamma
\circ
f_7
\circ
\left[
d_{out}
-\frac{1}{N}
\sum_{i=1}^{N}
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\\+
\frac{1}{N}
\sum_{i=1}^{N}\left(
f_9
\circ
\frac{1}{N}
\sum_{k=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\right]
\\&=
\gamma
\circ
f_7
\circ \left[
d_{out}
-\frac{1}{N}
\sum_{i=1}^{N}
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
+0
\right]
\\&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-
\frac{1}{N}\left(
d_{\beta}
+
f_9
\circ
d_{\gamma}
\right)
\right)
\\
\because d_{3a}&=
f_7
\circ
\left(
d_{12a}
-f_3
\circ
\frac{1}{f_5 + \epsilon}
\circ
\frac{1}{N}
\sum_{j=1}^{N}(
f_3
\circ
d_{12a}
)
\right)
\\&=
f_7
\circ
\left(
d_{out}
\circ
\gamma
-
\frac
{x_i - \mu}
{\sqrt{\sigma + \epsilon}}
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\frac
{\bf{x_j}- \mu}
{\sqrt{\sigma + \epsilon}}
\circ
\bf{d_{out}}
\circ
\gamma
\right)
\right)
\\&=
\gamma
\circ
f_7
\circ
\left(
d_{out}
-f_9
\circ
\frac{1}{N}
\sum_{j=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\\ \because
\frac{1}{N}&
\sum_{i=1}^{N}\left(
f_9
\circ
\frac{1}{N}
\sum_{k=1}^{N}\left(
\bf{f_9}
\circ
\bf{d_{out}}
\right)
\right)
\\&=
\frac{1}{N}
\sum_{i=1}^{N}\left(
f_9
\circ
\frac{1}{N}
d_{\gamma}
\right)
\\&=
\frac
{d_{\gamma}}
{N^2}
\circ
\sum_{i=1}^{N}\left(
f_9
\right)
\\&=
\frac
{d_{\gamma}}
{N^2}
\circ
\sum_{i=1}^{N}\left(
\frac
{x_i - \mu}
{\sqrt{\sigma + \epsilon}}
\right)
\\&=
\frac
{d_{\gamma}}
{N^2\sqrt{\sigma + \epsilon}}
\circ
\sum_{i=1}^{N}
\left(x_i - \mu \right)
\\&=
\frac
{d_{\gamma}}
{N^2\sqrt{\sigma + \epsilon}}
\circ
0
= 0
\end{align}
python実装による検証
コード
"""
Created on Mon Nov 12 02:47:15 2018
@author: ryotakata
"""
USE_GPU = True
if USE_GPU:
import cupy as xp
from cupy.random import *
else:
import numpy as xp
from numpy.random import *
def main():
size = (512, 4) # batch_size = 512
bn = BatchNorm(size)
x = random(size)
dy = random(size)
dx0, _, _ = bn.backward_impl0(dy, x)
dx1, _, _ = bn.backward_impl1(dy, x)
print ("x: \n", x)
print ("dy: \n", dy)
print ("dx1 - dx0: \n", (dx1 - dx0))
return
class BatchNorm(object):
"""size = (batch_size, n_model)"""
def __init__(self,size, eps = 2e-05):
assert len(size) == 2
self.size = size
self.eps = eps
N_, D_ = size
self.xhat = xp.ones(size,dtype=float) * xp.nan
self.gamma = random(D_)
self.beta = random(D_)
def forward(self,x):
N_, D_ = self.size
assert x.ndim == 2
assert x.shape == (N_, D_)
self.saved_mean = x.mean(axis = 0)
assert self.saved_mean.shape == (D_,)
self.saved_inv_var = 1./ xp.sqrt(x.var(axis = 0) + self.eps)
assert self.saved_inv_var.shape == (D_,)
y = xp.ones(x.shape) * xp.nan
self.xhat = xp.ones(x.shape) * xp.nan
for (i, xi) in enumerate(x):
self.xhat[i] = (xi - self.saved_mean) / self.saved_inv_var
y[i] = self.gamma * self.xhat[i] + self.beta
return y
def backward_impl0(self, dy, x):
N_, D_ = self.size
assert dy.ndim == x.ndim == 2
assert dy.shape == x.shape == (N_, D_)
_ = self.forward(x)
"""calculate the gradient with respect to gamma and beta"""
dgamma = self.xhat * dy
dgamma = dgamma.sum(axis = 0)
assert dgamma.shape == (D_,)
dbeta = dy.sum(axis = 0)
assert dbeta.shape == (D_,)
"""calculate the gradient of x"""
"""f3, f5, f7, d12a"""
f7 = self.saved_inv_var
f3 = xp.ones(x.shape) * xp.nan
for (i, xi) in enumerate(x):
f3[i] = xi - self.saved_mean
d12a = xp.ones(dy.shape) * xp.nan
for (i, dyi) in enumerate(dy):
d12a[i] = self.gamma * dyi
inv_f5_eps = f7 ** 2
assert inv_f5_eps.shape == (D_,)
"""d3a"""
d3a = xp.ones(dy.shape) * xp.nan
tmp_ = f3 * d12a
assert tmp_.shape == (N_, D_)
tmp_ = tmp_.sum(axis = 0) / N_
assert tmp_.shape == (D_,)
for i in range(N_):
d3a[i] = f7 * (d12a[i] - f3[i] * inv_f5_eps * tmp_)
"""dx"""
dx = xp.ones(x.shape) * xp.nan
tmp_ = d3a.sum(axis = 0)
assert tmp_.shape == (D_,)
for (i, d3ai) in enumerate(d3a):
dx[i] = d3ai - 1 / N_ * tmp_
return dx, dgamma, dbeta
def backward_impl1(self, dy, x):
N_, D_ = self.size
assert dy.ndim == x.ndim == 2
assert dy.shape == x.shape == (N_, D_)
_ = self.forward(x)
"""calculate the gradient with respect to gamma and beta"""
dgamma = self.xhat * dy
dgamma = dgamma.sum(axis = 0)
assert dgamma.shape == (D_,)
dbeta = dy.sum(axis = 0)
assert dbeta.shape == (D_,)
"""calculate the gradient of x"""
dx = xp.ones(x.shape) * xp.nan
for (i, dyi) in enumerate(dy):
dx[i] = dyi - 1 / N_ * (dbeta + self.xhat[i] * dgamma)
dx[i] = self.gamma * self.saved_inv_var * dx[i]
return dx, dgamma, dbeta
if __name__ == '__main__':
main()
結果
x:
[[0.90067645 0.05195989 0.94890906 0.39859517]
[0.44419045 0.96537073 0.66998211 0.3254723 ]
[0.14176828 0.99931822 0.0932397 0.6887744 ]
...
[0.81677439 0.79172529 0.97273544 0.3386522 ]
[0.13069407 0.1792883 0.68598364 0.13730587]
[0.75877388 0.91132649 0.84422827 0.34956074]]
dy:
[[0.68581415 0.60774472 0.17332992 0.2146162 ]
[0.85103707 0.56038386 0.85837978 0.37975325]
[0.67934268 0.75050501 0.71611867 0.37727164]
...
[0.19264311 0.25868408 0.91539923 0.47523352]
[0.62664839 0.40625006 0.47672807 0.93141321]
[0.53504761 0.53021432 0.56164742 0.66057609]]
dx1 - dx0:
[[-0.01040809 0.00645451 -0.03261807 -0.00284718]
[ 0.00159596 -0.00626283 -0.01279289 -0.00477366]
[ 0.00954865 -0.00673548 0.02820001 0.00479781]
...
[-0.00820175 -0.00384518 -0.03431157 -0.00442643]
[ 0.00983987 0.00468173 -0.01393022 -0.00973105]
[-0.00667653 -0.00551038 -0.02517772 -0.00413903]]