LoginSignup
1
2

【Keras】【TensorFlow】教師あり学習fitの「batch_size」の技術検証

Posted at

検証を行うきっかけと目的

 同じデータや内容のAIの作成を行う場合でも⓵どの端末で行ったか⓶学習の変数の一つである「batch_size」が違うだけで大きく結果が異なるように見えたため。端末ごと、batch_sizeごとにどのような影響があるか検証する。

 上記の⓵と⓶以外は前回の下の記事をそのまま流用している。
AIチャレンジ1(天気予報)
https://qiita.com/horiivalue/items/639e2e9c8b83c8d4886a

fitの変数「batch_size」とは?

 そもそも、この「batch_size」は何を算定しているのか?
以下のQ&Aなどを参照してまとめてみると以下のようになる。
⓵kerasのbatch_sizeについて(分からない)
https://teratail.com/questions/191679
⓶Kerasのドキュメント
https://keras.io/ja/models/model/#:~:text=batch_size%3A%20%E6%95%B4%E6%95%B0%E3%81%BE%E3%81%9F%E3%81%AF%20None,%EF%BC%8E%E5%8B%BE%E9%85%8D%E6%9B%B4%E6%96%B0%E6%AF%8E%E3%81%AE%E3%82%B5%E3%83%B3%E3%83%97%E3%83%AB%E6%95%B0%E3%82%92%E7%A4%BA%E3%81%99%E6%95%B4%E6%95%B0%EF%BC%8E%E6%8C%87%E5%AE%9A%E3%81%97%E3%81%AA%E3%81%91%E3%82%8C%E3%81%B0%20batch_size%20%E3%81%AF%E3%83%87%E3%83%95%E3%82%A9%E3%83%AB%E3%83%88%E3%81%A732%E3%81%AB%E3%81%AA%E3%82%8A%E3%81%BE%E3%81%99%EF%BC%8E

 「batch_size」は勾配更新毎のサンプル数を示す整数であり、1回あたりの学習の大きさを示しているそうだ。一般的に以下の特徴が存在する。
1,バッチサイズを小さくする:
局所解に嵌りずらくなる。
GPU を使用する場合、計算効率が悪くなる。

2,バッチサイズを大きくする:
局所解に嵌りやすくなる。
GPU を使用する場合、一度に計算できたほうが計算効率がよい。
あまり大きい値にしすぎると、GPU のメモリに乗り切らない場合もある。

ローカルでの検証

⓵デフォルト(32)の場合

log = model.fit(X_train, Y_train, epochs=5000, batch_size=32, verbose=True,
                callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss',
                                                         min_delta=0, patience=100,
                                                         verbose=1)],
         validation_data=(X_valid, Y_valid))

image.png

⓶batch_sizeを48に変更

log = model.fit(X_train, Y_train, epochs=5000, batch_size=48, verbose=True,
                callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss',
                                                         min_delta=0, patience=100,
                                                         verbose=1)],
         validation_data=(X_valid, Y_valid))

image.png

GoogleColaboratory上の場合(batch_sizeが24)

log = model.fit(X_train, Y_train, epochs=5000, batch_size=24, verbose=True,
                callbacks=[keras.callbacks.EarlyStopping(monitor='val_loss',
                                                         min_delta=0, patience=100,
                                                         verbose=1)],
         validation_data=(X_valid, Y_valid))
              precision    recall  f1-score   support

           0       0.71      0.76      0.74        33
           1       0.89      0.87      0.88        77

    accuracy                           0.84       110
   macro avg       0.80      0.81      0.81       110
weighted avg       0.84      0.84      0.84       110

検証結果

 3つの結果を比べて、マシンの性能がディープラーニングを左右することが判明した。計算効率を重視するなら、バッチサイズを大きくすることだが精度は落ちる。逆はまたしかりということなのだろうか。

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2