1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[Keras,Tensorflow]Model.predict()をループの中で呼び出すと落ちる

Last updated at Posted at 2017-11-13

前提:KerasをTensorflowバックエンドで使っている

やりたいこと

複数の入力データをある学習済モデル(ここではsome_modelとする)に入れたときのそれぞれの推定結果を得たい。

問題

forループでmodel.predict()を多数回呼ぶと、GPUのメモリが不足して落ちる。

以下のコードでは、for文の中でフォルダinput_dirの中にある画像を読み込み、それにモデルを適用した結果をresultsの中に逐一加えている。

import keras
import numpy as np
from keras.preprocessing import image
import os

...

results = []
filenames = os.listdir(input_dir)
for name in filenames:
    img = image.load_img(input_dir+'/'+name, target_size=(256,256))
    x = image.img_to_array(img)
    x = np.expand_dims(x,0)
    result = some_model.predict(x)
    results.append(result)

解決策

一旦入力データをPython listの中に保持し、numpy.stack()で1つのNumpy arrayにまとめてから一回だけmodel.predictを呼ぶ。

images = []
filenames = os.listdir(input_dir)
for name in filenames:
    img = image.load_img(input_dir+'/'+name, target_size=(256,256))
    x = image.img_to_array(img)
    images.append(x)

images_array = np.stack(images)
results = some_model.predict(images_array,batch_size=1)
1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?