バッチ正規化 及び レイヤー正規化
最近、バッチ正規化とレイヤー正規化について勉強して、実際に微分して実装を行なったので供養します。具体的にバッチ正規化やレイヤー正規化が何かを知りたい場合は、論文を読むか、他の方の記事を参考にしてください。
それぞれの正規化の概略だけ
バッチ正規化は、各レイヤーで起きる共変量シフトによって、次の層の入力の分布が偏ってしまい、学習がうまくいかなくなることがあるという問題に対応するために用いるテクニックで、通常は活性化関数に噛ませる前に正規化を行います。
ただし、バッチサイズが小さいと、正規化中に取る平均 $\mu$ 及び分散 $\sigma$ が偏った推定値になりかねないので、性能が良くないかもしれないという欠点があります。
一方で、レイヤー正規化はこのバッチ正規化の短所を改善するように作られました。具体的には、平均や分散を取る方向を変更することによって、バッチサイズが小さくても上手く動きます。この改良によって生じる短所は、平均等をとる方向の次元が小さかった場合に、また偏った値になりかねないことです。
正規化の数式
バッチ正規化の論文1から取って来ました。これは、ミニバッチのある次元を固定した時の値です。$x_i$ は、例えば画像なら $32$pixel $\times$ $32$pixel $\times$ $3$(RGB) $= 3072 := D$ もの情報があるので、$3072$ 要素の横ベクトルになります。 バッチは、この横ベクトルを縦に $N$ 個重ねた $N \times D$ 行列 $ X $ で表されます。特に、行列の各列に対して平均を取って、標準偏差を計算して...と言っています。
一方で、レイヤー正規化の場合は、各行に対して平均を取ることになるわけです。
微分について
$\gamma$ と $ \beta $ についての微分はほとんど自明なので、省略します。
$$
\begin{align}
& \mu=\frac{1}{N}\sum_{k=1}^N x_k & v=\frac{1}{N}\sum_{k=1}^N (x_k-\mu)^2 \\
& \sigma=\sqrt{v+\epsilon} & y_i=\frac{x_i-\mu}{\sigma}
\end{align}
$$
行列計算が絡んでくるので、一息に求めたい表現を見つけようとするのは難しいです。なので行列 $X$ から一部切り出して、その挙動を考えることにします。
バッチ正規化では $X$ の列 $D = 1$ の場合の $X$ を計算することを考えます(つまり、$x_i$は長さ3072のベクトルの要素)。一方で、レイヤー正規化では転置を行なった $ D\times N$ 行列 $X^T$ における $N = 1$ の場合を考えることにします。
このように転置する理由は、平均などを取る方向を縦(axis=0)に揃えられるからです。
いずれの場合も、縦ベクトルについて考えるんだなあと思っておけば良いです。
さて、$\frac{\partial L}{\partial x_i}$ を計算することを考えます。$L$は損失のLossです。数式を打ち込むのはとても面倒くさいので、手計算を載せます(ごめんなさい)。以下の画像の最後に、 if bn というバッチ正規化専用の等式がありますが、それ以外の部分はバッチ正規化でもレイヤー正規化でも数式上はほとんど同じです。
この数式を追う上で必ず注意して欲しいことは、次の2つだけです。
- レイヤー正規化を考えている時でも $N\times D$ 行列 $X$ は転置するので $x$ は長さ $D$ の縦ベクトルではあるが、縦の長さを $N$ と書き表している。(こうすることで、バッチ正規化と式が同じになる)
- バッチ正規化では、$\gamma$ や $\beta$ は $X$ の各列について同じ値を使う。レイヤー正規化では $X$ の各行について同じ値を使うため、$\gamma_i$ などと表記して別に扱う。
$\frac{\partial L}{\partial x_i}$ が計算できたら、これを縦に並べることで $X$ もしくは $X^T$ の列についての微分がそれぞれ求められたことになります。あとは、これを実装すると次元がうまく噛み合って $\frac{\partial L}{\partial X}$ が求まってしまいます。
バッチ正規化のコード
import numpy as np
def batch_normalization_forward(X, gamma, beta, is_batch_normalization):
'''
X: (N, D)
gamma: (D, )
beta: (D, )
'''
eps = 1e-5
mean = np.mean(x, axis=0) #(D, )
var = np.var(x, axis=0) #(D, )
std = np.sqrt(var + eps) #(D, )
z = (x - mean) / std #(N, D)
out = z * gamma + beta #(N, D)
cache = {
'x': x,
'gamma': gamma,
'beta': beta,
'eps': eps,
'mean': mean,
'var': var,
'std': std,
'z': z,
'axis': 0 if is_batch_normalization else 1
}
return out, cache
def batch_normalization_backward(dL, cache):
z = cache['z']
gamma = cache['gamma']
std = cache['std']
N, D = cache['x'].shape
ax = cache['axis']
dbeta = np.sum(dL, axis=ax) # dL/dbeta of shape (D, )
dgamma = np.sum(dL * z, axis=ax) #(D, ) # dL/dgamma of shape (D, )
gL = gamma * dL #(N, D)
inparen = gout - (np.sum(gL, axis=0) + z * np.sum(gL * z, axis=0)) / N
dx = inparen / std
# または、画像の最後の行を利用すると次のようにも書ける
# レイヤー正規化の際に使うのはダメ。
# in_paren = dout - (dbeta + z * dgamma) / N
# dx = gamma / std * in_paren
return dx, dgamma, dbeta
先の説明では特に触れませんでしたが、$\frac{\partial L}{\partial \gamma}$ や $\frac{\partial L}{\partial \beta}$ を求める際に取るべき和の方向には気をつけてください。$\gamma$ や $\beta$ がどのようにレイヤーの出力に寄与しているかを考えると良いです。
レイヤー正規化のコード
バッチ正規化のコードを使って、次のようにかけます。
def layer_normalization_forward(x, gamma, beta):
'''
X: (N, D)
gamma: (D, )
beta: (D, )
'''
out, cache = batch_normalization_forward(x.T, gamma.reshape((-1, 1)), beta.reshape((-1, 1)), False)
out = out.T
return out, cache
def layer_normalization_backward(dL, cache):
dx, dgamma, dbeta = batch_normalization_backward(dL.T, cache)
dx = dx.T
return dx, dgamma, dbeta
終わりに
微分の周りの話は、自分で計算しながらやってみないとかなり難しいと思います。転置などが絡んだ時にどういう様子になるのかイメージしながら計算してみてください。一旦、バッチ正規化のコードが書ければ、レイヤー正規化も簡単にかけることがわかったと思います。