LoginSignup
46
39

More than 5 years have passed since last update.

Kerasでメモリに乗りきらない量のデータを扱うための、fit_generator関数を試してみる

Posted at

画像関係のKaggleコンテストで、Kerasを使いつつコードを書いていたところ、前処理などで工夫しても厳しそうなレベルでメモリ不足に悩まされました。(しかし、一方で精度を上げるためになるべく多くのデータを使いたい)

他の人はどうやっているんだろう?と他人のカーネルを見ていたところ、KerasのSequentialクラスにfit_generator関数という、バッチ単位でデータを扱ってくれる(=瞬間的なメモリが少なくて済む)関数を使っているようでした。

過去に読んだ書籍だと、この関数は使っていなかったので、触りながら色々調べてみます。

簡単な例で試してみる。

MNISTで試してみます。モデルのコード自体は、以前書いたGoogle colaboratoryを試してみる(Keras & MNIST)のものをほぼそのまま使います。

X.shapeが(60000, 1, 28, 28)、y.shapeが(60000, 10)のデータは用意済みという想定で進めます。(また、Xの値は0~1の範囲に設定しておきます)

また、実際にKaggleで膨大なデータを扱うときのことを考慮して、メモリに乗せずにmemmapで扱う想定で進めてみます。(メモリ以外にも、ディスクサイズの制限がありますが、そちらはKaggle Kernelで前処理として複数のKernelに分割するなどして対応する想定)

import math
import os
import subprocess

import numpy as np
import pandas as pd
from keras import backend as K
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import Adam
from keras.utils import Sequence

INPUT_SHAPE = (1, 28, 28)
CLASSES_NUM = 10
EPOCH_NUM = 20
BATCH_SIZE = 128
K.set_image_dim_ordering('th')


model = Sequential()

model.add(
    Conv2D(filters=20, kernel_size=5, padding='same',
           input_shape=INPUT_SHAPE))
model.add(Activation('relu'))
model.add(
    MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

model.add(
    Conv2D(filters=50, kernel_size=5, padding='same'))
model.add(Activation('relu'))
model.add(
    MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))

model.add(Flatten())
model.add(Dense(500))
model.add(Activation('relu'))

model.add(Dense(units=CLASSES_NUM))
model.add(Activation('softmax'))

model.compile(
    loss='categorical_crossentropy', optimizer=Adam(), metrics=['accuracy'])

※データ読み込みの個所は割愛。

X.shape
(60000, 1, 28, 28)
y.shape
(60000, 10)

ここまではいつも通りです。

データの準備として、memmapのファイルパスを定義しておいて、書きこんでおきます。(この辺りは、実際のKaggleでは前処理としてカーネルを分ける想定)

# memmap用のファイルのパス。
X_MEMMAP_PATH = './X_memmap.npy'
Y_MEMMAP_PATH = './y_memmap.npy'

# Xのデータをmemmapのファイルへ書き込む。
X_memmap = np.memmap(
    filename=X_MEMMAP_PATH, dtype=np.float32, mode='w+', shape=(60000, 1, 28, 28))
X_memmap[:] = X

# yのデータをmemmapのファイルへ書き込む。
y_memmap = np.memmap(
    filename=Y_MEMMAP_PATH, dtype=np.uint8, mode='w+', shape=(60000, CLASSES_NUM))
y_memmap[:] = y

# memmapのインスタンスをdelで削除すると、memmapのファイルが閉じられます。
del X, y
del X_memmap, y_memmap

次に、バッチサイズ単位でのデータの読み込みを扱うためのクラスを定義します。
Pythonのジェネレーターで書くか、KerasのSequenceクラスを継承して作ります。(モデル作成時のSequentialクラスと紛らわしいですが、別物です)
Sequenceクラスを使うと、マルチプロセスでの制御をよしなに対応してくれたり、fit_generator関数実行時に必要な引数が少しシンプルになったりといったメリットがあるようで、今回はSequenceクラスを使っていきます。
マルチプロセスに関しては、例えばDataAugumentation的に、動的にデータを増やしたり、といったようなCPU負荷の高い処理などで効果を発揮します(※通常はUbuntu環境などで実行されると思いますが、Windowsなどの場合は恐らくマルチプロセスのものは使用できません)。

