LoginSignup
15
17

More than 3 years have passed since last update.

tf.kerasで学習中の進捗表示をカスタマイズする (GoogleColaboratoryのセルあふれ対策)

Last updated at Posted at 2019-12-24

【内容】

同僚がKeras(tf.keras)を使ってGoogleColaboratory上で数万Epochの学習をしていたら、ブラウザは重くなるし、挙句の果てに表示が更新されなくなってしまったと嘆いていました。
原因は model.fit 中に verbose=1 を指定して進捗ログを表示していたのですが、そのログが肥大化して重くなり、ある閾値(?)に達すると表示が更新されなくなってしまいました。(動作は継続している)
verbose=0 でログを止めて回せば事足りますが、それでは進捗状況が確認できなくなってしまいます。

そういえば自分も過去に同じ現象に悩んだ挙げ句、Callback関数を使って解決したこと思い出したので共有したいと思います。

【tf.kerasのCallback関数について】

model.fit の引数 callbackstf.keras.callbacks.Callback クラスを継承したクラスを指定することで、学習中の振舞いをカスタマイズできます。
詳細は公式のドキュメントを確認してください。
【tf.keras.callbacks.Callback - TensorFlow】
【コールバック - Keras Documentation】

tf.keras.callbacks.Callback にはメソッドがいくつか用意されていますが、それらはあるタイミングで呼び出されるようになっています。
これらのメソッドをオーバライドすることで、学習時の振る舞いを変更できます。
今回は以下のメソッドをオーバライドしました。

メソッド 呼び出されるタイミング
on_train_begin 学習開始時
on_train_end 学習終了時
on_batch_begin Batch開始時
on_batch_end Batch終了時
on_epoch_begin Epoch開始時
on_epoch_end Epoch終了時

上記以外にも推論時やテスト時などに呼び出されるメソッドが用意されています。

【方針】

学習中の進捗表示を改行せずに同一行に上書きし続けることで、出力セルが肥大化してあふれないようにします。
同一行で上書きし続けるためには以下のコードを使います。

print('\rTest Print', end='')

上記コードの \r はCarriage Return(CR)を意味していて、カーソルを行の先頭に移動することが出来ます。
これにより表示済みの行の上書きが可能になります。

ただし、このままだとprint文を実行するたびに改行されてしまします。
そこでprint文の引数として end='' を指定します。
要は第一引数を出力後になにも出力しないように指定することで、改行を抑止します。
なお、print文はデフォルトで end='\n' が指定されています。
\n はLine Feed(LF)を意味して、カーソルを新しい行に送ります (つまり改行されます)。

試しに下記のコードを実行すると 0 ~ 9 を上書きし続け、カウントアップしているように表現できます。

上書きサンプル
from time import sleep
for i in range(10):
  print('\r%d' % i, end='')
  sleep(1)

ここで、ふと思います。
わざわざ '\r' をプリントするのではなく end='\r' にすれば良いようにも感じます。

しかし、この試みはうまくいきません。
なぜならPythonでは '\r' が出力される際に、それまで出力された内容がクリアされてしまうようです。
例えば print('Test Print', end='\r') を実行すると見かけ上はなにも表示されず、今回の用途では都合が悪いです。
なので、文字出力直前に '\r' を出力した後に、出力したい文字列を出力する方法で行くしかありません。

というわけで、上記の手法を使って以下の方針でコーディングします。

学習開始/終了時

開始/終了を表示するとともに、実行された時間を表示します。
ここは普通に改行します。

Batch完了時 および Epoch完了時

Epoch数や処理したデータ数、accやlossを表示します。
この表示は改行せずに上書きすることによって、出力セルのサイズを抑制します。

【コーディング】

上記の方針をもとに、実装します。
モデル部分はTensorFlowのチュートリアがベースになっています。
【TensorFlow 2 quickstart for beginners】

