Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
16
Help us understand the problem. What is going on with this article?
@FukuharaYohei

KerasのImageDataGeneratorを使いつつ複数Input統合モデル

More than 1 year has passed since last update.

記事「超簡単 Kerasで複数Input統合モデル」で、複数Inputに対応する方法を書きました。当記事では、複数InputでImageDataGeneratorを使う方法を紹介します。
ImageDataGeneratorは、画像水増しを手軽にできる非常に便利なクラスで、ググるとわかりやすいページがたくさん出てくるのでここでは言及しません。
ImageDataGeneratorを使うことで、下図(リンク先)のようなことがやりやすくなります。
alt

処理概要

以下の2種類のInputをするモデルを想定。
1. 16×16グレースケール画像
2. 3文字の数値(テキストを数値化した想定です。3文字という少なさが非現実的ですが、実際に使う場合には増やせばいいだけなのでこの程度)
両値の生成に乱数を使っているため、実際にデータの意味はありません。さらにはモデル定義もいい加減です。

処理プログラム

プログラム全体はGitHubを参照ください。

1. ライブラリインポート

主にnumpyとtensorflowに統合されているkerasを使います。ピュアなkerasでも問題なく、インポート元を変えるだけです。

from random import random
from math import ceil

import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, concatenate, Dense, Conv2D, MaxPooling2D, Flatten, SimpleRNN
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.preprocessing import image 
from tensorflow.python.keras.utils import np_utils
from tensorflow.python.keras.utils.vis_utils import plot_model
import numpy as np

2. 前処理

2.1. データ作成

下記の3種類のデータを4つ生成します。
説明変数1(x_image):16×16グレースケール画像
説明変数2(x_text):3文字の数値化されたテキスト
目的変数(y_train):4種類のラベルOne-Hot Encoding形式

NUM_DATA = 4 # Number of data

x_image = np.random.rand(NUM_DATA, 16, 16, 1) # dummy gray scale image as random value
x_text  = np.random.randint(0, 10, (NUM_DATA, 3, 1)) # dummy text
y_train = np.arange(NUM_DATA).reshape(NUM_DATA, 1)  # dummy label
y_train = np_utils.to_categorical(y_train) # to one hot-encoding

2.2. Generator関数定義

Generatorの関数を定義します。この関数の定義方法を調べるのに時間がかかりました。
通常、flow関数で受け取る値はラベル(目的変数)ですが、今回はこれを行数(x_index)にします。そして受け取った行数を使ってテキスト(x_text)とラベル(y_train)をこの関数で返すようにします。

# Generator for 2 input
def gen_flow_for_two_inputs(datagen, batch, x_image, x_text, y_train, shuffle=True):
    """
    Args:
        datagen(image.ImageDataGenerator): data generator
        batch(int): batch size 
        x_image(np.ndarray): image array for input
        x_text(np.ndarray): text array for input
        y_train(np.ndarray): label array for output 
        shuffle(bool): bool to shuffle data
    """
    # Create index
    x_index = np.arange(x_image.shape[0])

    # Pass index to the 2nd parameter instead of labels
    batch = datagen.flow(x_image, x_index, batch_size=batch, shuffle=shuffle)
    while True:
        batch_image, batch_index = batch.next()

        # Use index values for text(x_text) and labels(y_train)
        yield [batch_image, x_text[batch_index]], y_train[batch_index]

Generator関数をチェックしてみます。

sample_datagen = image.ImageDataGenerator() # data generator for checking generator
gen_flow = gen_flow_for_two_inputs(sample_datagen, 2, x_image, x_text, y_train, False)

# check generator
for i, batch in enumerate(gen_flow):
    print(i, '-----')
    print(batch[0][0].shape)
    print(batch[0][1].shape)
    print(batch[1].shape)
    print(batch[1])
    if i == 1:
        break

One-Hot Encoding形式のラベル(目的変数)を意味する配列が1, 2, 3, 4となっているのがわかります。

0 -----
(2, 16, 16, 1)
(2, 3, 1)
(2, 4)
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]]
1 -----
(2, 16, 16, 1)
(2, 3, 1)
(2, 4)
[[0. 0. 1. 0.]
 [0. 0. 0. 1.]]

