Python
ディープラーニング
Keras

Kerasで大容量データをModel.fit_generatorを使って学習する

画像用のネットワークを調整する際、入力サイズを増やしたら、画像読込時にメモリに乗り切らなくなった問題を解消した際のメモです。


前提環境


  • Ubuntu 16.04

  • Keras 2.1.6


課題

学習に使う画像データの総容量が大きくなり、一度に読込できなくなった。

そのため、一定サイズ毎に区切りながらデータを読み込む必要が発生した。


概要

Model.fitの代わりに、Model.fit_generatorメソッドを使って学習する。

fit_generatorメソッドには、学習・検証データとして、Generatorオブジェクトを渡す。

Generatorオブジェクトは、バッチ単位にデータを提供する仕組みを実装する。


実装例


my_generator.py

class MyGenerator(Sequence):

"""Custom generator"""

def __init__(self, data_paths, data_classes,
batch_size=1, width=256, height=256, ch=3, num_of_class=2):
"""construction

:param data_paths: List of image file
:param data_classes: List of class
:param batch_size: Batch size
:param width: Image width
:param height: Image height
:param ch: Num of image channels
:param num_of_class: Num of classes
"""

self.data_paths = data_paths
self.data_classes = data_classes
self.length = len(data_paths)
self.batch_size = batch_size
self.width = width
self.height = height
self.ch = ch
self.num_of_class = num_of_class
self.num_batches_per_epoch = int((self.length - 1) / batch_size) + 1

def __getitem__(self, idx):
"""Get batch data

:param idx: Index of batch

:return imgs: numpy array of images
:return labels: numpy array of label
"""

start_pos = self.batch_size * idx
end_pos = start_pos + self.batch_size
if end_pos > self.length:
end_pos = self.length
item_paths = self.data_paths[start_pos : end_pos]
item_classes = self.data_classes[start_pos : end_pos]
imgs = np.empty((len(item_paths), self.height, self.width, self.ch), dtype=np.float32)
labels = np.empty((len(item_paths), num_of_class), dtype=np.float32)

for i, (item_path, item_class) in enumerate(zip(item_paths, item_classes)):
img, label = _load_data(item_path, item_class, self.width, self.height, self.ch)
imgs[i, :] = img
labels[i] = label

return imgs, labels

def __len__(self):
"""Batch length"""

return self.num_batches_per_epoch

def on_epoch_end(self):
"""Task when end of epoch"""
pass



train.py

# train_paths: list of image file path[xxx.png, yyy.png, zzz.png...]

# train_classes: list of class[0,1,0,1...]
val_count = int(len(train_paths) * 0.2)
train_gen = my_generator.DataGenerator(
train_paths[val_count:],
train_classes[val_count:],
batch_size=50)
val_gen = my_generator.DataGenerator(
train_paths[:val_count],
train_classes[:val_count],
batch_size=50)
model.fit_generator(
train_gen,
steps_per_epoch=train_gen.num_batches_per_epoch,
validation_data=val_gen,
validation_steps=val_gen.num_batches_per_epoch,
epochs=100,
shuffle=True)


説明


MyGenerator

以下のメソッドの実装が必要です。



  • __getitem__

    idxの指定にあわせて、バッチ単位のデータとラベルを返します。

    形式は、fitに渡すものと同じです。



  • __len__

    1epochのバッチ数を返します。



  • on_epoch_end

    1epochが終わった時の処理を実装します。


model.fit_generator

学習、検証用のGeneratorをそれぞれ作成し、引数に渡します。

shuffleは、バッチをシャッフルする設定です。

データ全体のシャッフルをする場合は、on_epoch_endで処理することで、epoch毎にシャッフルできます。


補足

評価(evaluate)、予測(predict)にも、同様にevaluate_generatorやpredict_generatorがあるので、Generatorが使えます。


参考

Keras ModelクラスAPI