5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

DCGANによるMNIST画像生成プログラム作成(tensorflowチュートリアル)

Posted at

GAN及びDCGANとは

 深層学習の分野で無くてはならなくなっているアルゴリズムであるGAN(Generative adversarial networks)について今回学びました。
 GANとは、2014年にlan Goodfellowというアメリカの研究者によって生み出された技術で、2つのネットワークを敵対的に訓練させていくことで、本物と見分けがつかないようなデータを生成していくネットワークのことです。
 日本語では敵対性生成ネットワークと直訳されていますが、名前が厨2っぽさがあり、カッコよくて好きです。

 tensorflowのURLでチュートリアルとして挙げられているDCGANについて私なりに解釈したものを今回まとめました。
https://www.tensorflow.org/tutorials/generative/dcgan?hl=ja
 このチュートリアルでは、DCGAN(Deep Convolutional GAN)という手法でMNIST(手書き数字)を生成させるアルゴリズムを取り扱っています。

 DCGANとはICLR2016(AI分野の会議)で発表された論文で提案された生成モデルのことです。
 所謂GANとの差は、全結合層(Affine)を行わずDeep Convolutional=畳み込みを利用している点にあります。
 全結合層モデルは重みづけ係数が非常に大きく過学習になりやすい特徴がありますが、畳み込みのみで構成すると過学習を防ぐことができるようです。一方で収束が遅くなる傾向があるとのことです。

001.jpg

 

GAN全般に参考になったURL
https://blog.negativemind.com/2019/09/07/deep-convolutional-gan/

#__future__モジュールとは

gan.py

from __future__ import absolute_import, division, print_function, unicode_literals

 Python2.6以降ではこのコードによりPython3系の挙動に一部の関数や命令の挙動を変更できるというものです。Python2系で3系の関数を使いたい場合に読み込むと理解しました。

#読み込んだライブラリ、モジュールについて

gan.py
import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

 画像を処理するためのライブラリとしてimageioとPIL(Pillow)を読み込みました。画像処理系のライブラリというとOpenCVが有名ですが、Pillowのほうがコードがシンプルで分かりやすいとのことです。
 また、ファイルパス名を取得できるモジュールとしてglobを読み込みました。ワイルドカード*などの特殊文字を使って条件を満たすファイル名、フォルダ名などのパスをリストやイテレータで取得できます。

 次にIPython.displayモジュールですが、これは音声や動画をNotebook上に埋め込むことができる機能があります。

#データセットの読み込み

gan.py
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]

 今回はMNSIT(手書き数字画像)をkerasのデータセットから読み込みます。
 それらは数列として記載されており、標準化を行います。

gan.py

BUFFER_SIZE =60000
BATCH_SIZE =256
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

Dataset化

#Generatorモデルの定義

gan.py

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

#Discriminatorモデルの定義

gan.py

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

https://www.slideshare.net/HiroyaKato1/gandcgan-188544721
https://www.hellocybernetics.tech/entry/2018/05/28/180012
https://keras.io/ja/getting-started/sequential-model-guide/

#損失(loss)関数の定義

gan.py

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

 一般的にニューラルネットワークモデルを鍛えるにはこの損失関数の重みパラメータの勾配が小さくなるように重みパラメータを調整していきます。
 Discriminator側が本物と偽物を見分ける能力を上げさせ、Generator側がDiscriminatorを騙す能力を上げさせる関数を定義します。
 この部分がGAN特有の技術を表す定義ですね。

#訓練関数の定義

gan.py

EPOCHS =50
noise_dim = 100
num_examples_to_generate = 16
seed= tf.random.normal([num_examples_to_generate,noise_dim])

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # Save the model every 15 epochs
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # Generate after the final epoch
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

#計算した後の画像

10epochs後
image_at_epoch_0010.png

30epochs後
image_at_epoch_0030.png

50epochs後
image_at_epoch_0050.png

 30epochsを超えてきたあたりから、かなり数字に見える画像となっています。私の環境※だと1epochに5分ほどかかったので、250分程度かかったことになります。
 GUIを使用することができるGoogle Colabを活用しましょう。。

※現PC環境
PC:Windows 10 Home
CPU:Intel Core i7 3.6GHz
RAM:8GB

#終わりに

 今回は初めて実装してみたGANですが、考え方や計算含めてコンピュータが得意とする計算機能を存分に発揮しているプログラムであり楽しかったです。
 生成器や識別器、損失関数の考え方などは多くのモデルが提唱されているため、一つ一つ学んでまとめていければと思います。

 全コードはこちらに置きました。
https://github.com/Fumio-eisan/dcgan20200306

5
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
6

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?