Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
Help us understand the problem. What is going on with this article?

KerasでBEGAN(Boundary Equilibrium Generative Adversarial Networks)を実装する

More than 3 years have passed since last update.

はじめに

GAN(Generative Adversarial Networks: 敵対的生成ネットワーク)という生成モデルの中でも面白い構造の仕組みを最近よく見かけます。Generator(生成者)とDiscriminator(識別者)が互いに競い合って精度を上げていく構造は、美術界での贋作者(=Generator)と鑑定士(Discriminator)の勝負のようでギャラリーフェイクのような世界を彷彿とさせます。仕組み自体に浪漫を感じます。さらにそれで精度良い生成ができるのだから本当に不思議です。

ただ、GANは学習が難しいというのが課題だったようなのですが、この BEGAN というのは二者の対立のバランスを取りながら学習することでこの課題を解決するとのこと。BEGANは結構内容がシンプルで理解しやすかったので、Kerasで実装してみました。

BEGANの面白いところ

色々なGANの改良を引き継ぎつつ、バランスを取る仕組みが入っているみたいなのですが、論文に書いてある特徴(それ以前に別の人が考えたものも含む)から個人的に面白いと思った部分として、

  • Discriminator は AutoEncoder の Loss関数でできている
  • そのAutoEncoderのLossは Wasserstein distance という小難しい概念だが、実際は近似して Mean Absolute Error (つまり $|InputX - AutoEncoder(InputX)|$ ) である
  • 多様性を表現する $\gamma \in [0, 1]$ というパラメータで「サンプルからどれだけ離れた生成をさせるか」を調整できる
  • 学習の進捗を 実際の画像を目視で確認することなく 評価できる計算式を提示している

というところもありつつ、一番興味深いのは、

DiscriminatorのAutoEncoderのLoss関数を $L(x)$ とした時に、
DiscriminatorのLoss関数を$L_{D}(x)$、GeneratorのLoss関数を$L_{G}(x)$とすると、

$L_{D}(x) = L(真の画像) - k_{t} * L(Generatorが生成した画像)$
$L_{G}(x) = L(Generatorが生成した画像)$
$k_{t+1} = k_{t} + \lambda_{k}(\gamma * L(真の画像) - L(Generatorが生成した画像)) $

という形で学習させるところです。
この$k_{t}$というのは最初は0から始まって徐々に大きくなる感じです。

DiscriminatorはLossを減らすために「真の画像のAutoEncodeを頑張る(Lossを小さくする)」のと「偽の画像のAutoEncoderを頑張らない(Lossを大きくする)」という行為を迫られます。最初はk=0なので真の画像のAutoEncoderを最適化していきますが、徐々にkが大きくなりGeneratorの画像のLossを大きくする努力も併せて行います。最後は $(\gamma * L(真の画像) - L(Generatorが生成した画像))=0$になる辺りでkは平衡状態に至ります(そのとき $\gamma$が効いてくるわけですね)。

Generatorは常にAutoEncoderのLossが小さくなる用に生成画像を工夫していくので、徐々に争いが高度になっていくということになります。

このジレンマのような仕組みがとても面白く、それがスッキリ表現されていて、おおー、と思いました(まあ、他のGANのことはよく知らないのですが...)。

実装

ソースコード

https://github.com/mokemokechicken/keras_BEGAN
に置いてあります。

まあまあそれっぽい画像が生成できているのでだいたい実装としてはあっているんじゃないかと思っていますが...

Kerasだとこういう一風変わったモデルや学習を行うのが難しい(あまりサンプルがない)ですが、書き方がわかってくると実はそれほど難しくはなく、作れてしまえばモジュール性の高さから読みやすく、いろいろ応用しやすい良さがあります。

学習の進行

1 batch毎の各種Lossなどの値をPlotすると以下のようになりました。
$\gamma=0.5$ で学習させています。

training.png

それぞれの値の意味は以下のとおりです。

  • m_global: 学習の進行度合いを示す値。これが収束すると良いらしい。
  • k: 上で説明したDiscriminator と Generatorのバランスをとる値
  • loss_discriminator: $L_{D}(x)$
  • loss_generator: $L_{G}(x)$
  • loss_real_x: L(真の画像)
  • loss_gen_x: L(Generatorが生成した画像)

パット見て思うのは、

  • kは一旦大きくなって、その後小さくなるようです。
  • loss_real_x * gamma = loss_gen_x にちゃんと収束しています
  • m_global は徐々に低減して、そろそろ収束しかかっています
  • 他の事例をみると、このあともう少し我慢して、m_globalがほぼ横ばいになったころに終わりにする感じでしょうか

実行後の生成例

サンプル画像

64x64 Pixelの正方形画像ならなんでも良いのですが、サンプル画像として、http://vis-www.cs.umass.edu/lfw/

[new] All images aligned with deep funneling 
(111MB, md5sum 68331da3eb755a505a502b5aacb3c201)

を使わせてもらいました。グレイスケール画像を除くと 13194サンプルあります。

生成画像

学習の進行別に生成された画像を並べてみるとこんな感じでした。

Epoch 1
Epoch 25
Epoch 50
Epoch 75
Epoch 100
Epoch 125
Epoch 150
Epoch 175
Epoch 200
Epoch 215

顔写真としては、epoch125くらいまでのが結構良い感じです。
それ以降は、背景も取り込もうとしたからか、顔の部分の乱れがすごいことになっています。
顔に注目して生成したければ、背景を潰して顔だけにしたやつを使えばもう少し綺麗になるのかもしれないですね。
どれくらい綺麗な画像になるかはModelのConv Layerの数なども関係するそうですし、ちょっと不足気味だったのかもしれません。

実行時間

以下のマシンスペックで約680秒/epoch でした。

  • Linux
  • Dataset: All images aligned with deep funneling (13194 samples)
  • Intel(R) Core(TM) i7-7700K CPU @ 4.20GHz
  • GeForce GTX 1080
  • Environment Variables

    KERAS_BACKEND=theano
    THEANO_FLAGS=device=gpu,floatX=float32,lib.cnmem=1.0
    

さいごに

やっとGANについて少しわかった気がします。

mokemokechicken
お気楽極楽会社員です。気ままに投稿しています。
sprocket
"Sprocket(スプロケット)は、Webサイトにおけるコンバージョン(購入・入会・資料請求・問合せ等)を増やしたい企業様向けに、自社開発のWeb接客ツールの導入及びコンバージョン改善コンサルティングを行っている会社です。 "
https://www.sprocket.bz/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away