LoginSignup
0
2

More than 5 years have passed since last update.

kerasでtrainとvalidationのTensorboardの出力を分ける

Last updated at Posted at 2018-09-25

これなーに

keras便利ですよね〜 
モデル作って、精度のカーブをtensorboardで確認っと・・・

スクリーンショット 2018-09-25 23.12.27.png

trainとvalが別れてグラフにプロットされてて、ちょっと見づらい。。
特にoverfitなどを確認する時は、trainとvalのグラフの精度のカーブを一緒に見たくなるので、できたら一緒にみたい!

ということで、下みたいに一緒に見れるようにするtipsです。
スクリーンショット 2018-09-25 22.44.45.png

やり方

以下のclassを作って、model.fitをする時の引数にcallbacks=[TrainValTensorBoard(write_graph=False)]を指定してあげる。

class TrainValTensorBoard(keras.callbacks.TensorBoard):
    def __init__(self, log_dir='../saved/', **kwargs):
        # Make the original `TensorBoard` log to a subdirectory 'training'
        training_log_dir = os.path.join(log_dir, '../saved/tensorboard/train')
        super(TrainValTensorBoard, self).__init__(training_log_dir, **kwargs)

        # Log the validation metrics to a separate subdirectory
        self.val_log_dir = os.path.join(log_dir, '../saved/tensorboard/val')

    def set_model(self, model):
        # Setup writer for validation metrics
        self.val_writer = tf.summary.FileWriter(self.val_log_dir)
        super(TrainValTensorBoard, self).set_model(model)

    def on_epoch_end(self, epoch, logs=None):
        # Pop the validation logs and handle them separately with
        # `self.val_writer`. Also rename the keys so that they can
        # be plotted on the same figure with the training metrics
        logs = logs or {}
        val_logs = {k.replace('val_', ''): v for k, v in logs.items() if k.startswith('val_')}
        for name, value in val_logs.items():
            summary = tf.Summary()
            summary_value = summary.value.add()
            summary_value.simple_value = value.item()
            summary_value.tag = name
            self.val_writer.add_summary(summary, epoch)
        self.val_writer.flush()

        # Pass the remaining logs to `TensorBoard.on_epoch_end`
        logs = {k: v for k, v in logs.items() if not k.startswith('val_')}
        super(TrainValTensorBoard, self).on_epoch_end(epoch, logs)

    def on_train_end(self, logs=None):
        super(TrainValTensorBoard, self).on_train_end(logs)
        self.val_writer.close()

def training():
    ~略~
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test),
              callbacks=[TrainValTensorBoard(write_graph=False)])
    ~略~

一番簡単なのは、tensorboardの利用を以下の形で定義し、fitする時のcallbacksに定義した物を突っ込みますが、その代わりに上をやると良い感じでtrainとvalが同時に見えるように出力できます。

    ~略~
    tb_cb = keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
    cbks = [tb_cb]
    ~略~
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test),
              callbacks=cbks)
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2