Help us understand the problem. What is going on with this article?

VAEの学習にジェネレータを使う

概要

『PythonとKerasによるディープラーニング』8章のVAEをやってみました。コードはここに公開されています。この通りやって何の問題もありませんでしたが、自分の画像でVAEを学習させようとした場合、データのサイズが大きいのでジェネレータを使う必要がありますが、ふつうにImageDataGeneratorを準備してvae.fit_generatorとしたのではうまくいきませんでした。

結論としては2通りの解決方法が見つかりました。
1) fit_generatorではなく、test_on_batchを使う
2) モデルを変更し、コンパイル時に損失関数を指定する形にする

なぜすんなりfit_generatorが使えないのかというと、サンプルコードのモデルでは損失関数をカスタム層に入れ込んで、モデルのコンパイル時にvae.compile(loss=None)としているためのようです。
fit_genenratorの公式ドキュメントを見ると、fit_generatorではジェネレータの出力に(inputs, targets)または(inputs, targets, sample_weights)のタプルのいずれかしか受け付けないことになっていますので、出力に正解ラベルが出力されるようにしてImageDataGeneratorを与える必要がありますが、モデル側が正解ラベルを受け付けるように作られていないのでエラーとなります。
なのでサンプルコードのモデルのままやるなら、fit_generatorではなくtest_on_batchを使うことで学習できるようになりました。

では、fit_generatorを使うことはできないのかというと、コンパイル時に損失関数を指定するようにモデルを改造してやればfit_generatorで学習ができるようになりました。

1) fit_generatorではなく、test_on_batchを使う方法

モデルはサンプルコードのモデルをそのまま使用しますので、モデルのコンパイルまでは同じです。
データの供給にImageDataGeneratorを使用します。

from keras.preprocessing.image import ImageDataGenerator
train_dir = './MNIST-train'
test_dir = './MNIST-test'

# すべての画像を1/255でスケーリングする
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# ImageDataGeneratorを使ってディレクトリから画像を読み込む
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(28, 28),
    color_mode='grayscale',
    batch_size=batch_size,
    class_mode=None)          # モデルが正解ラベルを受け付けないのでNone
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(28, 28),
    color_mode='grayscale',
    batch_size=batch_size,
    class_mode=None)

ImageDataGeneratorの出力にはモデルへの入力となる画像だけが必要なので、class_mode=Noneと指定します。これで正解ラベルが出力されません。
ジェネレータの出力がこのような形になるとfit_generatorは使えないので、例えば次のような学習ループを自分で作ります。

import math
steps_per_epoch = math.ceil(train_generator.samples / batch_size)       # エポックあたりのバッチ数
iters_verbose = 200                                                     # Lossの表示頻度
curr_epoch = 0
for i, x_train in enumerate(train_generator):
    # エポック数の更新
    if i % steps_per_epoch == 0:
        curr_epoch += 1

    # 学習
    loss = vae.train_on_batch(x_train, y=None)

    # 学習経過の表示
    if i % iters_verbose == 0:
        print('epoch:{}, iters:{}, loss:{:.3f}'.format(curr_epoch, i, loss))

    # 指定したエポック数に達したら終了
    if curr_epoch == 10:
        break

2) モデルを変更し、コンパイル時に損失関数を指定する形にする方法

サンプルコードのモデルは最後にカスタム層をくっつけていますが、これをやめます。
デコーダーの定義以降のコードを下記のように変更します。

decoder = Model(decoder_input, x, name='decoder')
z_decoded = decoder(z)

'''
カスタム層をやめて、VAEの損失関数の関数を定義してcompile時に渡す
モデルの定義も上記カスタム層yを含まない形に変更する
自作損失関数の引数は、(y_true, y_pred)の形をとっている必要があると思う
'''
def vae_loss(x, z_decoded):
    x = K.flatten(x)
    z_decoded = K.flatten(z_decoded)
    xent_loss = keras.metrics.binary_crossentropy(x, z_decoded)
    kl_loss = -5e-4 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=1)
    return K.mean(xent_loss + kl_loss)

vae = Model(input_img, z_decoded)
vae.compile(optimizer='rmsprop', loss=vae_loss)
vae.summary()

上記コードにより出来上がるモデルは下記のようになります。

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 28, 28, 32)   320         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 14, 14, 64)   18496       conv2d_1[0][0]                   
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_2[0][0]                   
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 14, 14, 64)   36928       conv2d_3[0][0]                   
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 12544)        0           conv2d_4[0][0]                   
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 32)           401440      flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 2)            66          dense_1[0][0]                    
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 2)            66          dense_1[0][0]                    
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 2)            0           dense_2[0][0]                    
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
decoder (Model)                 (None, 28, 28, 1)    56385       lambda_1[0][0]                   
==================================================================================================
Total params: 550,629
Trainable params: 550,629
Non-trainable params: 0
__________________________________________________________________________________________________
Backend TkAgg is interactive backend. Turning interactive mode on.

入力と出力の形状が一致していてモデルとしてはこちらのほうがわかりやすい気がします。

損失関数vae_lossの引数が正解ラベルとしてデコーダーの出力すべき画像を必要とする形になっています。なので、ジェネレータは次のようにします。
(また、fit_generatorもジェネレータの出力に(x, y)形式のタプルを要求します)

from keras.preprocessing.image import ImageDataGenerator
train_dir = './MNIST-train'
test_dir = './MNIST-test'

# すべての画像を1/255でスケーリングする
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

# ImageDataGeneratorを使ってディレクトリから画像を読み込む
train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(28, 28),
    color_mode='grayscale',
    batch_size=batch_size,
    class_mode='input')         # fit_generatorを使う場合、ジェネレータの出力は(x, y)のTupleになるようにする
test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(28, 28),
    color_mode='grayscale',
    batch_size=batch_size,
    class_mode='input')

そしてfit_generatorを使って学習します。

vae.fit_generator(train_generator,
                  epochs=10,
                  validation_data=test_generator)

まとめ

どちらの方法でも学習はうまくいきました。
どちらかというと、モデルがきれいな形になるのと、学習ループを自分で書かなくていいのとで、2)の方がいいような気がします。

参考

元サンプルコード
keras blog
損失関数の指定
ジェネレータの自作

Why do not you register as a user and use Qiita more conveniently?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away