1
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlowの高レベルAPIを使ったBatch Normalizationの実装:Keras版

Posted at

前回の記事
「TensorFlowの高レベルAPIの使用方法:Kerasの使い方」
https://qiita.com/cometscome_phys/items/d9553fe7c92e09fc14a9
に引き続き、Kerasを使ってみる。
今回はバッチ正規化を用いる。tf.layersを用いたバッチ正規化のやり方は
「TensorFlowの高レベルAPIを使ったBatch Normalizationの実装」
https://qiita.com/cometscome_phys/items/6d5d3c74d7000382efef
にある。バッチ正規化についてはこちらの記事を参照。

バージョン

TensorFlow: 1.12.0
Keras: 2.1.6-tf

再現すべき関数

tf.layersの記事と同じ関数である。

test.py
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

n = 10
x0 = np.linspace(-2.0, 2.0, n)
a0 = 3.0
a1 = 2.0
b0 = 1.0
y0 = np.zeros((n,1))
y0[:,0] = a0*x0+a1*x0**2 + b0 + 3*np.cos(20*x0)

nm = 300
xmany = np.linspace(-2.0, 2.0, nm)
ymany = np.zeros((nm,1))
ymany[:,0] = a0*xmany+a1*xmany**2 + b0 + 3*np.cos(20*xmany)

plt.plot(xmany,ymany )
plt.show()
plt.savefig("graph_many.png")

Unknown-4.png

ランダムバッチ:バッチ正規化なし

まずはじめに、バッチ正規化なしの場合を考える。トレーニングの際にランダムにバッチを選んでくることで、特定のバッチに過学習されることを防ぐことができる。
Kerasでは、実はmodel.fitshuffletrueになっているとすでにランダムにバッチを選んできていることになる。なお、デフォルトでtrueである。

モデルの構築

前回の記事と同じである。

test.py
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import optimizers

def build_model(d_input,d_middle):
    inputs = tf.keras.Input(shape=(d_input,))  #インプットの次元を指定
    x = layers.Dense(d_middle, activation='relu')(inputs) #中間層の指定
    y = layers.Dense(1)(x) #最終層の指定
    adam = optimizers.Adam() #最適化にはAdamを使用
    model =  tf.keras.Model(inputs=inputs, outputs=y) #モデルのインプットとアウトプットを指定
    
    model.compile(optimizer=adam,
              loss='mean_squared_error') #modelの構築。平均二乗誤差をloss関数とした。
    
    return model

学習の実行

test.py
k = 6
phi = make_phi(x0,n,k)
phimany = make_phi(xmany,nm,k)
d_type = tf.float32
d_input = k
d_middle = 10

とパラメータを設定して、

test.py
model = build_model(d_input,d_middle)
history = model.fit(phimany, ymany, epochs=2000,batch_size=20,validation_data=(phi, y0))

で学習を行う。ここで、validation_dataにテストデータセットを設定しておくと、テストデータを使った時のlossも計算してくれる。

結果

学習結果は、

test.py
ytest = model.predict(phimany)

plt.plot(xmany,ymany )
plt.plot(xmany,ytest,'o')
plt.show()
plt.savefig("graph.png")

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train','test'], loc='upper left')
plt.show()
plt.savefig("train.png")

となった。

image.png

image.png

ランダムバッチ:バッチ正規化あり

次に、バッチ正規化を使ってみる。

モデルの構築

モデルは

test.py
from tensorflow.keras.layers import BatchNormalization

def build_model_BN(d_input,d_middle):
    inputs = tf.keras.Input(shape=(d_input,))  
    x = layers.Dense(d_middle, activation='relu')(inputs)
    x = BatchNormalization()(x)
    y = layers.Dense(1)(x)
    adam = optimizers.Adam()
    model =  tf.keras.Model(inputs=inputs, outputs=y)
    
    model.compile(optimizer=adam,
              loss='mean_squared_error')
    
    return model

とした。ここで、先ほどのモデルとの違いは、たった一行であることに注目してほしい。これですでにバッチ正規化ができているはずである。

学習

学習は

test.py
model_BN = build_model_BN(d_input,d_middle)
history_BN = model_BN.fit(phimany, ymany, epochs=2000,batch_size=20,validation_data=(phi, y0))

で実行する。

結果

結果は、

test.py
ytest = model.predict(phimany)

plt.plot(xmany,ymany )
plt.plot(xmany,ytest,'o')
plt.show()
plt.savefig("graph_BN.png")

plt.yscale("log")
plt.plot(history.history['val_loss'])
plt.plot(history_BN.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['test','test with BN'], loc='upper left')
plt.show()
plt.savefig("test_BN.png")

でプロットした。バッチ正規化をしなかった場合とlogスケールで見てみたが、バッチ正規化ありの方が早くlossが減っていることがわかる。

image.png

image.png

1
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?