LoginSignup
6
6

More than 1 year has passed since last update.

GAN(敵対的生成ネットワーク)の学習過程を可視化

Posted at

はじめに

GAN(敵対的生成ネットワーク)の学習状況を可視化できないか調べていたら、2015年のツイートを見つけましたので紹介します。

GANの学習とは

軽く、学習についての説明です。GANというか生成モデルは、「サンプルデータ(テストデータ)は何かしらの確率分布に従っている」という仮定のもとで、与えられた標本分布から母集団の確率分布をモデル化し、その分布に従った新たなデータを生成するモデルとなります。例えば画像生成の場合、入力画像データがN枚ある場合、サンプルデータの標本分布(p(画像1)、p(画像2)、p(画像3)、…p(画像N))から母集団の確率分布を学習しモデル化します。生成時はモデル化された確率分布に従う新たな画像データを出力します。

サンプリングされた画像は[0,255] ^ (チャンネル数 x 縦 x 横)の高次元な空間に存在することになります。この空間中の与えられたサンプル画像の確率分布を探索することによって、サンプル画像に類似した画像が生成されるようになります。MNISTの場合はグレー画像の縦28、横28の画像となりますので、[0,255]の(28X28=)784乗の空間に全ての画像が含まれていることになります。このように高次元な場合は説明がとても困難ですので、ここでは1次元と仮定した図で説明します。

MNISTの標本分布から母集団の確率分布を推測
GAN001.png

学習により確率分布にモデルを近似していく
GAN002.png

近似したモデルの確率分布に従った画像を出力する
GAN003.png

こんなイメージです。GANはこの学習を識別器(Discriminator、以下Dと表記)と生成器(Generator、以下Gと表記)の2つのネットワークを用いて相互に学習していくアーキティクチャとなります。

学習の流れは以下の通りです。

  1. Dの学習。テスト画像を入力したDの識別結果と正解ラベルとのクロスバイナリーエントロピー損失を計算する。
  2. Dの学習。Gが生成した画像を入力したDの識別結果と偽物ラベルとのクロスバイナリーエントロピー損失を計算する。
  3. Dの学習。1と2の損失を足し合わせて、Dに誤差逆伝播を行う
  4. Gの学習。Gが生成した画像を入力したDの識別結果と正解ラベルとのクロスエントロピー損失を計算する。
  5. Gの学習。4の損失をGの誤差逆伝播を行う。

の繰り返しとなります。それぞれの学習の際に渡す、正解/偽物ラベルの違いに注意してください。これを実装したコードは以下のようになります。

for epoch in range(EPOCHS):
    for i, (imgs, _) in enumerate(dataloader):

        # サンプルノイズの生成(正規分布に従ったランダムな値)
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], LATENT_DIM))))

        # 判定ラベル
        valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False)


        # テスト画像
        real_imgs = Variable(imgs.type(Tensor))

        # ---------------------
        #  Discriminatorの学習
        # ---------------------
        optimizer_D.zero_grad()

        # テスト画像と正解ラベルのペアで損失を計算
        real_loss = adversarial_loss(discriminator(real_imgs), valid)

        # G(z)と偽物ラベルのペアで損失を計算
        fake_imgs = generator(z)
        fake_loss = adversarial_loss(discriminator(fake_imgs.detach()), fake)

        # それぞれのLossを加算してDiscriminatorの損失
        d_loss = real_loss + fake_loss

        # 損失の誤差逆伝播
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Generatorの学習
        # -----------------
        optimizer_G.zero_grad()

        # D(G(z))と正解ラベルで損失を計算
        g_loss = adversarial_loss(discriminator(fake_imgs), valid)

        # 損失の誤差逆伝播
        g_loss.backward()
        optimizer_G.step()

        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, EPOCHS, i, len(dataloader), d_loss.item(), g_loss.item())
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % SAMPLE_INTERVAL == 0:
            save_image(fake_imgs.data[:25], IMAGES_PATH + "/%07d.png" % batches_done, nrow=5, normalize=True)
            # ログ情報の収集
            G_losses.append(g_loss.item())
            D_losses.append(d_loss.item())
            img_list.append(fake_imgs.data[:25])

実際、どんな感じで学習が進んでいくのか、可視化について調べていたら、2015年のツイートを見つけました。ソースを探していたら、親切にNotebook形式に編集してくれていた記事も見つけました。こちらです

とても興味深かったので、自分でも実行してみました。結果、GoogleColabでは仕様のためかライブラリのバーションがあわず動かせませんでしたが、ローカルのJupyter環境では無事に動かせました。

学習過程のデモをGIFにしてみました。 G(x)がP(data)に近似していく様子が可視化されています。
animation.gif

最後に

なかなか面白い動きです。ミススペルについてもそのままコピーして実行しました。

6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6