3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

ColaboratoryでGANを実行してみた

Last updated at Posted at 2019-11-04

はじめに

以前から気になっていたColaboratoryというJupyter Notebook環境で、これまた以前から気になっていたGANを実行してみただけという内容です。

Colavoratoryについては、【秒速で無料GPUを使う】深層学習実践Tips on Colaboratory の記事が参考になりました。

GANについては、論文Generative Adversarial Networks をさらっと読みました。GANは手元のデータの確率分布(一様分布とみなす)を近似する方法の一種で、2つのネットワークGとDをいい感じに学習させると、Gによって生成されるデータの確率分布は、手元のデータの確率分布と一致するそうです。いい感じに学習できない場合があるので、目下そのための工夫が研究されています。

実行

GANのコードはこちらを参考にさせていただきました。GANがkerasを使ってシンプルにコーディングされていて、勉強になりました。

2つのMLPを定義して(これがGとDです)、Gの出力をDに与えて、Adamで交互に学習させています。Dは、「手元のデータ」と「Gの出力」を見分けられるように学習します。Gは、Dの判別結果が「手元のデータ」となるよう、教師データを操作して学習させます。このときDは学習させません。学習データはMNISTです。

見慣れないReLUがあると思ったら、最近よく使われているLeakyReLUというものらしいです。(参考:活性化関数ReLUについてとReLU一族【追記あり】)ReLUと違い、活性化関数の入力xが0以下でもx*αの値が出力されます。wikiには効果ないとあるのですが、勾配消失が少しでも軽減されるんですかね。よく分かりません。

コードは特に問題なく実行できましたが、fitではなく、自前のループ内でtrain_on_batchを実行して学習を進めているので、historyが返ってきません。lossとaccを可視化したいので、インスタンス変数として保存するコードや可視化用の、コードを追加します。

 # save
      self.all_d_loss_real.append(d_loss_real)
      self.all_d_loss_fake.append(d_loss_fake)
      self.all_g_loss.append(g_loss)
      
      if epoch % sample_interval == 0:
        self.sample_images(epoch)
        np.save('d_loss_real.npy', self.all_d_loss_real)
        np.save('d_loss_fake.npy', self.all_d_loss_fake)
        np.save('g_loss.npy', self.all_g_loss)

realは手元のデータのDのlossで、fakeはGが生成したデータのDのlossです。
ローカルに保存するコードです。

from google.colab import files
import os

file_list = os.listdir("images")

for file in file_list:
    files.download("images"+os.sep+file)

files.download('d_loss_real.npy')
files.download('d_loss_fake.npy')
files.download('g_loss.npy')

lossとかをプロットするコードです。

import numpy as np
import pylab as plt


t1 = np.load('d_loss_real.npy')
t2 = np.reshape(np.load('d_loss_fake.npy'),[np.shape(t1)[0],2])
g_loss = np.load('g_loss.npy')

t = (t1+t2)/2
d_loss = t[:,0]
acc = t[:,1]
d_loss_real = t1[:,0]
d_loss_fake = t2[:,0]
acc_real = t1[:,1]
acc_fake = t2[:,1]


n_epoch = 29801

x = np.linspace(1,n_epoch,n_epoch)
plt.plot(x, acc, label='acc')
plt.plot(x, d_loss, label='d_loss')
plt.plot(x, g_loss, label='g_loss')
plt.plot(x, d_loss_real, label='d_loss_real')
plt.plot(x, d_loss_fake, label='d_loss_fake')
plt.legend()
plt.ylim([0, 2])
plt.grid()
plt.show()

# 移動平均
num=100#移動平均の個数
b=np.ones(num)/num
acc2=np.convolve(acc, b, mode='same')
d_loss2=np.convolve(d_loss, b, mode='same')
d_loss_real2=np.convolve(d_loss_real, b, mode='same')
d_loss_fake2=np.convolve(d_loss_fake, b, mode='same')
g_loss2=np.convolve(g_loss, b, mode='same')

x = np.linspace(1,n_epoch,n_epoch)
plt.plot(x, acc2, label='acc')
plt.plot(x, d_loss2, label='d_loss')
plt.plot(x, g_loss2, label='g_loss')
plt.plot(x, d_loss_real2, label='d_loss_real')
plt.plot(x, d_loss_fake2, label='d_loss_fake')
plt.legend()
plt.ylim([0,1.2])
plt.grid()
plt.show()

