LoginSignup
9
1

More than 5 years have passed since last update.

Keras2.xのCallbackを自作するときのメモ

Posted at

Callbackでオーバーライドする時にメソッドの引数の中身が不明だったので、ログ出力してみた。
オーバーライドできるメソッドはgithubを参照すること。

from keras.callbacks import Callback

class MyCallback(Callback) :

    def __init__(self) :

    def on_train_begin(self, logs=None):
        print("on_train_begin")
        print("model", self.model)
        print("params", self.params)
        print("logs", logs)
        print("=" * 20)

    def on_train_end(self, logs=None):
        print("on_train_end")
        print("logs", logs)
        print("=" * 20)

    def on_batch_begin(self, batch, logs = None) :
        print("on_batch_begin")
        print("batch", batch)
        print("logs", logs)
        print("=" * 20)

    def on_batch_end(self, batch, logs = None) :
        print("on_batch_end")
        print("batch", batch)
        print("logs", logs)
        print("=" * 20)

    def on_epoch_begin(self, epoch, logs=None):
        print("on_epoch_begin")
        print("epoch", epoch)
        print("logs", logs)
        print("=" * 20)

    def on_epoch_end(self, epoch, logs = None) :
        print("on_epoch_end")
        print("epoch", epoch)
        print("logs", logs)
        print("=" * 20)

# Modelの構築は省略して、fitの箇所だけ書いてます。
model.fit(data, label_one_hot, epochs=3, batch_size=50, validation_split=0.0, callbacks=[MyCallback()], verbose=0)
ログ
on_train_begin
model <keras.engine.sequential.Sequential object at 0x7f114e96b400>
params {'batch_size': 50, 'epochs': 3, 'steps': None, 'samples': 90, 'verbose': 0, 'do_validation': False, 'metrics': ['loss', 'acc']}
logs {}
====================
on_epoch_begin
epoch 0
logs {}
====================
on_batch_begin
batch 0
logs {'batch': 0, 'size': 50}
====================
on_batch_end
batch 0
logs {'batch': 0, 'size': 50, 'loss': 1.1077641, 'acc': 0.42}
====================
on_batch_begin
batch 1
logs {'batch': 1, 'size': 40}
====================
on_batch_end
batch 1
logs {'batch': 1, 'size': 40, 'loss': 1.0845511, 'acc': 0.6}
====================
on_epoch_end
epoch 0
logs {'loss': 1.097447223133511, 'acc': 0.5000000033113692}
====================
on_epoch_begin
epoch 1
logs {}
====================
on_batch_begin
batch 0
logs {'batch': 0, 'size': 50}
====================
on_batch_end
batch 0
logs {'batch': 0, 'size': 50, 'loss': 1.0686755, 'acc': 0.72}
====================
on_batch_begin
batch 1
logs {'batch': 1, 'size': 40}
====================
on_batch_end
batch 1
logs {'batch': 1, 'size': 40, 'loss': 1.0953418, 'acc': 0.375}
====================
on_epoch_end
epoch 1
logs {'loss': 1.0805271996392145, 'acc': 0.5666666825612386}
====================
on_epoch_begin
epoch 2
logs {}
====================
on_batch_begin
batch 0
logs {'batch': 0, 'size': 50}
====================
on_batch_end
batch 0
logs {'batch': 0, 'size': 50, 'loss': 1.0497301, 'acc': 0.76}
====================
on_batch_begin
batch 1
logs {'batch': 1, 'size': 40}
====================
on_batch_end
batch 1
logs {'batch': 1, 'size': 40, 'loss': 1.1001902, 'acc': 0.275}
====================
on_epoch_end
epoch 2
logs {'loss': 1.0721567736731634, 'acc': 0.5444444417953491}
====================
on_train_end
logs {}
====================

まとめ

タイミング メソッド名 引数 内容
訓練開始時 on_train_begin logs {}
訓練終了時 on_train_end logs {}
epoch開始時 on_epoch_begin epoch int 0開始
logs {}
epoch終了時 on_epoch_end epoch int 0開始
logs {'loss': 1.093177596728007, 'acc': 0.36666665805710685}
1batchの訓練開始時 on_batch_begin batch int 0開始
logs {'batch': 0, 'size': 50}
1batchの訓練終了時 on_batch_end batch int 0開始
logs {'batch': 0, 'size': 50, 'loss': 1.1000787, 'acc': 0.48}
9
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
1