8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

GANとVAE

Posted at

##目次

  1. オートエンコーダとは
  2. 変分オートエンコーダ(VAE)
  3. GAN

##オートエンコーダとは

やってることはしたの図がわかりやすい。要は、エンコーダで画像データの次元を削減することで情報を圧縮し、デコーダで圧縮された情報を使って画像を再構築している。入力画像と出力画像の画素間の距離MAEを算出し、最適化する。教師なし学習。
image.png

さてこのオートエンコーダを使ってどのように画像を生成するのか。
まず画像を使ってオートエンコーダを訓練する。エンコーダとデコーダの二つのネットワークのパラメータが適切になるように最適化を行う。そうすると入力した画像が潜在空間でどのように表現されるかという大体の感触がわかる。これを生成モデルとして使う場合は、基本的にはエンコーダ部分は使わず、潜在空間とデコーダのみを使う。

##変分オートエンコーダ

普通のオートエンコーダは潜在空間を配列として学習しようとするが、変分エンコーダは、潜在空間の分布を定義する適切なパラメータを見つけようとする。そしてこの潜在分布から値をサンプリングして具体的な値を得てデコーダに入力することで画像が再構築されるのである。

スクリーンショット 2020-06-12 2.29.18.png
左が通常のオートエンコーダで右が変分エンコーダ。

詳しいことはこの記事がとてもわかりやすいです。

##GAN

GANの詳しい動きは他の記事に当たっていただきたい。ここでは、GANの特徴的な性質を述べる。
GANアーキテクチャにおいては、生成器も識別器も、識別器の損失関数によって訓練される。識別器自身の訓練では全ての訓練データに対して、識別器の損失を最小化しようとする。一方生成器は、自分が作る偽のサンプルに対して、識別器の損失を最大化しようとする。
つまり普通のニューラルネットワークの訓練は最適化問題であるのに対して、GANの訓練は最適化というよりも、生成器と識別器が競うあうゲーム。ナッシュ均衡に達すると安定する。

GANの訓練アルゴリズムをまとめると以下のようになる

for 訓練における各反復ステップ:
1.識別器の訓練
 a. 本物のデータからランダムにサンプルを選んで、ミニバッチXを作る
 b. 乱数ベクトルのミニバッチzを作り、偽のサンプルからなるミニバッチG(Z)=X`を作る
 c. D(x)とD(x`)に対する識別損失を計算し、全誤差を逆誤差伝搬させて識別器のパラメータを更新する

2.生成器の訓練
 a.乱数ベクトルのミニバッチzを作り、偽のサンプルからなるミニバッチG(z)=X`を作る
 b.D(x`)に対する識別損失を計算し、それを逆伝搬させて生成器のパラメータを更新することで識別誤差を最大化する。

ステップ1で識別器の訓練をするときは生成器のパラメータは更新されないことに注意!
ステップ2で生成器の訓練をするときは識別器のパラメータは更新されないことに注意!

ここまでのGANの基礎的な知識を用いて、可能な限り簡略化したGANを実装する。
(より実践的な実装は以後の記事であげます)

###簡略化したGANの実装
今回は、KerasのSequentialAPIで、MNISTデータから画像を生成するコードを実装する。
####1.import宣言

%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np

from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

####2.入力次元の設定

#入力次元の設定
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)

#生成器への入力ノイズの次元
z_dim = 100

####3.生成器の生成

#生成器
def build_generator(img_shape, z_dim):
    model = Sequential()
    model.add(Dense(128, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(28*28*1, activation='tanh'))
    model.add(Reshape(img_shape))
    return model

####4.識別器の生成

#識別器
def build_discriminator(image_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(1, activation='sigmoid'))
    return model

####5.コンパイル

#コンパイル!

def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model

#識別器の構築及びコンパイル
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=Adam(),
                      metrics=["accuracy"])

#生成器の構築
generator = build_generator(img_shape, z_dim)
#生成器の構築中は、識別器のパラメータは固定する
discriminator.trainable = False
#GANモデルの構築及びコンパイル
gan = build_gan(generator, discriminator)
gan.compile(loss="binary_crossentropy",optimizer=Adam())

####6.trainの設定

#訓練!
losses = []
accuracies = []
iteration_checkpoints = []

