ポイント
- Layer Normalization を実装し、具体的な数値で確認。
レファレンス
数式
(参照論文より引用)
サンプルコード
def layer_norm_3d(x, batch_size, length, n_units):
mean, var = tf.nn.moments(x, axes = [2])
mean = tf.expand_dims(mean, axis = 2)
var = tf.expand_dims(var, axis = 2)
init_one = tf.constant_initializer(value = 1.0, \
dtype = tf.float32)
init_zero = tf.constant_initializer(value = 0.0, \
dtype = tf.float32)
gamma = tf.get_variable('gamma', shape = [1, length, \
n_units], initializer = init_one)
beta = tf.get_variable('beta', shape = [1, length, \
n_units], initializer = init_zero)
normalized = (x - mean) / tf.sqrt(var + 1e-8)
return gamma * normalized + beta
def layer_norm_4d(x, batch_size, hight, width, n_units):
mean, var = tf.nn.moments(x, axes = [3])
mean = tf.expand_dims(mean, axis = 3)
var = tf.expand_dims(var, axis = 3)
init_one = tf.constant_initializer(value = 1.0, \
dtype = tf.float32)
init_zero = tf.constant_initializer(value = 0.0, \
dtype = tf.float32)
gamma = tf.get_variable('gamma', shape = [1, hight, \
width, n_units], initializer = init_one)
beta = tf.get_variable('beta', shape = [1, hight, width, \
n_units], initializer = init_zero)
normalized = (x - mean) / tf.sqrt(var + 1e-8)
return gamma * normalized + beta
batch_size = 2
length = 5
hight = 2
width = 2
n_units = 3
# for 3d
x = tf.reshape(tf.range(15, dtype = tf.float32), \
[batch_size, length, n_units])
# for 4d
# x = tf.reshape(tf.range(12, dtype = tf.float32), \
[batch_size, hight, width, n_units])
with tf.variable_scope('test'):
# for 3d
y = layer_norm_3d(x, batch_size, length, n_units)
# for 4d
#y = layer_norm_4d(x, batch_size, hight, width, n_units)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print('x: ')
print (sess.run(x))
print()
print ('y: ')
print (sess.run(y))
結果
3d
x:
[[[ 0. 1. 2.]
[ 3. 4. 5.]
[ 6. 7. 8.]
[ 9. 10. 11.]
[12. 13. 14.]]
[[15. 16. 17.]
[18. 19. 20.]
[21. 22. 23.]
[24. 25. 26.]
[27. 28. 29.]]]
y:
[[[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]]
[[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]
[-1.2247448 0. 1.2247448]
[-1.2247448 0. 1.2247448]]]
4d
x:
[[[[ 0. 1. 2.]
[ 3. 4. 5.]]
[[ 6. 7. 8.]
[ 9. 10. 11.]]]
[[[12. 13. 14.]
[15. 16. 17.]]
[[18. 19. 20.]
[21. 22. 23.]]]]
y:
[[[[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]]
[[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]]]
[[[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.2247449]]
[[-1.2247449 0. 1.2247449]
[-1.2247449 0. 1.22474