class MNISTSequence(Sequence):
    """
    学習中の、バッチ単位でのデータの取得を扱うためのクラス。

    Attributes
    ----------
    batch_size : int
        バッチ単体でのデータ件数。
    memmap_X : memmap
        入力データのmemmap配列。
    memmap_y : memmap
        教師データのmemmap配列。
    length : int
        データのインデックス件数。math.ceil(データ行数 / batch_size)の
        値が設定される。

    Parameters
    ----------
    batch_size : int
        バッチ単体でのデータ件数。
    """

    def __init__(self, batch_size):
        DATA_ROW_NUM = 60000

        self.batch_size = batch_size
        self.memmap_X = np.memmap(
            filename=X_MEMMAP_PATH, dtype=np.float32, mode='r',
            shape=(DATA_ROW_NUM, 1, 28, 28))
        self.memmap_y = np.memmap(
            filename=Y_MEMMAP_PATH, dtype=np.uint8, mode='r',
            shape=(DATA_ROW_NUM, CLASSES_NUM))
        self.length = math.ceil(DATA_ROW_NUM / batch_size)

    def __getitem__(self, idx):
        """
        対象のインデックスの、バッチ単体分のデータを取得する。

        Parameters
        ----------
        idx : int
            取得対象のインデックス番号。

        Returns
        -------
        X : memmap
            対象のインデックスの入力データ。
        y : memmap
            対象のインデックスの教師データ。
        """
        start_idx = idx * self.batch_size
        last_idx = start_idx + self.batch_size
        X = self.memmap_X[start_idx:last_idx]
        y = self.memmap_y[start_idx:last_idx]

        return X, y

    def __len__(self):
        """
        データのインデックス件数を取得する。math.ceil(データ行数 / batch_size)の
        値が設定される。

        Returns
        -------
        length : int
            データのインデックス件数。
        """
        return self.length

    def on_epoch_end(self):
        """
        1エポック分の処理が完了した際に実行される。
        属性で持っている(__getitem__関数実行後も残る)データなどの破棄処理や
        コールバックなど、必要な処理があれば記載する。
        """

        # メモリ使用量などを表示する(MB単位)。
        print(subprocess.getoutput('vmstat -S m'))

クラス内に以下の関数が最低限必要になります。

  • __init__
    • コンストラクタ。必要な初期処理などを記述します。
  • __getitem__
    • バッチサイズ分のデータを取得するためのコードを記述します。
    • 返却値はXとy(もしくはinputsとtargetsなどの名前)のtupleを設定する必要があります。
    • Xとyの行数は、バッチサイズで設定します。ただし、最後のバッチなどで件数が少なくなるなど、全てのタイミングで同じ行数が必要になるというわけではありません。
    • ただし、Xとyの行数は必ず一致させます。
  • __len__
    • バッチサイズで分割した際のデータの件数が返却されるようにします。(math.ceil(データ行数 / batch_size))
  • on_epoch_end
    • 1エポック分の処理が完了した際に必要な処理を記載します。
    • 特になければpassとだけ書いておきましょう。
    • 今回は、memmapを使っているのでメモリがほぼ変動しないのは当たり前ではありますが、一応linuxのコマンドでメモリ使用量などが出力されるようにしてみました。
    • ※Kaggle Kernel上でも使用メモリはUI上に表示されますが、あまり参考にならないときもあるため、コマンドで確認しています。

最後に学習を開始させて終わりです。

mnist_sequence = MNISTSequence(batch_size=BATCH_SIZE)
model.fit_generator(generator=mnist_sequence, epochs=EPOCH_NUM, verbose=2)
Epoch 1/20
 - 7s - loss: 0.1660 - acc: 0.9505
...
procs -----------memory---------- ---swap-- -----io---- -system-- ------cpu-----
 r  b   swpd   free   buff  cache   si   so    bi    bo   in   cs us sy id wa st
 0  0      0   3893   2075   5340    0    0    81    41  307   14  1  1 98  0  0
...

無事動いたようで、且つメモリの推移など見ていても変動していません。

最後に、fit_generator関数の他の引数(よく使いそうなもの)の説明を読んでみます。

  • use_multiprocessing : bool, default False
    • マルチプロセスで処理してほしい場合にはTrueを指定する必要があるようです。
  • workers : int, default 1
    • プロセスの上限数の指定のようです。Kaggle Kernelなどでは4コアなので4を指定したり、といった感じでしょうか。
  • validation_data
    • 今回は指定しませんでしたが、バリデーションのデータを扱う場合に指定します。
    • generator引数と同様、PythonのジェネレーターもしくはSequenceクラスを継承したクラスのインスタンスを指定してあげます。
    • 今回用意したMNISTSequenceで、コンストラクタで学習用なのかバリデーション用なのかを引数に渡すようにして、__getitem__関数を分岐させるなり、別のクラスを用意してしまうなりで対応します。

実行環境

!python -V
Python 3.6.6 :: Anaconda, Inc.
import keras
keras.__version__
'2.2.4'
import platform
platform.platform()
'Linux-4.9.0-5-amd64-x86_64-with-debian-8.9'
46
39
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
46
39