Help us understand the problem. What is going on with this article?

KerasでLSGAN書く

More than 3 years have passed since last update.

KerasでLSGAN書いてみました。

元論文:Least Squares Generative Adversarial Networks

この記事のリポジトリ:https://github.com/t-ae/watch-generator-keras

LSGANとは?

二乗誤差を使うことで普通のDCGANより本物に近い画像が得られるらしい。
GANのバリエーションはいろいろありますがLSGANについて特筆すべきことは実装が簡単ということです。

準備

データセット

Gressiveから集めた腕時計画像1300枚ほど。
正面向きのに限定しています。元のサイズは148x190。

画像の入出力サイズ

コンテンツ的に縦長の画像を作りたいので適切なサイズを考えます。
どのような手段を使うとしてもアップサンプリングの倍率が大きいのはよろしくないので、
散々悩んだ末6x7を元に2倍アップサンプリングを4回かけて得られる96x112を使うことにしました。
正方形の画像を使っているケースをよく見ますがこういうめんどくささを避けるためでもあるのかなぁと思いました。

実装

目的関数

LSGANの目的関数は以下です。

\min_D V_{LSGAN}(D) = \frac{1}{2}\mathbb{E}_{{\bf x} \sim p_{data}({\bf x})}[(D({\bf x}) - b)^2] + \frac{1}{2}\mathbb{E}_{{\bf z} \sim p_z(\bf{z})}[(D(G({\bf z})) - a)^2] 
\min_G V_{LSGAN}(G) = \frac{1}{2}\mathbb{E}_{{\bf z} \sim p_z({\bf z})}[(D(G({\bf z})) - c)^2] 

a, b, cは今回は0, 1, 1にしました(詳細については論文を参照してください)。
GeneratorのほうはMSEを半分にするだけでいいのですが、Discriminatorのほうはラベルごとに平均する処理が必要になります。

def create_lsgan_d_loss(a, b):
    def loss_func(y_true, y_pred):
        a_mask = K.cast(K.equal(y_true, a), K.floatx())
        b_mask = K.cast(K.equal(y_true, b), K.floatx())
        a_loss = K.sum((y_pred * a_mask - a) ** 2) / K.sum(a_mask)
        b_loss = K.sum((y_pred * b_mask - b) ** 2) / K.sum(b_mask)
        return (a_loss + b_loss) / 2
    return loss_func

(他の方の記事にロスはサンプルごとに返すと書いてあるのですが、ソースのこのへんを読んだ感じだとウェイトやマスクをかけるために用意されている関数がそうなっているだけで、逆にそれらを使わないなら今回のようにサンプルをまたいで計算しちゃっても良いんじゃないかと思ってます。)

ネットワーク

Generator

def create_generator():

    normalization = BatchNormalization

    inp = Input([Z_DIMENSION])
    x = Dense(7*6*256, use_bias=False)(inp)
    x = normalization()(x)
    x = ELU()(x)
    x = Reshape([7, 6, 256])(x)
    x = Conv2DTranspose(256, 3, padding="same", strides=2, use_bias=False)(x)  # 14x12
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(256, 3, padding="same", strides=2, use_bias=False)(x)  # 28x24
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(128, 3, padding="same", strides=2, use_bias=False)(x)  # 56x48
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(64, 3, padding="same", strides=2, use_bias=False)(x)  # 112x96
    x = normalization()(x)
    x = ELU()(x)
    x = Conv2DTranspose(3, 5, padding="same", strides=1)(x)
    x = Activation("tanh")(x)
    return Model(inp, x, name="generator")

特にいうことなし。

Discriminator

def create_discriminator(out="linear"):

    normalization = InstanceNormalization

    return Sequential([
        InputLayer([112, 96, 3]),
        Conv2D(32, 7, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Conv2D(64, 5, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Conv2D(128, 3, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Conv2D(256, 3, padding="same", use_bias=False),
        normalization(),
        ELU(),
        AvgPool2D(),
        Flatten(),
        Dense(1),
        Activation(out)
    ])

こちらではInstance Normalizationを使っています(コード)。
以前も書きましたがBatchNormalizationだと本物と偽の画像を同時に突っ込むと両者間に隔たりがありすぎて学習不可能になってしまうので、それの回避のためです。一応本物でtrain, 偽でtrainと2ステップに分けて学習することはできるのですが、LSGANだと目的関数が両方を同時に要求するのでこうなっています。
ちなみに前の記事ではNormalizationを入れないで学習していましたが、入れたほうが長期的にG-D間の均衡が保てるようでした。

Generatorの画像を保存しておく

Apple Machine Learning JournalによるとGeneratorが生成した画像の履歴を使うとDiscriminatorをより良く学習できる的なことが書いてあったので試しに入れてみてます。とはいえこの記事を読んだだけで、深く考えず適当に実装したせいか効果は正直分からなかったです。
データセットと同じサイズのバッファを用意し、毎エポック100枚をランダムに除去して新しい100枚を入れるというふうになっています。
(追記:論文読んできたらエポックごとじゃなくてステップごとに入れ替えてるみたいでした)

その他

こちらの記事の方法をいくつか取り込んでます。
1. zは正規分布から採取
2. 本物画像にエポックごとに減衰するガウシアンノイズをのせる
3. 説明済みのInstance Normalizationとか

またデータが1300枚だと少ないので縦4ピクセル、横2ピクセルまでランダムに移動するようにしました。

学習結果

Dのロスが1を超え続けてバグってる臭いのですがそれっぽいものは一応出てました。
すべて同じz群から採取した結果で、上の数字がDiscriminatorの出力になります。
GTX1060 6GBにて2日弱ほど回しました。

0エポック
0

100エポック
100

200エポック
100

500エポック
100

1000エポック
1000

3000エポック
1000

このへんが一番綺麗に出力できてます。インデックスの方向もそれっぽいです。

6000エポック
1000

黒に偏ってたり形状が怪しかったり。

6300エポック
1000

完全にぶっ壊れました。

まとめ

ケースの形状はかなり早い段階から取れていて、丸が綺麗に出力されるところなどは素晴らしいのですが、最もこだわりたい文字盤上の表現が全然でした。モデルを変更しまくっていてどれだったか覚えていないのですが、スモセコやオフセンターが出ているようなのも一応出来ました。アスペクト比すら今の値でなく、しかもMode collapse臭いですが画像が残っていたので貼っておきます。
old.png

最初は高画質なのができたら記事にしようと思っていたのですが、まだ満足なところまで辿り着いていません。
いろいろ試しているとKerasでは自由度が低かったり、ソースをたどりに行く必要があったりして結構手間だったので、PyTorchに移ることにしました。
今回の記事はとりあえずKerasでのまとめということで、PyTorchでもう少し高画質化を狙ってみようと思います。

t-ae
qoncept
リアルタイム画像認識を専門にした会社です。近年はスポーツにおける認識技術の応用に力を入れています。
https://qoncept.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした