3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

【画像生成】AutoencoderとVAEと。。。遊んでみた♬AEとその潜在空間は???

Last updated at Posted at 2019-05-24

昨夜のお約束の「DLの記憶」について書く準備として、AEと潜在空間について記事にしておこうと思います。
※DLの記憶はこの次の記事にします

###やったこと
・通常のAEでVAEと同様な絵を作成する
・潜在空間内の画像生成
・潜在空間を移動させて画像生成する
###・通常のAEでVAEと同様な絵を作成する
通常のAEを以下のように改変してVAEと同じようなことができるか試してみました。
それは、どうしてもあのLatent_dimのところでGauss関数に押し込む必要性が合点できないからです。
すなわち、あの関数はどんなものでもよく(たぶん、BackPropagationできる程度に滑らかなら)、全体のAEとしての機能は変わらないだろうと考察されるからです。逆に言えば、自然に発生する分布(関数と呼んでいいかは別として)を利用しても同じように制御できる可能性があるということで、やってみました。
これはよく説明に「AEは入力を再現するだけ、VAEは自由に画像生成することができる。」すなわち「AEは自由に画像生成できない」というのは誤解だということが元々のモチベーションです

コードは以下のとおりです。
変更点は、z_log_varを止めて、z_meanだけにしようということだけです。
この選択により、sampling関数は不要になります。
※分布は野となれ山となれということになります
すなわち、つなぎの部分が全結合になっていますが、通常のAEです。

inputs = Input(shape=input_shape, name='encoder_input')
x = Conv2D(32, (3, 3), activation='relu', strides=2, padding='same')(inputs)
x = Conv2D(64, (3, 3), activation='relu', strides=2, padding='same')(x)
shape = K.int_shape(x)
print("shape[1], shape[2], shape[3]",shape[1], shape[2], shape[3])
x = Flatten()(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
# instantiate encoder model
encoder = Model(inputs, z_mean, name='encoder')
encoder.summary()

decoderも通常のまんまです。

# build decoder model
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(shape[1] * shape[2] * shape[3], activation='relu')(latent_inputs)
x = Reshape((shape[1], shape[2], shape[3]))(x)
x = Conv2DTranspose(64, (3, 3), activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, (3, 3), activation='relu', strides=2, padding='same')(x)
outputs = Conv2DTranspose(1, (3, 3), activation='sigmoid', padding='same')(x)
# instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()
# instantiate VAE model
outputs = decoder(encoder(inputs))
vae = Model(inputs, outputs, name='vae_mlp')

異なるのは、この後に描画関数(おまけに掲載)をVAEと同じものを使います。
結果として以下のような図が得られます。
以下で示す図も含め、これらの図はほぼsampling関数を利用したものと変わりません。
vae_mean_allz0z1.png
######・潜在空間内の画像生成
画像生成の制御もある意味簡単にできることが分かります。
以下は、おまけに掲載したコードで生成した潜在空間内で生成された画像です。
この絵を見ると、母関数としてガウス分布を仮定しなくとも、画像はバラバラに生成されているわけではなく、分類・適度に分散されながら生成されていることが分かります。
※この部分が実はあとあと面白い解釈ができるのですが、次回の議論とします
digits_over_latent_all.png
###・潜在空間を移動させて画像生成する
そして、上記の潜在空間内を連続的に移動しながら画像生成します。
下記のように問題なく画像生成できます。
つまり、上に示した画像の潜在空間内の分布はガウス分布などの母関数を仮定しなくとも、画像生成は連続する変化となっており、途中の変化がなめらかになっているところがミソです。
z_sample_t_NoSamplingz0=[0.4,-2.3]z7=[0.4,2.7].gif
以下のパラメータ変換で、潜在空間内の点を移動させながらdecoderで画像生成しています。
【参考】以下を参考にしています
Variational Autoencoder徹底解説

z : s*z_0 + (1-t)*z_1
def plot_results2(models,
                 data,
                 batch_size=128,
                 model_name="vae_mnist"):
    z0=np.array([0.4,-2.3])
    z1=np.array([0.4,2.7])
    for t in range(50):
        s=t/50
        z_sample=np.array([s*z0+(1-s)*z1])
        x_decoded = decoder.predict(z_sample)
        plt.imshow(x_decoded.reshape(28, 28))
        plt.title("z_sample="+str(z_sample))
        plt.savefig('./mnist1000/z_sample_t{}'.format(t))
        plt.pause(0.1)
        plt.close()

※上記は50分割にしていますが、これはアップの都合でそうしているだけで、性能的には10倍くらいにしても問題なく動きます
以下のように円上を360度動かして、小さいサイズ(2.5MB位)にしたけど、やはりアップできないので、出来るようになったらアップすることとします。
円上を動かすコードは以下のとおり

def plot_results2(models, data, batch_size=128, model_name="vae_mnist"):
    for t in range(360):
        s=t/360
        z_sample=np.array([[2*np.cos(2*s*np.pi),2*np.sin(2*s*np.pi)]])
        x_decoded = decoder.predict(z_sample)
        plt.imshow(x_decoded.reshape(28, 28))
        plt.title("z_sample="+str(t))
        plt.savefig('./mnist1000/360/z_sample_t{}'.format(t))
        plt.pause(0.01)
        plt.close()

【参考】z0=np.array([0.4,-2.3])などとしている理由は以下の参考のとおりです
Why do I get TypeError: can't multiply sequence by non-int of type 'float'?
###まとめ
・AEで潜在空間のパラメータを動かして画像生成できた
・潜在空間内でパラメータを自由に動かしたとき滑らかに画像生成できることが分かった

・潜在空間内での画像生成と元の学習データの関係についてはDLの記憶とともに次回記事にします
###おまけ

def plot_results(models,
                 data,
                 batch_size=32,
                 model_name="vae_mnist"):
    """Plots labels and MNIST digits as function of 2-dim latent vector
    # Arguments
        models (tuple): encoder and decoder models
        data (tuple): test data and label
        batch_size (int): prediction batch size
        model_name (string): which model is using this function
    """
    encoder, decoder = models
    x_test, y_test = data
    os.makedirs(model_name, exist_ok=True)

    filename1 = "./mnist1000/vae_mean100L2_3"
    # display a 2D plot of the digit classes in the latent space
    z_mean = encoder.predict(x_test, batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.savefig(filename1+"z0z1.png")
    plt.pause(1)
    plt.close()

    filename2 = "./mnist1000/digits_over_latent100L2_3"
    # display a 30x30 2D manifold of digits
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    # linearly spaced coordinates corresponding to the 2D plot
    # of digit classes in the latent space
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]
    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit
    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename2+".png")
    plt.pause(1)
    plt.close()
3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?