Batch Normalizationはどんな計算をしているかよくわからず使っていましたので、Chainerで少し試してみました。
数式など解説は
多層ニューラルネットでBatch Normalizationの検証
が参考になります。
以下のプログラムはChainer Playground βでもコピペすれば動作しました。
ざっくりいうと、channel毎に平均と標準偏差が求められ、同じchannelに属するデータは全部同じの平均と標準偏差で正規化されているようでした。
import
import chainer.links as L
import numpy as np
batch=8枚,1chデータ
a = np.asarray(range(0,8),dtype=np.float32)
a = a.reshape((8,1))
print(a.shape)
print(a)
bn = L.BatchNormalization(1)
c = bn(a)
print(c.data)
(8, 1)
[[ 0.]
[ 1.]
[ 2.]
[ 3.]
[ 4.]
[ 5.]
[ 6.]
[ 7.]]
[[-1.52752233]
[-1.09108734]
[-0.65465242]
[-0.21821748]
[ 0.21821748]
[ 0.65465242]
[ 1.09108734]
[ 1.52752233]]
batch=8枚, 2chデータ
a1 = np.asarray(range(0,8),dtype=np.float32)
a2 = np.asarray(range(8,16),dtype=np.float32)
a1 = a1.reshape((8,1))
a2 = a2.reshape((8,1))
a = np.concatenate((a1,a2),axis = 1)
print(a.shape)
print(a)
bn = L.BatchNormalization(2)
c = bn(a)
print(c.data)
(8, 2)
[[ 0. 8.]
[ 1. 9.]
[ 2. 10.]
[ 3. 11.]
[ 4. 12.]
[ 5. 13.]
[ 6. 14.]
[ 7. 15.]]
[[-1.52752233 -1.52752233]
[-1.09108734 -1.09108734]
[-0.65465242 -0.65465242]
[-0.21821748 -0.21821748]
[ 0.21821748 0.21821748]
[ 0.65465242 0.65465242]
[ 1.09108734 1.09108734]
[ 1.52752233 1.52752233]]
batch=1枚,1ch 2*4データ
a = np.asarray(range(0,8),dtype=np.float32)
a = a.reshape((1,1,2,4))
print(a.shape)
print(a)
bn = L.BatchNormalization(1)
c=bn(a)
print(c.data)
(1, 1, 2, 4)
[[[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]]]]
[[[[-1.52752233 -1.09108734 -0.65465242 -0.21821748]
[ 0.21821748 0.65465242 1.09108734 1.52752233]]]]
batch=1枚, 2ch, 2*4データ
a1 = np.asarray(range(0,8),dtype=np.float32)
a2 = np.asarray(range(8,16),dtype=np.float32)
a1 = a1.reshape((1,1,2,4))
a2 = a2.reshape((1,1,2,4))
a = np.concatenate((a1,a2),axis = 1)
print(a.shape)
print(a)
bn = L.BatchNormalization(2)
c=bn(a)
print(c.data)
(1, 2, 2, 4)
[[[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]]
[[ 8. 9. 10. 11.]
[ 12. 13. 14. 15.]]]]
[[[[-1.52752233 -1.09108734 -0.65465242 -0.21821748]
[ 0.21821748 0.65465242 1.09108734 1.52752233]]
[[-1.52752233 -1.09108734 -0.65465242 -0.21821748]
[ 0.21821748 0.65465242 1.09108734 1.52752233]]]]
batch=2枚, 1ch, 2*4データ
a1 = np.asarray(range(0,8),dtype=np.float32)
a2 = np.asarray(range(8,16),dtype=np.float32)
a1 = a1.reshape((1,1,2,4))
a2 = a2.reshape((1,1,2,4))
a = np.concatenate((a1,a2),axis = 0)
print(a.shape)
print(a)
bn = L.BatchNormalization(1)
c=bn(a)
print(c.data)
(2, 1, 2, 4)
[[[[ 0. 1. 2. 3.]
[ 4. 5. 6. 7.]]]
[[[ 8. 9. 10. 11.]
[ 12. 13. 14. 15.]]]]
[[[[-1.6269778 -1.41004741 -1.19311702 -0.97618663]
[-0.7592563 -0.54232591 -0.32539555 -0.10846519]]]
[[[ 0.10846519 0.32539555 0.54232591 0.7592563 ]
[ 0.97618663 1.19311702 1.41004741 1.6269778 ]]]]
。