結果

Gの生成画像
epoch=0
0.png

epoch=200
200.png

epoch=1000
1000.png

epoch=3000
3000.png

epoch=7000
6600.png

epoch=10000
9800.png

epoch=20000
20000.png

epoch=30000
29800.png

epochが増えるにしたがってMNISTに似た画像が生成されるようになりますが、epoch 7000くらいからは特に変化がなさそうです。

正解率とloss
t.png

上図の移動平均(n=100, 両端ゼロ埋め)
t2.png

epoch 7千くらいからは、acc 0.63, d_loss(realとfakeも) 0.63, g_loss 1.02 ~ 1.08(微増) です(d_lossとg_lossは二値交差エントロピー)。realは手元のデータのDのlossで、fakeはGが生成したデータのDのloss、d_lossはその平均です。

loss は下式のように定義されます。

\textrm{loss} = -\frac{1}{N}\sum_{n=1}^{N}\bigl( y_n\log{p_n}+(1-y_n)\log{(1-p_n)}\bigr)

Nはデータ数で、yはラベル、pはDの出力値(0,1)です。

$\log$が入っているのでややこしいですが、やっていることは平均的なDの出力$\bigl(\prod_{n=1}^{N}p_n^{y_n}\bigr)^{\frac{1}{N}}$を$\log$にして、0~1の$\log$は負の数で見づらいのでマイナスをつけて正の数にしただけです。

\begin{align}
\textrm{loss} &= -\frac{1}{N}\sum_{n=1}^{N}\bigl( y_n\log{p_n}+(1-y_n)\log{(1-p_n)}\bigr) \\
&= -\log{\bigl( \prod_{n=1}^{N}p_n^{y_n}\bigr)^{\frac{1}{N}}} -\log{\bigl( \prod_{n=1}^{N}(1-p_n)^{y_n}\bigr)^{\frac{1}{N}}}
\end{align}

◯ epoch 25000付近のlossとか

loss 平均的なDの出力
g_loss 1.06 0.35
d_loss 0.63 0.53
d_loss_real 0.63 0.53
d_loss_fake 0.63 0.47

考察

ラベルと出力が一致すればするほど、lossは小さくなります。GANはlossを小さくすることが目的ではないので、下がらないのは特に問題ではありません。

学習がうまくいき、手元のデータとGの生成するデータが全く見分けがつかない状態(GANの目的はこの状態になること)であれば、acc=0.5となるはずですが、結果を見る限りそうなっていません。

Gの生成した画像を見ると、明らかに手書き数字じゃなさそうなのがあるため、それがaccが高くなってしまっている原因でしょう。パラメータを弄ればもう少し良くなるかもしれませんが、追い込むことが目的じゃないため、ひとまずこのあたりで止めておきます。

g_lossの値が意味するところは、g_lossの値が低いほど、Gの生成した画像をDが真と判定――つまり、Dを騙せているということ。逆にg_lossの値が高いほど、Dを騙せていないという意味です。仮に、g_lossの平均的なDの出力が0.5となることを目標とするなら、そのときのg_lossは0.7のため、もう少し下がって欲しいところです。

accがd_lossと一致しているのはたまたまだと思います。

epoch7000 ~ にかけて、g_lossの増加量に比べ、d_loss_fakeの低下量が小さいのは気になるところです。平均的なDの出力にしても10倍ほど差があります。Dの学習→Gの学習という順になっているため、それがモロに効いているのでしょうか。

終わりに

やってみたらできたという感じでした。特に詰まるところはないと思いますが、colaboratoryがあまり安定していないため、計算が途中でクラッシュしたり、画面が再読み込みされたと思ったらなぜか過去のノートブックが表示されてしまい、気付かないまま上書き保存して、泣く泣くコード書き直したりもしました。

2.png

画面再読み込み後、下からこのポップアップが出た場合は要注意です。よくよくコードを見るとColaboratoryを開いた直後の編集前のコードで、保存したら編集後のコードに上書きされました。

対策としては、ページを再読み込みすれば良いと思います。私のブラウザはSafariなのですが、Ctrl-rを押してページ再読み込みしたら編集後のコードが表示されて、実行後の変数等もキープされておりました。このポップアップが出たら、慌てて上書き保存しないほうが無難だと思います。

計算のクラッシュについては、定期的にバックアップを取るしかないと思います。

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?