3. モデル定義

画像とテキストの2つをInputとして統合するモデルを定義します。
今回はEmbeddingレイヤーを使っていないので大丈夫ですが、Embeddingレイヤーを使った場合、mask_zeroTrueにできないので注意してください。これはmaskingをサポートしていないFlattenレイヤーがあるためです。

# define two sets of inputs
input_image = Input(shape=(16, 16 ,1,))
input_text = Input(shape=(3, 1, ))

# image input
model_image = Conv2D(32, kernel_size=(4, 4), activation='relu')(input_image)
model_image = MaxPooling2D(pool_size=(2, 2))(model_image)
model_image = Flatten()(model_image)
model_image = Model(inputs=input_image, outputs=model_image)

# text input
# Be careful that Embedding with mask_zero and Flatten cannot coexist
model_text = SimpleRNN(16, return_sequences=True)(input_text)
model_text = Flatten()(model_text)
model_text = Model(inputs=input_text, outputs=model_text)

# Image and text combined
combined = concatenate([model_image.output, model_text.output])
final = Dense(32, activation="relu")(combined)
final = Dense(4, activation="sigmoid")(final)

model = Model(inputs=[model_image.input, model_text.input], outputs=final)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.summary()

plot_modelで出力するとこんな感じです.
image.png

4. 訓練実行

Generator関数の場合は、fit関数ではなくfit_generator関数を使って訓練実行です。steps_per_epochの値はデータ数とバッチサイズから計算しています。
申し訳程度にImageDataGeneratorで画像水増しをしています。

train_datagen = image.ImageDataGenerator(
    rotation_range=10.,  # rotation angle
    rescale=1/255
    )

EPOCH = 10
BATCH = 2

history = model.fit_generator(
    generator=gen_flow_for_two_inputs(train_datagen, BATCH, x_image, x_text, y_train),
    steps=ceil(x_image.shape[0] / BATCH),
    epochs=EPOCH
    )

5. 評価

評価時にはevaluate関数でなく、evaluate_generator関数を使います。評価時に画像水増ししません。

test_datagen = image.ImageDataGenerator(rescale=1/255)

history = model.evaluate_generator(
    generator=gen_flow_for_two_inputs(test_datagen, BATCH, x_image, x_text, y_train),
    steps=ceil(x_image.shape[0] / BATCH)
    )

6. 予測

評価時にはpredict関数ではなく、predict_generator関数を使います。
データは評価時と同じものを流用しています。また、今回はこのあと予測結果の答え合わせをするので、引数shuffleFalseにしています。

# shuffle should be false for prediction, since result goes with labels
y_pred = model.predict_generator(
    generator=gen_flow_for_two_inputs(test_datagen, BATCH, x_image, x_text, y_train, shuffle=False),
    steps=ceil(x_image.shape[0] / BATCH)
    )

予測の答え合わせ。scikit-learnを使ってClassification Reportと混合行列を出しています。

from sklearn.metrics import classification_report, confusion_matrix

# change from one-hot encoding
y_train_np = np.argmax(y_train, axis=1)
y_pred_np = np.argmax(y_pred, axis=1)

# output classification report and confusion matrix
print(classification_report(y_train_np, y_pred_np))
print(confusion_matrix(y_train_np, y_pred_np))

元データが乱数なので、無意味ですが参考までに結果記載。

              precision    recall  f1-score   support

           0       1.00      1.00      1.00         1
           1       1.00      1.00      1.00         1
           2       1.00      1.00      1.00         1
           3       1.00      1.00      1.00         1

    accuracy                           1.00         4
   macro avg       1.00      1.00      1.00         4
weighted avg       1.00      1.00      1.00         4

[[1 0 0 0]
 [0 1 0 0]
 [0 0 1 0]
 [0 0 0 1]]
16
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
FukuharaYohei
気の向いたままにいろいろと書きます。 仕事はSAP関連で、HANA、Fiori、SAPUI5、BusinessObjectsなどいろいろやっています。

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
16
Help us understand the problem. What is going on with this article?