21
Help us understand the problem. What are the problem?

posted at

updated at

kerasでGAN(mnist)動かしてみた

はじめに

[2021/2追記] Githubにコード公開しました。参考にしてみてください

リポジトリ内ではGAN以外にDCGANとCGANも公開しています。
この記事で日本語でリポジトリの解説をしています。

今回はGAN(Generative Adversarial Network)をこちらの本で勉強したのでまとめていきたいと思います。
何回かに分けて書いた後に、最後にしっかりまとめるつもりです。
なので、この記事はとてもざっくりとしたものになっています。
今回の記事は本書で紹介されていた、GANの簡単な実装コードの解説とまとめについて書かせていただきます。
GANの詳しい説明については他サイトに譲らせていただいて本稿では簡単に概要のみに止まらせていただきます。(需要がありそうなら後々まとめたものを投稿したいと思います。)
GAN初心者の方の参考になりそうなGANのサイトを載せておきます。
GAN:敵対的生成ネットワークとは何か ~「教師なし学習」による画像生成

GANとは

軽く説明させていただきます。

GANとは日本語で敵対敵生成ネットワークといいます。DNNを用いて画像や音楽、文章の生成に用いられており、今人工知能の分野で最も活発に研究されてる分野の一つです!

入力されたデータの特徴を学んで、入力データに似た何かを生成します。
データは音声、テキスト、画像等なんでもありです。
例えば、猫の画像を大量に入力とすれば出力されるのは猫の画像になります(うまく学習できていれば)

アルゴリズムとしてはふたつのDNNを用意してあげて、それぞれ画像を生成する係と画像が本物か生成された画像かどうかを見分けてあげる係に分けてあげます。この二つのモデルを競わせることで入力された画像に近い画像がアウトプットされる仕組みです。

結果

今回はmnist手書き文字データセットを用いて実装しました。
最初に結果を載せさせて頂きます。
これが生成された画像です。
download.png
それに対してこちらが入力されたデータです。
download.png

このコードはとてもシンプルなDNNで行われたのでまだまだ改善の余地がありますが、それにしても、少ない層のmodeldでもここまで表現できているのは驚きました。
次回の記事で改善版を載せるつもりでいます。

コード

コードになります。
もっと詳しく説明してくれと言われそうです。
しっかりと本を読み終えてから改てまとめようと思っているのでそれまで待ってください!
(いいねくれると。とっっっても励みになります)

simple_gan.py
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from keras.datasets import mnist
from keras.layers import Dense, Flatten, Reshape
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential
from keras.optimizers import Adam

#mnistの形状[28, 28, 1]を定義
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
#generatorが画像を生成するために入力させてあげるノイズの次元
z_dim = 100

#generator(生成器)の定義するための関数
def build_generator(img_shape, z_dim):
  model = Sequential()
  model.add(Dense(128, input_dim=z_dim))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(28*28*1, activation='tanh'))
  model.add(Reshape(img_shape))
  return model

#discriminator(識別器)の定義するための関数
def build_discriminatior(img_shape):
  model = Sequential()
  model.add(Flatten(input_shape=img_shape))
  model.add(Dense(128))
  model.add(LeakyReLU(alpha=0.01))
  model.add(Dense(1, activation='sigmoid'))
  return model

#Ganのモデル定義する(生成器と識別器をつなげてあげる)ための関数
def build_gan(generator, discriminator):
  model = Sequential()
  model.add(generator)
  model.add(discriminator)
  return model

#実際関数を呼び出してにGANのモデルをコンパイルしてあげる
discriminator = build_discriminatior(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])
generator = build_generator(img_shape, z_dim)

#識別器の学習機能をオフにしてあげる。こうすることで、識別器と生成者を別々に学習させてあげられる
discriminator.trainable = False 

gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam())


losses = []
accuracies = []
iteration_checkpoint = []
#学習させてあげるための関数。イテレーション数、バッチサイズ、 何イテレーションで画像を生成して可視化するかを引数にとる
def train(iterations, batch_size, sample_interval):
  (x_train, _), (_, _) = mnist.load_data()

  x_train = x_train / 127.5 - 1
  x_train = np.expand_dims(x_train, axis=3)

  real = np.ones((batch_size, 1))
  fake = np.zeros((batch_size, 1))

  for iteration in range(iterations):

    idx = np.random.randint(0, x_train.shape[0], batch_size)
    imgs = x_train[idx]
    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(imgs, real)
    d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
    d_loss, acc = 0.5 * np.add(d_loss_real, d_loss_fake)

    z = np.random.normal(0, 1, (batch_size, 100))
    gen_imgs = generator.predict(z)

    g_loss = gan.train_on_batch(z, real)
#sample_intervalごとに損失値と精度、チェックポイントを保存
    if (iteration+1) % sample_interval == 0:
      losses.append((d_loss, g_loss))
      accuracies.append(acc)
      iteration_checkpoint.append(iteration+1)
#画像を生成
      sample_images(generator)

#サンプルとして画像を生成するための関数
def sample_images(generator, image_grid_rows =4, image_grid_colmuns=4):
  z = np.random.normal(0, 1, (image_grid_rows*image_grid_colmuns, z_dim))
  gen_images = generator.predict(z)

  gen_images = 0.5 * gen_images + 0.5

  fig, axs = plt.subplots(image_grid_rows, image_grid_colmuns, figsize=(4,4), sharex=True, sharey=True)

  cnt = 0
  for i in range(image_grid_rows):
    for j in range(image_grid_colmuns):
      axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray')
      axs[i, j].axis('off')
      cnt += 1

Register as a new user and use Qiita more conveniently

  1. You can follow users and tags
  2. you can stock useful information
  3. You can make editorial suggestions for articles
What you can do with signing up
21
Help us understand the problem. What are the problem?