Keras の学習時に使用する fit ですが、独自に処理を追加したい場合、
fit_generator を使う場合と引数の Callback を追加する方法があります。
(それぞれ役割が違いますけど…)
それぞれの役割はともかく実行タイミングはぱっと見同じタイミングに見えるものがあるので、
実際に実行して違いを見てみました。
Callbackの定義
MyCallback
class MyCallback(keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None):
print("[Callback]on_batch_begin {}".format(batch))
def on_batch_end(self, batch, logs=None):
print("[Callback]on_batch_end {}".format(batch))
def on_epoch_begin(self, epoch, logs=None):
print("[Callback]on_epoch_begin {}".format(epoch))
def on_epoch_end(self, epoch, logs=None):
print("[Callback]on_epoch_end {}".format(epoch))
def on_train_begin(self, logs=None):
print("[Callback]on_train_begin")
def on_train_end(self, logs=None):
print("[Callback]on_train_end")
Generatorの定義
__len__
は1epoch内でのバッチの index を返す必要があります。
例えばデータ数が10でバッチ数が2だった場合、index は 0~4 になります。
__getitem__
ではindexに対応したバッチのデータを返します。
MyGenerator
class MyGenerator(keras.utils.Sequence):
def __init__(self, x_data, y_data, batch_size):
self.x_data = x_data
self.y_data = y_data
self.batch_size = batch_size
def __getitem__(self, idx):
print("[Generator] batch begin. {}".format(idx))
n = idx*self.batch_size
a = self.x_data[n:n+self.batch_size]
b = self.y_data[n:n+self.batch_size]
return np.asarray(a), np.asarray(b)
def __len__(self):
return int(len(self.x_data) / self.batch_size)
def on_epoch_end(self):
print("[Generator]on_epoch_end")
fit
call = MyCallback()
gen = MyGenerator(x_data, y_data, batch_size)
# train
model.fit_generator(gen, epochs=3, verbose=0, callbacks=[call])
結果
epochs=3、batch_size=2 の実行結果です。
どうやら Generator と Callback の実行タイミングは連動しているわけではなさそうです。
Generator は非同期に実行されてその後に Callback が実行されていそうですね。
Callback の実行順序が正しいのはたまたま?
Generator の index は shuffle が有効になっているのでランダムな順番になるだけでした。。。
[Callback]on_train_begin
[Callback]on_epoch_begin 0
[Generator] batch begin. 2
[Generator] batch begin. 3
[Callback]on_batch_begin 0
[Generator] batch begin. 0
[Generator] batch begin. 4
[Generator] batch begin. 1
[Callback]on_batch_end 0
[Callback]on_batch_begin 1
[Callback]on_batch_end 1
[Callback]on_batch_begin 2
[Callback]on_batch_end 2
[Callback]on_batch_begin 3
[Callback]on_batch_end 3
[Callback]on_batch_begin 4
[Callback]on_batch_end 4
[Callback]on_epoch_end 0
[Generator]on_epoch_end
[Callback]on_epoch_begin 1
[Generator] batch begin. 1
[Generator] batch begin. 3
[Callback]on_batch_begin 0
[Generator] batch begin. 4
[Generator] batch begin. 0
[Generator] batch begin. 2
[Callback]on_batch_end 0
[Callback]on_batch_begin 1
[Callback]on_batch_end 1
[Callback]on_batch_begin 2
[Callback]on_batch_end 2
[Callback]on_batch_begin 3
[Callback]on_batch_end 3
[Callback]on_batch_begin 4
[Callback]on_batch_end 4
[Callback]on_epoch_end 1
[Generator]on_epoch_end
[Callback]on_epoch_begin 2
[Generator] batch begin. 4
[Generator] batch begin. 1
[Callback]on_batch_begin 0
[Generator] batch begin. 3
[Generator] batch begin. 2
[Generator] batch begin. 0
[Callback]on_batch_end 0
[Callback]on_batch_begin 1
[Callback]on_batch_end 1
[Callback]on_batch_begin 2
[Callback]on_batch_end 2
[Callback]on_batch_begin 3
[Callback]on_batch_end 3
[Callback]on_batch_begin 4
[Callback]on_batch_end 4
[Callback]on_epoch_end 2
[Generator]on_epoch_end
[Callback]on_train_end
コード全体
from keras.models import Model
from keras.layers import *
from keras.utils import np_utils
import keras
import numpy as np
shape = (1,)
batch_size = 2
x_data = [ [i/10.] for i in range(10)]
y_data = [ i for i in range(10)]
# one hot encode the output variable
y_data = np_utils.to_categorical(y_data)
# create model
c = input_ = Input(shape=shape)
c = Dense(16, activation="relu")(c)
c = Dense(y_data.shape[1], activation="softmax")(c)
model = Model(input_, c)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
class MyCallback(keras.callbacks.Callback):
def on_batch_begin(self, batch, logs=None):
print("[Callback]on_batch_begin {}".format(batch))
def on_batch_end(self, batch, logs=None):
print("[Callback]on_batch_end {}".format(batch))
def on_epoch_begin(self, epoch, logs=None):
print("[Callback]on_epoch_begin {}".format(epoch))
def on_epoch_end(self, epoch, logs=None):
print("[Callback]on_epoch_end {}".format(epoch))
def on_train_begin(self, logs=None):
print("[Callback]on_train_begin")
def on_train_end(self, logs=None):
print("[Callback]on_train_end")
class MyGenerator(keras.utils.Sequence):
def __init__(self, x_data, y_data, batch_size):
self.x_data = x_data
self.y_data = y_data
self.batch_size = batch_size
def __getitem__(self, idx):
print("[Generator] batch begin. {}".format(idx))
n = idx*self.batch_size
a = self.x_data[n:n+self.batch_size]
b = self.y_data[n:n+self.batch_size]
return np.asarray(a), np.asarray(b)
def __len__(self):
return int(len(self.x_data) / self.batch_size)
def on_epoch_end(self):
print("[Generator]on_epoch_end")
call = MyCallback()
gen = MyGenerator(x_data, y_data, batch_size)
# train
model.fit_generator(gen, epochs=3, verbose=0, callbacks=[call])