import tensorflow as tf
# カスタム進捗表示用のCallback関数定義
"""
進捗表示用のCallback関数です。
Batch終了時とEpoch終了時にデータを収集して、表示しています。
ポイントとしては print 出力時に /r で行先頭にカーソルを戻しつつ、引数 end='' で改行を抑制している点です。
"""
import datetime

class DisplayCallBack(tf.keras.callbacks.Callback):
  # コンストラクタ
  def __init__(self):
    self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss = None, None, None, None
    self.now_batch, self.now_epoch = None, None

    self.epochs, self.samples, self.batch_size = None, None, None

  # カスタム進捗表示 (表示部本体)
  def print_progress(self):
    epoch = self.now_epoch
    batch = self.now_batch

    epochs = self.epochs
    samples = self.samples
    batch_size = self.batch_size
    sample = batch_size*(batch)

    # '\r' と end='' を使って改行しないようにする
    if self.last_val_acc and self.last_val_loss:
      # val_acc/val_loss が表示可能
      print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f - val_acc: %f val_loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss, self.last_val_acc, self.last_val_loss), end='')
    else:
      # val_acc/val_loss が表示不可
      print("\rEpoch %d/%d (%d/%d) -- acc: %f loss: %f" % (epoch+1, epochs, sample, samples, self.last_acc, self.last_loss), end='')


  # fit開始時
  def on_train_begin(self, logs={}):
    print('\n##### Train Start ##### ' + str(datetime.datetime.now()))

    # パラメータの取得
    self.epochs = self.params['epochs']
    self.samples = self.params['samples']
    self.batch_size = self.params['batch_size']

    # 標準の進捗表示をしないようにする
    self.params['verbose'] = 0


  # batch開始時
  def on_batch_begin(self, batch, logs={}):
    self.now_batch = batch

  # batch完了時 (進捗表示)
  def on_batch_end(self, batch, logs={}):
    # 最新情報の更新
    self.last_acc = logs.get('acc') if logs.get('acc') else 0.0
    self.last_loss = logs.get('loss') if logs.get('loss') else 0.0

    # 進捗表示
    self.print_progress()


  # epoch開始時
  def on_epoch_begin(self, epoch, log={}):
    self.now_epoch = epoch

  # epoch完了時 (進捗表示)
  def on_epoch_end(self, epoch, logs={}):
    # 最新情報の更新
    self.last_val_acc = logs.get('val_acc') if logs.get('val_acc') else 0.0
    self.last_val_loss = logs.get('val_loss') if logs.get('val_loss') else 0.0

    # 進捗表示
    self.print_progress()


  # fit完了時
  def on_train_end(self, logs={}):
    print('\n##### Train Complete ##### ' + str(datetime.datetime.now()))
# コールバック関数用のインスタンス生成
cbDisplay = DisplayCallBack()
# MNISTデータセットを読み込み正規化
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# tf.keras.Sequential モデルの構築
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
# モデルの学習
# ここでコールバック関数を使います
history = model.fit(x_train, y_train,
                    validation_data = (x_test, y_test),
                    batch_size=128,
                    epochs=5,
                    verbose=1,              # 標準の進捗表示はコールバック関数内で無視するようにしている
                    callbacks=[cbDisplay])  # コールバック関数としてカスタム進捗表示をセット
# モデル評価
import pandas as pd

results = pd.DataFrame(history.history)
results.plot();

【出力例】

上記を実行すると何Epoch回しても、下記の3行しか表示されません。
2行目がBatch終了時とEpoch終了時に最新の情報に書き換わり、学習が完了すると最後の行が出力されます。

##### Train Start ##### 2019-12-24 02:17:27.484038
Epoch 5/5 (59904/60000) -- acc: 0.970283 loss: 0.066101 - val_acc: 0.973900 val_loss: 0.087803
##### Train Complete ##### 2019-12-24 02:17:34.443442

image.png

15
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15
17