はじめに
Normalizationについてよく見る下記の図について理解が不十分だったので、具体的な数値を使って検証した。
まとめ
Batch Normalization, Layer Normalization, Instance Normalizationの計算式の形は式(1)で変化しない。変化するのは平均と分散を計算する際の対象。
\begin{equation}
y = \frac{x - \text{E}[x]}{\sqrt{\text{Var}[x] + \epsilon}} \cdot \gamma + \beta \tag{1}
\end{equation}
$x$: 入力
$y$: 出力
$\epsilon$: 0割りを防止するための小さな値
$\gamma$: Weight(学習で算出)
$\beta$: Bias(学習で算出)
平均と分散の計算対象
- Batch Normalization: ミニバッチ内の同じチャンネルごとに空間次元(画像なら縦×横)で平均と分散を計算
- Layer Normalization: ミニバッチ内の全チャンネルで空間次元(画像なら縦×横)で平均と分散を計算
- Instance Normalization: ミニバッチ内の1バッチかつチャンネルごとの空間次元(画像なら縦×横)で平均と分散を計算
ただし、ここでは学習時における平均と分散を対象としている。
推論時も同様に平均と分散を計算するが、バッチサイズが1となり分散を計算できない等のため、移動平均を利用して更新する。今回はこの点についてまでは検証しない。
検証
入力データの形式をBCHW (Batch, Channel, Height, Width) にして、Batch Normalization、Layer Normalization、Instance Normalizationの違いを具体例で示す。この例では、バッチサイズが2で、各バッチに2チャンネルがあり、各チャンネルが2x2ピクセルで構成されているとする。
更に式(1)を単純化するため、$\epsilon=0$, $\gamma=1$, $\beta=0$とする。
入力データ(BCHW形式)
- バッチ 1:
チャンネル 1:
\begin{bmatrix}
55 & 4 \\
50 & 22
\end{bmatrix}
チャンネル 2:
\begin{bmatrix}
63 & 73 \\
44 & 7
\end{bmatrix}
- バッチ 2:
チャンネル 1:
\begin{bmatrix}
53 & 95 \\
64 & 8
\end{bmatrix}
チャンネル 2:
\begin{bmatrix}
0 & 15 \\
25 & 82
\end{bmatrix}
Batch Normalization
Batch Normalizationでは、ミニバッチ内の同じチャンネルごとに平均と分散を計算する。
平均の計算
- チャンネル 1の平均 ($\mu_{\text{ch1}}$)の計算
\begin{align}
\mu_{\text{ch1}} &= \frac{55 + 4 + 50 + 22 + 53 + 95 + 64 + 8}{8} \\
&= 43.875
\end{align}
- チャンネル 2の平均 ($\mu_{\text{ch2}}$)の計算
\begin{align}
\mu_{\text{ch2}} &= \frac{63 + 73 + 44 + 7 + 0 + 15 + 25 + 82}{8} \\
&= 38.625
\end{align}
分散の計算
- チャンネル 1の分散 ($\sigma^2_{\text{ch1}}$) の計算
\begin{align}
\sigma^2_{\text{ch1}} &= \frac{(55-43.875)^2 + (4-43.875)^2 + (50-43.875)^2 + (22-43.875)^2 + (53-43.875)^2 + (95-43.875)^2 + (64-43.875)^2 + (8-43.875)^2}{8} \\
&= 827.3594
\end{align}
- チャンネル 2の分散 ($\sigma^2_{\text{ch2}}$) の計算
\begin{align}
\sigma^2_{\text{ch2}} &= \frac{(63-38.625)^2 + (73-38.625)^2 + (44-38.625)^2 + (7-38.625)^2 + (0-38.625)^2 + (15-38.625)^2 + (25-38.625)^2 + (82-38.625)^2}{8} \\
&= 865.2344
\end{align}
正規化されたデータ
チャンネルごとに同じ平均と分散を利用して正規化する。
\begin{align*}
\text{Batch 1, Channel 1:} \quad &
\begin{bmatrix}
\frac{55 - 43.875}{\sqrt{827.3594}} & \frac{4 - 43.875}{\sqrt{827.3594}} \\
\frac{50 - 43.875}{\sqrt{827.3594}} & \frac{22 - 43.875}{\sqrt{827.3594}}
\end{bmatrix}
=
\begin{bmatrix}
0.3868 & -1.3863 \\
0.2129 & -0.7605
\end{bmatrix} \\
\text{Batch 1, Channel 2:} \quad &
\begin{bmatrix}
\frac{63 - 38.625}{\sqrt{865.2344}} & \frac{73 - 38.625}{\sqrt{865.2344}} \\
\frac{44 - 38.625}{\sqrt{865.2344}} & \frac{7 - 38.625}{\sqrt{865.2344}}
\end{bmatrix}
=
\begin{bmatrix}
0.8287 & 1.1686 \\
0.1827 & -1.0751
\end{bmatrix} \\
\text{Batch 2, Channel 1:} \quad &
\begin{bmatrix}
\frac{53 - 43.875}{\sqrt{827.3594}} & \frac{95 - 43.875}{\sqrt{827.3594}} \\
\frac{64 - 43.875}{\sqrt{827.3594}} & \frac{8 - 43.875}{\sqrt{827.3594}}
\end{bmatrix}
=
\begin{bmatrix}
0.3172 & 1.7774\\
0.6997 & -1.2472
\end{bmatrix} \\
\text{Batch 2, Channel 2:} \quad &
\begin{bmatrix}
\frac{0 - 38.625}{\sqrt{865.2344}} & \frac{15 - 38.625}{\sqrt{865.2344}} \\
\frac{25 - 38.625}{\sqrt{865.2344}} & \frac{82 - 38.625}{\sqrt{865.2344}}
\end{bmatrix}
=
\begin{bmatrix}
-1.3131 & -0.8032\\
-0.4632 & 1.4746
\end{bmatrix}
\end{align*}
コード
import torch
import torch.nn as nn
# Define the data as given in the BCHW format
data = torch.tensor([
[
[[55, 4], [50, 22]], # Batch 1, Channel 1
[[63, 73], [44, 7]] # Batch 1, Channel 2
],
[
[[53, 95], [64, 8]], # Batch 2, Channel 1
[[0, 15], [25, 82]] # Batch 2, Channel 2
]
], dtype=torch.float32)
# Normalization
bn = nn.BatchNorm2d(num_features=2, affine=False, eps=0, momentum=1) # No learnable parameters
# Apply Normalizations
batch_norm_output = bn(data)
print("Batch Normalization Output:", batch_norm_output)
出力結果
Batch Normalization Output:
tensor([[[[ 0.3868, -1.3863],
[ 0.2129, -0.7605]],
[[ 0.8287, 1.1686],
[ 0.1827, -1.0751]]],
[[[ 0.3172, 1.7774],
[ 0.6997, -1.2472]],
[[-1.3131, -0.8032],
[-0.4632, 1.4746]]]])
Layer Normalization
Layer Normalizationでは、ミニバッチ内の全チャンネルで平均と分散を計算する。
平均の計算
- バッチ 1の平均 ($\mu$)の計算
\begin{align}
\mu_{\text{batch1}} &= \frac{55 + 4 + 50 + 22 + 63 + 73 + 44 + 7}{8} \\
&= 39.75 \\
\end{align}
- バッチ 2の平均 ($\mu$)の計算
\begin{align}
\mu_{\text{batch2}} &= \frac{53 + 95 + 64 + 8 + 0 + 15 + 25 + 82}{8} \\
&= 42.75 \\
\end{align}
分散の計算
- バッチ 1の分散 ($\sigma^2$)の計算
\begin{align}
\sigma^2_{\text{batch1}} &= \frac{(55-39.75)^2 + (4-39.75)^2 + (50-39.75)^2 + (22-39.75)^2 + (63-39.75)^2 + (73-39.75)^2 + (44-39.75)^2 + (7-39.75)^2}{8} \\
&= 583.4375
\end{align}
- バッチ 2の分散 ($\sigma^2$)の計算
\begin{align}
\sigma^2_{\text{batch2}} &= \frac{(53-42.75)^2 + (95-42.75)^2 + (64-42.75)^2 + (8-42.75)^2 + (0-42.75)^2 + (15-42.75)^2 + (25-42.75)^2 + (82-42.75)^2}{8} \\
&= 1118.4375
\end{align}
正規化されたデータ
バッチごとに同じ平均と分散を利用して正規化する。
\begin{align*}
\text{Batch 1, Channel 1:} \quad &
\begin{bmatrix}
\frac{55 - 39.75}{\sqrt{583.4375}} & \frac{4 - 39.75}{\sqrt{583.4375}} \\
\frac{50 - 39.75}{\sqrt{583.4375}} & \frac{22 - 39.75}{\sqrt{583.4375}}
\end{bmatrix}
=
\begin{bmatrix}
0.6314 & -1.4801 \\
0.4244 & -0.7349
\end{bmatrix} \\
\text{Batch 1, Channel 2:} \quad &
\begin{bmatrix}
\frac{63 - 39.75}{\sqrt{583.4375}} & \frac{73 - 39.75}{\sqrt{583.4375}} \\
\frac{44 - 39.75}{\sqrt{583.4375}} & \frac{7 - 39.75}{\sqrt{583.4375}}
\end{bmatrix}
=
\begin{bmatrix}
0.9626 & 1.3766\\
0.1760 & -1.3559
\end{bmatrix} \\
\text{Batch 2, Channel 1:} \quad &
\begin{bmatrix}
\frac{53 - 42.75}{\sqrt{1118.4375}} & \frac{95 - 42.75}{\sqrt{1118.4375}} \\
\frac{64 - 42.75}{\sqrt{1118.4375}} & \frac{8 - 42.75}{\sqrt{1118.4375}}
\end{bmatrix}
=
\begin{bmatrix}
0.3065 & 1.5624\\
0.6354 & -1.0391
\end{bmatrix} \\
\text{Batch 2, Channel 2:} \quad &
\begin{bmatrix}
\frac{0 - 42.75}{\sqrt{1118.4375}} & \frac{15 - 42.75}{\sqrt{1118.4375}} \\
\frac{25 - 42.75}{\sqrt{1118.4375}} & \frac{82 - 42.75}{\sqrt{1118.4375}}
\end{bmatrix}
=
\begin{bmatrix}
-1.2783 & -0.8298\\
-0.5308 & 1.1736
\end{bmatrix}
\end{align*}
コード
import torch
import torch.nn as nn
# Define the data as given in the BCHW format
data = torch.tensor([
[
[[55, 4], [50, 22]], # Batch 1, Channel 1
[[63, 73], [44, 7]] # Batch 1, Channel 2
],
[
[[53, 95], [64, 8]], # Batch 2, Channel 1
[[0, 15], [25, 82]] # Batch 2, Channel 2
]
], dtype=torch.float32)
# Normalization
ln = nn.LayerNorm([2, 2, 2], elementwise_affine=False, eps=0) # Normalizing over channel dimension
# Apply Normalizations
layer_norm_output = ln(data)
print("Layer Normalization Output:", layer_norm_output)
出力結果
Layer Normalization Output:
tensor([[[[ 0.6314, -1.4801],
[ 0.4244, -0.7349]],
[[ 0.9626, 1.3766],
[ 0.1760, -1.3559]]],
[[[ 0.3065, 1.5624],
[ 0.6354, -1.0391]],
[[-1.2783, -0.8298],
[-0.5308, 1.1736]]]])
Instance Normalization
Instance Normalizationでは、ミニバッチ内の1バッチかつチャンネルごとに平均と分散を計算する。
平均の計算
- バッチ 1, チャンネル 1の平均 ($\mu$)の計算
\begin{aligned}
\mu_{\text{batch1, channel1}} &= \frac{55 + 4 + 50 + 22}{4} \\
&= 32.75
\end{aligned}
- バッチ 1, チャンネル 2の平均 ($\mu$)の計算
\begin{aligned}
\mu_{\text{batch1, channel2}} &= \frac{63 + 73 + 44 + 7}{4} \\
&= 46.75
\end{aligned}
- バッチ 2, チャンネル 1の平均 ($\mu$)の計算
\begin{aligned}
\mu_{\text{batch2, channel1}} &= \frac{53 + 95 + 64 + 8}{4} \\
&= 55
\end{aligned}
- バッチ 2, チャンネル 2の平均 ($\mu$)の計算
\begin{aligned}
\mu_{\text{batch2, channel2}} &= \frac{0 + 15 + 25 + 82}{4} \\
&= 30.5
\end{aligned}
分散の計算
- バッチ 1, チャンネル 1の分散 ($\sigma^2$)の計算
\begin{aligned}
\sigma^2_{\text{batch1, channel1}} &= \frac{(55-32.75)^2 + (4-32.75)^2 + (50-32.75)^2 + (22-32.75)^2}{4} \\
&= 578.25
\end{aligned}
- バッチ 1, チャンネル 2の分散 ($\sigma^2$)の計算
\begin{aligned}
\sigma^2_{\text{batch1, channel2}} &= \frac{(63-46.75)^2 + (73-46.75)^2 + (44-46.75)^2 + (7-46.75)^2}{4} \\
&= 846.9167
\end{aligned}
- バッチ 2, チャンネル 1の分散 ($\sigma^2$)の計算
\begin{aligned}
\sigma^2_{\text{batch2, channel1}} &= \frac{(53-55)^2 + (95-55)^2 + (64-55)^2 + (8-55)^2}{4} \\
&= 1298
\end{aligned}
- バッチ 2, チャンネル 2の分散 ($\sigma^2$)の計算
\begin{aligned}
\sigma^2_{\text{batch2, channel2}} &= \frac{(0-30.5)^2 + (15-30.5)^2 + (25-30.5)^2 + (82-30.5)^2}{4} \\
&= 1284.3334
\end{aligned}
正規化されたデータ
1バッチかつチャンネルごとに平均と分散を利用して正規化する。
\begin{aligned}
\text{Batch 1, Channel 1:} \quad &
\begin{bmatrix}
\frac{55 - 32.75}{\sqrt{578.25}} & \frac{4 - 32.75}{\sqrt{578.25}} \\
\frac{50 - 32.75}{\sqrt{578.25}} & \frac{22 - 32.75}{\sqrt{578.25}}
\end{bmatrix}
=
\begin{bmatrix}
1.0684 & -1.3805 \\
0.8283 & -0.5162
\end{bmatrix} \\
\text{Batch 1, Channel 2:} \quad &
\begin{bmatrix}
\frac{63 - 46.75}{\sqrt{846.9167}} & \frac{73 - 46.75}{\sqrt{846.9167}} \\
\frac{44 - 46.75}{\sqrt{846.9167}} & \frac{7 - 46.75}{\sqrt{846.9167}}
\end{bmatrix}
=
\begin{bmatrix}
0.6448 & 1.0415 \\
-0.1091 & -1.5772
\end{bmatrix} \\
\text{Batch 2, Channel 1:} \quad &
\begin{bmatrix}
\frac{53 - 55}{\sqrt{1298}} & \frac{95 - 55}{\sqrt{1298}} \\
\frac{64 - 55}{\sqrt{1298}} & \frac{8 - 55}{\sqrt{1298}}
\end{bmatrix}
=
\begin{bmatrix}
-0.0641 & 1.2820 \\
0.2885 & -1.5064
\end{bmatrix} \\
\text{Batch 2, Channel 2:} \quad &
\begin{bmatrix}
\frac{0 - 30.5}{\sqrt{1284.3334}} & \frac{15 - 30.5}{\sqrt{1284.3334}} \\
\frac{25 - 30.5}{\sqrt{1284.3334}} & \frac{82 - 30.5}{\sqrt{1284.3334}}
\end{bmatrix}
=
\begin{bmatrix}
-0.9827 & -0.4994 \\
-0.1772 & 1.6593
\end{bmatrix}
\end{aligned}
コード
import torch
import torch.nn as nn
# Define the data as given in the BCHW format
data = torch.tensor([
[
[[55, 4], [50, 22]], # Batch 1, Channel 1
[[63, 73], [44, 7]] # Batch 1, Channel 2
],
[
[[53, 95], [64, 8]], # Batch 2, Channel 1
[[0, 15], [25, 82]] # Batch 2, Channel 2
]
], dtype=torch.float32)
# Normalization
inorm = nn.InstanceNorm2d(num_features=2, affine=False, eps=0, momentum=1) # No learnable parameters
# Apply Normalizations
instance_norm_output = inorm(data)
print("Instance Normalization Output:", instance_norm_output)
出力結果
Instance Normalization Output:
tensor([[[[ 1.0684, -1.3805],
[ 0.8283, -0.5162]],
[[ 0.6448, 1.0415],
[-0.1091, -1.5772]]],
[[[-0.0641, 1.2820],
[ 0.2885, -1.5064]],
[[-0.9827, -0.4994],
[-0.1772, 1.6593]]]])