GANとpix2pixについて
GAN(敵対的生成ネットワーク)
-
Generator(生成器) 学習データの分布, 生成データの確率分布の2つが近づくよう学習
-
Discriminator(識別器) Gから生成された物とXが本物か偽物かを判別
\underset{G}{min}\,\underset{D}{max}V(D,G) = \mathbb{E}_{x \sim pdata(x)}[logD(x)]+\mathbb{E}_{z \sim p_z(z)}[log(1-D(G(z))]
- Gはノイズ$z$を入力(pix2pixの場合は画像(ピクセル同士の対応関係ラベル))
- $P_z(z)$は出力分布
- $P_{data}(x)$はdataの分布
- DがmaxにGがminになるよう更新
pix2pixの損失関数
CGANの目的関数
\mathcal{L}_{cGAN} = \mathbb{E}_{x,y}[logD(x,y)]+\mathbb{E}_{x,y}[log(1-D(x,G(x,z))]
L1損失関数
\mathcal{L}_{L1}(G) = \mathbb{E}_{x,y,z}[\|y-G(x,z)\|_1]
pix2pixの目的関数
G^{*} = arg\,\underset{G}{min}\,\underset{D}{max}\,\mathcal{L}_{cGAN}(G,D)+\lambda\mathcal{L}_{L1}(G)
- $\lambda$はL1損失と交差エントロピーの比率を決定するパラメータ
https://arxiv.org/pdf/1411.1784.pdf
https://arxiv.org/pdf/1611.07004v1.pdf
https://arxiv.org/pdf/1505.04597.pdf
実装参考
https://elix-tech.github.io/ja/2017/02/06/gan.html
https://qiita.com/triwave33/items/f6352a40bcfbfdea0476
https://qiita.com/ichi_pg/items/1dca539f9b734c0389eb
導入
今更感は凄く感じますが楽しそうだったのでやって見ました。
動けばいいやとあまり深くは理解できていません。あとGANの学習を収束させるのはやはり難しいみたいです。
CycleGANでは対になるデータセットを用意しなくても良いみたいですがpix2pixは対になるデータが必要です。
まず手始めに256*256のサイズで38,000枚のペアで試して見ました。しかしGPUのメモリに入り切らなかったです。方法はあるのでしょうが分からないためサイズを小さくしました。
今回はデータ数も少なく汎化性能は期待出来そうにありません。
データセット
- 画像データはスクレイピングで12,000枚を用意
- 学習データ10,000枚とテストデータ2,000枚に分割
- 画像は縦横短い方を128pxにリサイズ、128*128にクロップ
- 線画データの作成はopencvでグレースケール変換し白部分の膨張からdiffを取る方法で作成
- 着色データと線画データのペアをそれぞれ用意
実装について
実装参考を元にpix2pixを作成。U-Netの理解に苦しみます。
Discriminatorが強くなり過ぎる件について
タイトルにもある通りかなりの時間を無駄にしました。
まずGeneratorとDiscriminatorは交互に学習させなければなりません。
しかしGANの学習が上手く進みません。
学習が上手く進まない理由として以下が上げられるそうです。
- ナッシュ均衡しない
- モード崩壊
- Discriminatorの圧勝により勾配消失
他にもあるとは思いますが「Discriminatorの圧勝により勾配消失」、完全こちらに苦しめられました。
参考
以下画像の様になります。
油断するとDiscriminatorが強くなり全て偽物と判断し、Generatorの学習が進みません。油断するも何もどうにも出来ません。
そこでDiscriminatorの学習を何回かの割合で飛ばします。具体的には
Generatorに対し2:1だったり3:1の割合で学習させました。
これに根拠があるのか不明です......。
しかし、最初は上の画像に直ぐ陥り頭を抱えていましたが、徐々に塗れるようになります。以下ランダムに抜き出し保存。
※webからスクレイピングして頂いた画像ですのでライセンス表記が難しいです。
7時間ほどで学習を止めました。
結果
- バッチサイズ 100
- 最適化手法 Adam(α=0.0002, β1=0.5, β2=0.9)
- Loss(G) = Loss(GAN)+ 100*Loss(L1)
以前に私が描いた画像で試します。
うーむ、あまり良くない。でもキャラの輪郭は検出してそうです。
256*256の線画で無理やり試して見ます。
錆や玉ボケが逆にすごい。
学習に使用したデータについては着色出来ています。
おわりに
まだあまり知識が深まっていないのでちゃんと理解出来たら良いのですが先は長そうです。
そして次はヒントを与える着色を実装して見たいです。
以下参考