はじめに
モデルを訓練する際に,MNIST程度の画像サイズと枚数のデータセットであれば最初にすべてメモリに読み込むことができるので高速にイテレーションを回すことができます。
しかし,自前のそれなりに大きいサイズと枚数の画像データセットでモデルを訓練しようとすると,データがすべてメモリに乗り切らなくなってしまうことがよくあります。
ChainerのImageDataset
Chainerには自前の画像データセットを扱うためにImageDataset, LabeledImageDatasetなどのクラスが用意されています。
これらのクラスは画像のパスやラベルだけを保持しており,画像データ自体はイテレータなどからget_example()
が呼び出されるたびにその都度読み込まれます。
そのため,メモリに乗り切らないような大規模なデータセットも扱うことができますが,画像を参照するたびに読み込みのオーバーヘッドが発生します。
画像一枚一枚を読み込むのにかかる時間は微々たるものです。
しかし,ミニバッチ学習を行う場合,たいてい数十枚の画像を毎イテレーション読み込むことになるので,その時間は無視できなくなってきます。
作ったもの
画像データセットの一部だけをメモリにキャッシュしておくことで,読み込みにかかる時間を短縮するデータセットのクラスを作りました。
以下にコードを示します。
# -*- coding: utf-8 -*-
import numpy as np
import os
from PIL import Image
from chainer.datasets.image_dataset import _read_image_as_array, _postprocess_image
class CacheImageDataset(chainer.datasets.ImageDataset):
def __init__(self, paths, root=".", dtype=np.float32, cache_num=0):
super().__init__(paths, root, dtype)
self.cache_images = [None]*len(self)
for i in range(min(len(self), cache_num)):
self.cache_images[i] = super().get_example(i)
def get_example(self, i):
if self.cache_images[i] is None:
img = super().get_example(i)
else:
img = self.cache_images[i]
return img
使い方
基本的にはChainerのImageDatasetクラスと同じように使えます。
コンストラクタの引数cache_num
にメモリにキャッシュしておきたい枚数を指定することで,インスタンス生成時にその枚数分だけ読み込みます。
chainer.datasets.ImageDataset
クラスを継承しているので,Chainer標準のデータセットクラスと同じようにイテレータに渡したり,TransformDataset
に渡してData augmentationを行ったりできます。
LabeledImageDataset
のようにラベル付きで使いたい場合はTupleDataset
を活用してください。
おわりに
メモリが許す限り多くキャッシュするほど,回すエポック数が多いほど学習ループ中の画像読み込み処理が減るので高速化が期待されます。
上記ではキャッシュ保存時のdtype
をfloat32
にしてしまっていますが,uint8
にするなどの工夫をすればさらにキャッシュ枚数を増やせるのではないかと思います。