Callbackでオーバーライドする時にメソッドの引数の中身が不明だったので、ログ出力してみた。
オーバーライドできるメソッドはgithubを参照すること。
from keras.callbacks import Callback
class MyCallback(Callback) :
def __init__(self) :
def on_train_begin(self, logs=None):
print("on_train_begin")
print("model", self.model)
print("params", self.params)
print("logs", logs)
print("=" * 20)
def on_train_end(self, logs=None):
print("on_train_end")
print("logs", logs)
print("=" * 20)
def on_batch_begin(self, batch, logs = None) :
print("on_batch_begin")
print("batch", batch)
print("logs", logs)
print("=" * 20)
def on_batch_end(self, batch, logs = None) :
print("on_batch_end")
print("batch", batch)
print("logs", logs)
print("=" * 20)
def on_epoch_begin(self, epoch, logs=None):
print("on_epoch_begin")
print("epoch", epoch)
print("logs", logs)
print("=" * 20)
def on_epoch_end(self, epoch, logs = None) :
print("on_epoch_end")
print("epoch", epoch)
print("logs", logs)
print("=" * 20)
# Modelの構築は省略して、fitの箇所だけ書いてます。
model.fit(data, label_one_hot, epochs=3, batch_size=50, validation_split=0.0, callbacks=[MyCallback()], verbose=0)
ログ
on_train_begin
model <keras.engine.sequential.Sequential object at 0x7f114e96b400>
params {'batch_size': 50, 'epochs': 3, 'steps': None, 'samples': 90, 'verbose': 0, 'do_validation': False, 'metrics': ['loss', 'acc']}
logs {}
====================
on_epoch_begin
epoch 0
logs {}
====================
on_batch_begin
batch 0
logs {'batch': 0, 'size': 50}
====================
on_batch_end
batch 0
logs {'batch': 0, 'size': 50, 'loss': 1.1077641, 'acc': 0.42}
====================
on_batch_begin
batch 1
logs {'batch': 1, 'size': 40}
====================
on_batch_end
batch 1
logs {'batch': 1, 'size': 40, 'loss': 1.0845511, 'acc': 0.6}
====================
on_epoch_end
epoch 0
logs {'loss': 1.097447223133511, 'acc': 0.5000000033113692}
====================
on_epoch_begin
epoch 1
logs {}
====================
on_batch_begin
batch 0
logs {'batch': 0, 'size': 50}
====================
on_batch_end
batch 0
logs {'batch': 0, 'size': 50, 'loss': 1.0686755, 'acc': 0.72}
====================
on_batch_begin
batch 1
logs {'batch': 1, 'size': 40}
====================
on_batch_end
batch 1
logs {'batch': 1, 'size': 40, 'loss': 1.0953418, 'acc': 0.375}
====================
on_epoch_end
epoch 1
logs {'loss': 1.0805271996392145, 'acc': 0.5666666825612386}
====================
on_epoch_begin
epoch 2
logs {}
====================
on_batch_begin
batch 0
logs {'batch': 0, 'size': 50}
====================
on_batch_end
batch 0
logs {'batch': 0, 'size': 50, 'loss': 1.0497301, 'acc': 0.76}
====================
on_batch_begin
batch 1
logs {'batch': 1, 'size': 40}
====================
on_batch_end
batch 1
logs {'batch': 1, 'size': 40, 'loss': 1.1001902, 'acc': 0.275}
====================
on_epoch_end
epoch 2
logs {'loss': 1.0721567736731634, 'acc': 0.5444444417953491}
====================
on_train_end
logs {}
====================
まとめ
タイミング | メソッド名 | 引数 | 内容 |
---|---|---|---|
訓練開始時 | on_train_begin | logs | {} |
訓練終了時 | on_train_end | logs | {} |
epoch開始時 | on_epoch_begin | epoch | int 0開始 |
logs | {} | ||
epoch終了時 | on_epoch_end | epoch | int 0開始 |
logs | {'loss': 1.093177596728007, 'acc': 0.36666665805710685} | ||
1batchの訓練開始時 | on_batch_begin | batch | int 0開始 |
logs | {'batch': 0, 'size': 50} | ||
1batchの訓練終了時 | on_batch_end | batch | int 0開始 |
logs | {'batch': 0, 'size': 50, 'loss': 1.1000787, 'acc': 0.48} |