2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasのfit実行時におけるCallbackとGeneratorの実行タイミングについて

Last updated at Posted at 2020-04-10

Keras の学習時に使用する fit ですが、独自に処理を追加したい場合、
fit_generator を使う場合と引数の Callback を追加する方法があります。
(それぞれ役割が違いますけど…)

それぞれの役割はともかく実行タイミングはぱっと見同じタイミングに見えるものがあるので、
実際に実行して違いを見てみました。

Callbackの定義

MyCallback
class MyCallback(keras.callbacks.Callback):
    def on_batch_begin(self, batch, logs=None):
        print("[Callback]on_batch_begin {}".format(batch))

    def on_batch_end(self, batch, logs=None):
        print("[Callback]on_batch_end {}".format(batch))

    def on_epoch_begin(self, epoch, logs=None):
        print("[Callback]on_epoch_begin {}".format(epoch))

    def on_epoch_end(self, epoch, logs=None):
        print("[Callback]on_epoch_end {}".format(epoch))

    def on_train_begin(self, logs=None):
        print("[Callback]on_train_begin")

    def on_train_end(self, logs=None):
        print("[Callback]on_train_end")

参考:Callback(Keras公式)

Generatorの定義

__len__ は1epoch内でのバッチの index を返す必要があります。
例えばデータ数が10でバッチ数が2だった場合、index は 0~4 になります。

__getitem__ ではindexに対応したバッチのデータを返します。

MyGenerator
class MyGenerator(keras.utils.Sequence):
    def __init__(self, x_data, y_data, batch_size):
        self.x_data = x_data
        self.y_data = y_data
        self.batch_size = batch_size

    def __getitem__(self, idx):
        print("[Generator] batch begin. {}".format(idx))

        n = idx*self.batch_size
        a = self.x_data[n:n+self.batch_size]
        b = self.y_data[n:n+self.batch_size]
        return np.asarray(a), np.asarray(b)

    def __len__(self):
        return int(len(self.x_data) / self.batch_size)
        
    def on_epoch_end(self):
        print("[Generator]on_epoch_end")

参考:Sequence(Keras公式)

fit

call = MyCallback()
gen = MyGenerator(x_data, y_data, batch_size)

# train
model.fit_generator(gen, epochs=3, verbose=0, callbacks=[call])

結果

epochs=3、batch_size=2 の実行結果です。
どうやら Generator と Callback の実行タイミングは連動しているわけではなさそうです。

Generator は非同期に実行されてその後に Callback が実行されていそうですね。
Callback の実行順序が正しいのはたまたま?
Generator の index は shuffle が有効になっているのでランダムな順番になるだけでした。。。

[Callback]on_train_begin
[Callback]on_epoch_begin 0
[Generator] batch begin. 2
[Generator] batch begin. 3
[Callback]on_batch_begin 0
[Generator] batch begin. 0
[Generator] batch begin. 4
[Generator] batch begin. 1
[Callback]on_batch_end 0
[Callback]on_batch_begin 1
[Callback]on_batch_end 1
[Callback]on_batch_begin 2
[Callback]on_batch_end 2
[Callback]on_batch_begin 3
[Callback]on_batch_end 3
[Callback]on_batch_begin 4
[Callback]on_batch_end 4
[Callback]on_epoch_end 0
[Generator]on_epoch_end
[Callback]on_epoch_begin 1
[Generator] batch begin. 1
[Generator] batch begin. 3
[Callback]on_batch_begin 0
[Generator] batch begin. 4
[Generator] batch begin. 0
[Generator] batch begin. 2
[Callback]on_batch_end 0
[Callback]on_batch_begin 1
[Callback]on_batch_end 1
[Callback]on_batch_begin 2
[Callback]on_batch_end 2
[Callback]on_batch_begin 3
[Callback]on_batch_end 3
[Callback]on_batch_begin 4
[Callback]on_batch_end 4
[Callback]on_epoch_end 1
[Generator]on_epoch_end
[Callback]on_epoch_begin 2
[Generator] batch begin. 4
[Generator] batch begin. 1
[Callback]on_batch_begin 0
[Generator] batch begin. 3
[Generator] batch begin. 2
[Generator] batch begin. 0
[Callback]on_batch_end 0
[Callback]on_batch_begin 1
[Callback]on_batch_end 1
[Callback]on_batch_begin 2
[Callback]on_batch_end 2
[Callback]on_batch_begin 3
[Callback]on_batch_end 3
[Callback]on_batch_begin 4
[Callback]on_batch_end 4
[Callback]on_epoch_end 2
[Generator]on_epoch_end
[Callback]on_train_end

コード全体

from keras.models import Model
from keras.layers import *
from keras.utils import np_utils
import keras

import numpy as np

shape = (1,)
batch_size = 2

x_data = [ [i/10.] for i in range(10)]
y_data = [ i for i in range(10)]

# one hot encode the output variable
y_data = np_utils.to_categorical(y_data)

# create model
c = input_ = Input(shape=shape)
c = Dense(16, activation="relu")(c)
c = Dense(y_data.shape[1], activation="softmax")(c)
model = Model(input_, c)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

class MyCallback(keras.callbacks.Callback):
    def on_batch_begin(self, batch, logs=None):
        print("[Callback]on_batch_begin {}".format(batch))

    def on_batch_end(self, batch, logs=None):
        print("[Callback]on_batch_end {}".format(batch))

    def on_epoch_begin(self, epoch, logs=None):
        print("[Callback]on_epoch_begin {}".format(epoch))

    def on_epoch_end(self, epoch, logs=None):
        print("[Callback]on_epoch_end {}".format(epoch))

    def on_train_begin(self, logs=None):
        print("[Callback]on_train_begin")

    def on_train_end(self, logs=None):
        print("[Callback]on_train_end")

class MyGenerator(keras.utils.Sequence):
    def __init__(self, x_data, y_data, batch_size):
        self.x_data = x_data
        self.y_data = y_data
        self.batch_size = batch_size

    def __getitem__(self, idx):
        print("[Generator] batch begin. {}".format(idx))

        n = idx*self.batch_size
        a = self.x_data[n:n+self.batch_size]
        b = self.y_data[n:n+self.batch_size]
        return np.asarray(a), np.asarray(b)

    def __len__(self):
        return int(len(self.x_data) / self.batch_size)
        
    def on_epoch_end(self):
        print("[Generator]on_epoch_end")

call = MyCallback()
gen = MyGenerator(x_data, y_data, batch_size)

# train
model.fit_generator(gen, epochs=3, verbose=0, callbacks=[call])
2
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?