def train(iterations, batch_size, sample_interval):
  (X_train, Y_train), (X_test, Y_test) = mnist.load_data() #X_train.shape=(60000, 28, 28)
  X_train = X_train /127.5 - 1.0
  X_train  = np.expand_dims(X_train, axis=3)

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size,1))

  for iteration in range(iterations):

    #本物の画像からランダムに取り出したバッチを作る
    idx = np.random.randint(0, X_train.shape[0],batch_size)
    imgs = X_train[idx]
    #偽の画像のバッチを作成する
    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    #識別器の訓練
    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)

    #偽の画像のバッチを作成
    z = np.random.normal(0, 1, (batch_size, 100))
    ge_images = generator.predict(z)
    #生成器の訓練
    g_loss = gan.train_on_batch(z, real)

    if (iteration+1) % sample_interval == 0:

      #イテレーションお気に損失値と整合値を記録
      losses.append((d_loss, g_loss))
      accuracies.append(100 * accuracy)
      iteration_checkpoints.append(iteration+1)

      print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" %
                  (iteration + 1, d_loss, 100.0 * accuracy, g_loss))
      sample_images(generator)

####7.画像出力

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):

    #ランダムノイズのサンプリング
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))
    gen_imgs = generator.predict(z)

    #画素値のスケール
    gen_imgs = 0.5 * gen_imgs + 0.5

    fig, axs = plt.subplots(image_grid_rows,
                            image_grid_columns,
                            figsize=(4, 4),
                            sharey=True,
                            sharex=True)
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            # Output a grid of images
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            cnt += 1

####8.いざ学習!

iterations = 20000
batch_size = 128
sample_interval = 1000
train(iterations, batch_size, sample_interval)

##結果

1000 [D loss: 0.129656, acc.: 96.09%] [G loss: 3.387729]
2000 [D loss: 0.079047, acc.: 97.66%] [G loss: 3.964481]
3000 [D loss: 0.071152, acc.: 97.27%] [G loss: 5.072118]
4000 [D loss: 0.217956, acc.: 91.02%] [G loss: 3.993687]
5000 [D loss: 0.380112, acc.: 86.72%] [G loss: 3.941338]
6000 [D loss: 0.292950, acc.: 89.45%] [G loss: 4.491636]
7000 [D loss: 0.345073, acc.: 85.55%] [G loss: 4.056399]
8000 [D loss: 0.396545, acc.: 86.33%] [G loss: 3.101150]
9000 [D loss: 0.744731, acc.: 70.70%] [G loss: 2.761991]
10000 [D loss: 0.444913, acc.: 80.86%] [G loss: 3.474383]
11000 [D loss: 0.362310, acc.: 82.81%] [G loss: 3.101751]
12000 [D loss: 0.383188, acc.: 84.38%] [G loss: 3.111648]
13000 [D loss: 0.283140, acc.: 89.06%] [G loss: 3.082010]
14000 [D loss: 0.411019, acc.: 81.64%] [G loss: 2.747284]
15000 [D loss: 0.386751, acc.: 82.03%] [G loss: 2.795580]
16000 [D loss: 0.475734, acc.: 80.86%] [G loss: 2.436490]
17000 [D loss: 0.285364, acc.: 89.45%] [G loss: 2.764011]
18000 [D loss: 0.202013, acc.: 91.80%] [G loss: 4.058733]
19000 [D loss: 0.285773, acc.: 86.72%] [G loss: 3.038511]
20000 [D loss: 0.354960, acc.: 81.64%] [G loss: 2.719907]

↓1000iteration
スクリーンショット 2020-06-12 20.27.24.png
↓2000iteration
スクリーンショット 2020-06-12 20.27.39.png
↓10000iteration
スクリーンショット 2020-06-12 20.28.31.png
↓20000iteration
スクリーンショット 2020-06-12 20.29.26.png

学習し初めはただのノイズのような画像でしたが、最終的には単純な二層の生成器でも割とリアルな手書き文字を生成できているのではないでしょうか。
しかし、単純なGANで生成された手書き画像の背景に白い点が見えてしまい、すぐに手書きではないのがバレてしまいますね。この弱点を改善すべく、次回は畳み込みを用いたDCGANを実装していきたいと思います!

8
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?