はじめに
以前から気になっていたColaboratoryというJupyter Notebook環境で、これまた以前から気になっていたGANを実行してみただけという内容です。
Colavoratoryについては、【秒速で無料GPUを使う】深層学習実践Tips on Colaboratory の記事が参考になりました。
GANについては、論文Generative Adversarial Networks をさらっと読みました。GANは手元のデータの確率分布(一様分布とみなす)を近似する方法の一種で、2つのネットワークGとDをいい感じに学習させると、Gによって生成されるデータの確率分布は、手元のデータの確率分布と一致するそうです。いい感じに学習できない場合があるので、目下そのための工夫が研究されています。
実行
GANのコードはこちらを参考にさせていただきました。GANがkerasを使ってシンプルにコーディングされていて、勉強になりました。
2つのMLPを定義して(これがGとDです)、Gの出力をDに与えて、Adamで交互に学習させています。Dは、「手元のデータ」と「Gの出力」を見分けられるように学習します。Gは、Dの判別結果が「手元のデータ」となるよう、教師データを操作して学習させます。このときDは学習させません。学習データはMNISTです。
見慣れないReLUがあると思ったら、最近よく使われているLeakyReLUというものらしいです。(参考:活性化関数ReLUについてとReLU一族【追記あり】)ReLUと違い、活性化関数の入力xが0以下でもx*αの値が出力されます。wikiには効果ないとあるのですが、勾配消失が少しでも軽減されるんですかね。よく分かりません。
コードは特に問題なく実行できましたが、fitではなく、自前のループ内でtrain_on_batchを実行して学習を進めているので、historyが返ってきません。lossとaccを可視化したいので、インスタンス変数として保存するコードや可視化用の、コードを追加します。
# save
self.all_d_loss_real.append(d_loss_real)
self.all_d_loss_fake.append(d_loss_fake)
self.all_g_loss.append(g_loss)
if epoch % sample_interval == 0:
self.sample_images(epoch)
np.save('d_loss_real.npy', self.all_d_loss_real)
np.save('d_loss_fake.npy', self.all_d_loss_fake)
np.save('g_loss.npy', self.all_g_loss)
realは手元のデータのDのlossで、fakeはGが生成したデータのDのlossです。
ローカルに保存するコードです。
from google.colab import files
import os
file_list = os.listdir("images")
for file in file_list:
files.download("images"+os.sep+file)
files.download('d_loss_real.npy')
files.download('d_loss_fake.npy')
files.download('g_loss.npy')
lossとかをプロットするコードです。
import numpy as np
import pylab as plt
t1 = np.load('d_loss_real.npy')
t2 = np.reshape(np.load('d_loss_fake.npy'),[np.shape(t1)[0],2])
g_loss = np.load('g_loss.npy')
t = (t1+t2)/2
d_loss = t[:,0]
acc = t[:,1]
d_loss_real = t1[:,0]
d_loss_fake = t2[:,0]
acc_real = t1[:,1]
acc_fake = t2[:,1]
n_epoch = 29801
x = np.linspace(1,n_epoch,n_epoch)
plt.plot(x, acc, label='acc')
plt.plot(x, d_loss, label='d_loss')
plt.plot(x, g_loss, label='g_loss')
plt.plot(x, d_loss_real, label='d_loss_real')
plt.plot(x, d_loss_fake, label='d_loss_fake')
plt.legend()
plt.ylim([0, 2])
plt.grid()
plt.show()
# 移動平均
num=100#移動平均の個数
b=np.ones(num)/num
acc2=np.convolve(acc, b, mode='same')
d_loss2=np.convolve(d_loss, b, mode='same')
d_loss_real2=np.convolve(d_loss_real, b, mode='same')
d_loss_fake2=np.convolve(d_loss_fake, b, mode='same')
g_loss2=np.convolve(g_loss, b, mode='same')
x = np.linspace(1,n_epoch,n_epoch)
plt.plot(x, acc2, label='acc')
plt.plot(x, d_loss2, label='d_loss')
plt.plot(x, g_loss2, label='g_loss')
plt.plot(x, d_loss_real2, label='d_loss_real')
plt.plot(x, d_loss_fake2, label='d_loss_fake')
plt.legend()
plt.ylim([0,1.2])
plt.grid()
plt.show()
結果
epochが増えるにしたがってMNISTに似た画像が生成されるようになりますが、epoch 7000くらいからは特に変化がなさそうです。
epoch 7千くらいからは、acc 0.63, d_loss(realとfakeも) 0.63, g_loss 1.02 ~ 1.08(微増) です(d_lossとg_lossは二値交差エントロピー)。realは手元のデータのDのlossで、fakeはGが生成したデータのDのloss、d_lossはその平均です。
loss は下式のように定義されます。
\textrm{loss} = -\frac{1}{N}\sum_{n=1}^{N}\bigl( y_n\log{p_n}+(1-y_n)\log{(1-p_n)}\bigr)
Nはデータ数で、yはラベル、pはDの出力値(0,1)です。
$\log$が入っているのでややこしいですが、やっていることは平均的なDの出力$\bigl(\prod_{n=1}^{N}p_n^{y_n}\bigr)^{\frac{1}{N}}$を$\log$にして、0~1の$\log$は負の数で見づらいのでマイナスをつけて正の数にしただけです。
\begin{align}
\textrm{loss} &= -\frac{1}{N}\sum_{n=1}^{N}\bigl( y_n\log{p_n}+(1-y_n)\log{(1-p_n)}\bigr) \\
&= -\log{\bigl( \prod_{n=1}^{N}p_n^{y_n}\bigr)^{\frac{1}{N}}} -\log{\bigl( \prod_{n=1}^{N}(1-p_n)^{y_n}\bigr)^{\frac{1}{N}}}
\end{align}
◯ epoch 25000付近のlossとか
loss | 平均的なDの出力 | |
---|---|---|
g_loss | 1.06 | 0.35 |
d_loss | 0.63 | 0.53 |
d_loss_real | 0.63 | 0.53 |
d_loss_fake | 0.63 | 0.47 |
考察
ラベルと出力が一致すればするほど、lossは小さくなります。GANはlossを小さくすることが目的ではないので、下がらないのは特に問題ではありません。
学習がうまくいき、手元のデータとGの生成するデータが全く見分けがつかない状態(GANの目的はこの状態になること)であれば、acc=0.5となるはずですが、結果を見る限りそうなっていません。
Gの生成した画像を見ると、明らかに手書き数字じゃなさそうなのがあるため、それがaccが高くなってしまっている原因でしょう。パラメータを弄ればもう少し良くなるかもしれませんが、追い込むことが目的じゃないため、ひとまずこのあたりで止めておきます。
g_lossの値が意味するところは、g_lossの値が低いほど、Gの生成した画像をDが真と判定――つまり、Dを騙せているということ。逆にg_lossの値が高いほど、Dを騙せていないという意味です。仮に、g_lossの平均的なDの出力が0.5となることを目標とするなら、そのときのg_lossは0.7のため、もう少し下がって欲しいところです。
accがd_lossと一致しているのはたまたまだと思います。
epoch7000 ~ にかけて、g_lossの増加量に比べ、d_loss_fakeの低下量が小さいのは気になるところです。平均的なDの出力にしても10倍ほど差があります。Dの学習→Gの学習という順になっているため、それがモロに効いているのでしょうか。
終わりに
やってみたらできたという感じでした。特に詰まるところはないと思いますが、colaboratoryがあまり安定していないため、計算が途中でクラッシュしたり、画面が再読み込みされたと思ったらなぜか過去のノートブックが表示されてしまい、気付かないまま上書き保存して、泣く泣くコード書き直したりもしました。
画面再読み込み後、下からこのポップアップが出た場合は要注意です。よくよくコードを見るとColaboratoryを開いた直後の編集前のコードで、保存したら編集後のコードに上書きされました。
対策としては、ページを再読み込みすれば良いと思います。私のブラウザはSafariなのですが、Ctrl-rを押してページ再読み込みしたら編集後のコードが表示されて、実行後の変数等もキープされておりました。このポップアップが出たら、慌てて上書き保存しないほうが無難だと思います。
計算のクラッシュについては、定期的にバックアップを取るしかないと思います。