LoginSignup
3
2

More than 5 years have passed since last update.

Batch Normalization メモ

Posted at

Batch Normalizationはどんな計算をしているかよくわからず使っていましたので、Chainerで少し試してみました。

数式など解説は
多層ニューラルネットでBatch Normalizationの検証
が参考になります。

以下のプログラムはChainer Playground βでもコピペすれば動作しました。

ざっくりいうと、channel毎に平均と標準偏差が求められ、同じchannelに属するデータは全部同じの平均と標準偏差で正規化されているようでした。

import

import chainer.links as L
import numpy as np

batch=8枚,1chデータ

a = np.asarray(range(0,8),dtype=np.float32)
a = a.reshape((8,1))
print(a.shape)
print(a)
bn = L.BatchNormalization(1)
c = bn(a)
print(c.data)


    (8, 1)
    [[ 0.]
     [ 1.]
     [ 2.]
     [ 3.]
     [ 4.]
     [ 5.]
     [ 6.]
     [ 7.]]
    [[-1.52752233]
     [-1.09108734]
     [-0.65465242]
     [-0.21821748]
     [ 0.21821748]
     [ 0.65465242]
     [ 1.09108734]
     [ 1.52752233]]

batch=8枚, 2chデータ

a1 = np.asarray(range(0,8),dtype=np.float32)
a2 = np.asarray(range(8,16),dtype=np.float32)
a1 = a1.reshape((8,1))
a2 = a2.reshape((8,1))
a = np.concatenate((a1,a2),axis = 1)
print(a.shape)
print(a)
bn = L.BatchNormalization(2)
c = bn(a)
print(c.data)


    (8, 2)
    [[  0.   8.]
     [  1.   9.]
     [  2.  10.]
     [  3.  11.]
     [  4.  12.]
     [  5.  13.]
     [  6.  14.]
     [  7.  15.]]
    [[-1.52752233 -1.52752233]
     [-1.09108734 -1.09108734]
     [-0.65465242 -0.65465242]
     [-0.21821748 -0.21821748]
     [ 0.21821748  0.21821748]
     [ 0.65465242  0.65465242]
     [ 1.09108734  1.09108734]
     [ 1.52752233  1.52752233]]

batch=1枚,1ch 2*4データ

a = np.asarray(range(0,8),dtype=np.float32)
a = a.reshape((1,1,2,4))
print(a.shape)
print(a)
bn = L.BatchNormalization(1)
c=bn(a)
print(c.data)


    (1, 1, 2, 4)
    [[[[ 0.  1.  2.  3.]
       [ 4.  5.  6.  7.]]]]
    [[[[-1.52752233 -1.09108734 -0.65465242 -0.21821748]
       [ 0.21821748  0.65465242  1.09108734  1.52752233]]]]

batch=1枚, 2ch, 2*4データ

a1 = np.asarray(range(0,8),dtype=np.float32)
a2 = np.asarray(range(8,16),dtype=np.float32)
a1 = a1.reshape((1,1,2,4))
a2 = a2.reshape((1,1,2,4))
a = np.concatenate((a1,a2),axis = 1)
print(a.shape)
print(a)
bn = L.BatchNormalization(2)
c=bn(a)
print(c.data)


    (1, 2, 2, 4)
    [[[[  0.   1.   2.   3.]
       [  4.   5.   6.   7.]]

      [[  8.   9.  10.  11.]
       [ 12.  13.  14.  15.]]]]
    [[[[-1.52752233 -1.09108734 -0.65465242 -0.21821748]
       [ 0.21821748  0.65465242  1.09108734  1.52752233]]

      [[-1.52752233 -1.09108734 -0.65465242 -0.21821748]
       [ 0.21821748  0.65465242  1.09108734  1.52752233]]]]

batch=2枚, 1ch, 2*4データ

a1 = np.asarray(range(0,8),dtype=np.float32)
a2 = np.asarray(range(8,16),dtype=np.float32)
a1 = a1.reshape((1,1,2,4))
a2 = a2.reshape((1,1,2,4))
a = np.concatenate((a1,a2),axis = 0)
print(a.shape)
print(a)
bn = L.BatchNormalization(1)
c=bn(a)
print(c.data)


    (2, 1, 2, 4)
    [[[[  0.   1.   2.   3.]
       [  4.   5.   6.   7.]]]


     [[[  8.   9.  10.  11.]
       [ 12.  13.  14.  15.]]]]
    [[[[-1.6269778  -1.41004741 -1.19311702 -0.97618663]
       [-0.7592563  -0.54232591 -0.32539555 -0.10846519]]]


     [[[ 0.10846519  0.32539555  0.54232591  0.7592563 ]
       [ 0.97618663  1.19311702  1.41004741  1.6269778 ]]]]

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2