Batch Normalizationによる収束性能向上

  • 38
    Like
  • 0
    Comment
More than 1 year has passed since last update.

概要

Neural Networkの収束性能を向上させる手法の一つであるBatch NormalizationをTensorFlowで実装し、効果を検証します。

実装

理論については他の方が丁寧に解説してくださっているので、いきなり実装から入ります。

注意点として、Convolutional Layerの扱いを以下の元論文の通りにします。

For convolutional layers, we additionally want the normalization to obey the convolutional property - so that different elements of the same feature map, at different locations, are normalized in the same way. To achieve this, we jointly normalize all the activations in a mini-batch, over all locations.
...
We learn a pair of parameters $\gamma^{(k)}$ and $\beta^{(k)}$ per feature map.

取り扱う問題は、TensorFlow公式TutorialのDeep MNIST for expertsです。

チュートリアル通りのプログラムを修正してBatch Normalizationを実装します。

batch_normalization 関数

Layer毎にbatch normalizationを適用するので、関数化しておきます。

def batch_normalization(shape, input):
  eps = 1e-5
  gamma = weight_variable([shape])
  beta = weight_variable([shape])
  mean, variance = tf.nn.moments(input, [0])
  return gamma * (input - mean) / tf.sqrt(variance + eps) + beta

epsの値は、chainerの実装を参考にさせていただきました。

batch normalization適用のための修正

以下のような修正をします。
論文に記述されている通り、バイアスの影響は平均値を引く操作によって消えるために、無視します。

before
  W_conv1 = weight_variable([5, 5, 1, 32])
  b_conv1 = bias_variable([32])
  h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
  h_pool1 = max_pool_2x2(h_conv1)
after
  W_conv1 = weight_variable([5, 5, 1, 32])
  h_conv1 = conv2d(x_image, W_conv1)
  bn1 = batch_normalization(32, h_conv1)
  h_pool1 = max_pool_2x2(tf.nn.relu(bn1))

コード全体の差分は長くなるので割愛しますが、こちらのGitHubのPull Requestとして載せておきます。

実行結果

Batch Normalizationを使った場合と使わなかった場合の収束のしかたを比較します。

Learning Rate 1E-4

まずLearning Rateを$1 \times 10^{-4}$とした場合です。

1E-4.png

あれ、あまり変化がない・・・?

Learning Rate 1E-3

ということで、試しにLearning Rateを10倍の$1 \times 10^{-3}$としてみます。

1E-3.png

圧倒的収束性能です。Batch Normalizationを使わないとLearning Rateが大きすぎて収束しないところで効果を発揮しました。

1000ステップでの精度

1000ステップでのTest Dataに対する精度を示します。

Learning Rate without BN with BN
1E-4 96% 95%
1E-3 収束せず 98%

参考資料

Batch Normalizationの恩恵を受けるDCGANのアプリケーションとして、以下の記事も紹介します。