1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasでメモリが足りないときの逐次処理方法

Posted at

概要

model.fit_generatorを使います。
https://keras.io/ja/models/model/#fit_generator

環境

  • Keras: 2.4.3
  • python: 3.6.9

使い方

メモリが少なくなった時、一番お手軽にできる対処法はbatch_sizeを減らすことですが、paddingなどの前処理をした時点で気絶したり、そもそもデータが巨大すぎて全て乗らず一度に前処理すらできないこともあるかと思います。
その場合はモデルの訓練時に、batchごとにデータロード→前処理をしなければなりません。

fit_generatorでは、通常データを指定すべきところに関数を指定することができます。
また、validation_data=で検証用データに対しても関数を指定できます。
この関数内 (generate_arrays) で上記の処理を行います。generate_arraysは後程定義します。

batch_size = 8 # 一度にメモリに乗せるデータ
model.fit_generator(
    generate_arrays(x_train, y_train, batch_size),
    validation_data=generate_arrays(x_test, y_test, batch_size),
    epochs=3,
    steps_per_epoch=len(x_train) // batch_size,
    validation_steps=len(x_test) // batch_size
)

steps_per_epochで1エポックで何ステップ実行するかを決定します。これを決めないと終わりません。通常は、データ全てを見てほしいのでデータ全体の長さをバッチサイズで割ります。validation_stepsは検証データに関してのものです。

generate_arraysの中身は下のようになります。

def generate_arrays(x, y, batch_size=32):
    i = 0
    while True:
        batch_df = x[i * batch_size : (i + 1) * batch_size]
        batch_y = y[i * batch_size : (i + 1) * batch_size]
        if (i + 1) * batch_size >= len(x):
            i = 0 # iのリセットは必要
        else:
            i += 1
        yield process_data(batch_df, batch_y)

# 前処理の一例
def process_data(batch_df, batch_y):
    arr = {}   
    # パディングして履歴の長さをそろえる
    for c in batch_df.columns:
        arr[c] = pad_sequences(batch_df[c], dtype='float32', maxlen=MAX_RES_TOKENS)
    
    return (arr, batch_y)

今回はデータ自体はメモリに乗ったが前処理を一度にできないパターンでした。x, yを受け取った後はまずDataFrameをバッチサイズに切り出します。
pythonのyieldはreturnと似ているのですが、ループが終わらずにまた実行されます。下記サイトが参考になります。
http://ailaby.com/yield/
次のepochに行ってもgenerate_arraysのiはリセットされないので、データサイズを超えるindexを見ようとした場合は0に手動でリセットする必要があります。

yieldするデータとしては(batch_x, batch_y)のタプルです。
もし、モデルの入力や出力に名前を付けていた場合、

(
    {
        'in1': batch_in1, 'in2': batch_in2
    },
    {
        'out1': batch_out1, 'out2': batch_out2
    }
)

の形でyieldすれば大丈夫です。

データが一度で乗らないとき

csvファイルなどが連番で大量にあるときは、事前に検証用データを分けておくなどして

def generate_arrays(file_path, max_file_count):
    i = 0
    while True:
        batch_df = pd.read_csv(f"{file_path}/{i}.csv")
        batch_y = batch_df['answer']
        if i >= max_file_count:
            i = 0
        else:
            i += 1
        yield process_data(batch_df, batch_y)

などとすれば大丈夫です。ファイルごとにデータ数が変わる場合はcsvを読み込んだ後さらに切り出しを行います。

参考サイト

参考になりました、ありがとうございます。この記事は↓の内容をシンプルにしただけです。
Kerasくんとgeneratorの魔法

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?