Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
1
Help us understand the problem. What is going on with this article?
@chikuwa014

[Keras,Tensorflow]Model.predict()をループの中で呼び出すと落ちる

More than 3 years have passed since last update.

前提:KerasをTensorflowバックエンドで使っている

やりたいこと

複数の入力データをある学習済モデル(ここではsome_modelとする)に入れたときのそれぞれの推定結果を得たい。

問題

forループでmodel.predict()を多数回呼ぶと、GPUのメモリが不足して落ちる。

以下のコードでは、for文の中でフォルダinput_dirの中にある画像を読み込み、それにモデルを適用した結果をresultsの中に逐一加えている。

import keras
import numpy as np
from keras.preprocessing import image
import os

...

results = []
filenames = os.listdir(input_dir)
for name in filenames:
    img = image.load_img(input_dir+'/'+name, target_size=(256,256))
    x = image.img_to_array(img)
    x = np.expand_dims(x,0)
    result = some_model.predict(x)
    results.append(result)

解決策

一旦入力データをPython listの中に保持し、numpy.stack()で1つのNumpy arrayにまとめてから一回だけmodel.predictを呼ぶ。

images = []
filenames = os.listdir(input_dir)
for name in filenames:
    img = image.load_img(input_dir+'/'+name, target_size=(256,256))
    x = image.img_to_array(img)
    images.append(x)

images_array = np.stack(images)
results = some_model.predict(images_array,batch_size=1)
1
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
1
Help us understand the problem. What is going on with this article?