5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【エラー】Cannot interpret feed_dict key as Tensor: Tensor Tensor... is not an element of this graph

Last updated at Posted at 2018-08-11

エラーが出たとき用に足跡を残す。

fit_generatorを使いたいというときがある。kerasのデータ拡張機能を使いたいとき、あるいは色々書き方の問題でメモリエラーになる場合。
fit用にリストに画像とか学習用データを溜め込んで学習するコードを書いていた場合、
データが多くなると当然メモリエラーになる。
全データをバイナリに固めるのがいい場合もあるが、
手っ取り早くジェネレータにしてbatch分だけ取得するようにすることもできる。

fit_generatorを使うときにエラーが出て、結構ハマったが落ち着いて考えると、「そりゃそうだ」というエラーだった。

そもそも計算グラフが構築されてないとグラフないエラー出る。
あとは何かしらの問題で計算グラフ死んでた場合はエラーなる。

ちなみにfit_generatorのサンプルはこちら
https://gist.github.com/Hironsan/e041d6606164bc14c50aa56b989c5fc0

#何かしらの問題でグラフ死んでる場合

グラフできてから回そう

global graph
with graph.as_default():
    (... do inference here ...)
 with self.graph.as_default():
    labels = self.model.predict(data)

https://github.com/keras-team/keras/issues/2397
https://github.com/keras-team/keras/issues/6124

不正にセッションが残ることがあるので、最後に消すんですと。

from keras.backend import tensorflow_backend as backend
backend.clear_session()

なので初めにmodel.compileしてbackend.clear_session()してからfitすると、
グラフないエラーになる。

#スレッドセーフでない系の場合

Kerasはスレッドセーフでないからです。(バグ)
Graphを別に用意してください。

結論今回はこれで困ってたわけではなかった。

参考
https://teratail.com/questions/117352
https://github.com/keras-team/keras/issues/5896

#ジェネレータ部分でモデル部分を使ってしまってる場合

今回ハマってたのはこの例だった。
fit_generatorはデータセットをジェネレータ形式で作ってkerasに渡す。
yield書いとけばいい。

class ImageDataGenerator(object):
    def __init__(self):
        self.reset()
        self.IR2 = InceptionResNetV2(weights='imagenet', include_top=False)

    def reset(self):
        self.features = []

    def flow_from_directory(self, path_train, batch_size=32):
        while True:
            for path in pathlib.Path(path_train+'/images').iterdir():
                image_path = str(path)
                img = img_to_array(load_img(image_path, target_size=(299, 299)))
                img = np.array(img, dtype=float)
                img = preprocess_input(img)
                img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2])
                feature = self.IR2.predict(img)
                self.features.append(feature)
                if len(self.features) == batch_size:
                    input_features = np.array(self.features)
                    targets = np.array(1 or 0ラベル入れる処理入れる)
                    self.reset()
                    yield input_features, targets

image_gen = ImageDataGenerator()


model.fit_generator(
    generator=image_gen.flow_from_directory(path_train, IR2, batch_size=batch_size),
    steps_per_epoch=int(np.ceil(len(list(train_dir.iterdir())) / batch_size)),
    epochs=epoch,
    verbose=1)

generator部分でself.IR2.predict(img)してるが、コォ言う書き方はダメ。
おそらく前処理部分に位置してて計算グラフ構築前なのでエラーになる。
self.IR2.predict(img)した特徴量を学習に使いたい場合は特徴量を保存していて再利用しよう。

for path in pathlib.Path(path_train+'/images').iterdir():
    path = str(path)
    image_path = path
    print(image_path)
    img = img_to_array(load_img(image_path, target_size=(299, 299)))
    img = np.array(img, dtype=float)
    img = preprocess_input(img)
    img = img.reshape(1, img.shape[0], img.shape[1], img.shape[2])
    feature = IR2.predict(img)
    print(feature.shape)
    feature = feature.reshape(feature.shape[1], feature.shape[2], feature.shape[3])
    print(feature.shape)
    image_filename = os.path.basename(image_path)
    print(image_filename)
    np.save('image/{image}.npy'.format(image=image_filename), feature)

別の解決策の例

conv_base.predict(inputs_batch) で特徴を取り出して、後々再利用する。

import os
import numpy as np
from keras.preprocessing.image import ImageDataGenerator

base_dir = '/Users/fchollet/Downloads/cats_and_dogs_small'

train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')
test_dir = os.path.join(base_dir, 'test')

datagen = ImageDataGenerator(rescale=1./255)
batch_size = 20

def extract_features(directory, sample_count):
    features = np.zeros(shape=(sample_count, 4, 4, 512))
    labels = np.zeros(shape=(sample_count))
    generator = datagen.flow_from_directory(
        directory,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')
    i = 0
    for inputs_batch, labels_batch in generator:
        features_batch = conv_base.predict(inputs_batch) #ここ!!!!!!!!!
        features[i * batch_size : (i + 1) * batch_size] = features_batch
        labels[i * batch_size : (i + 1) * batch_size] = labels_batch
        i += 1
        if i * batch_size >= sample_count:
            # Note that since generators yield data indefinitely in a loop,
            # we must `break` after every image has been seen once.
            break
    return features, labels

train_features, train_labels = extract_features(train_dir, 2000)
validation_features, validation_labels = extract_features(validation_dir, 1000)
test_features, test_labels = extract_features(test_dir, 1000)

5
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?