LoginSignup
6
6

More than 3 years have passed since last update.

GANをコードで爆速キャッチアップしたい Context Encoder編

Posted at

突然ですが

爆速でGANをキャッチアップしたい時ってありますよね
ただ、論文をがっつり読むのって時間かかりますよね
GANってコードを読んだ方が早く理解できますよね

ということでコードから理解していこうと思います

より上から直線的に読めるという理由からKerasを選び、Keras-GANレポジトリからサンプルコードを拝借して理解いきます。
どうしてもPytorchじゃないとダメという人はPytorch-GANという、同じくミニマムな実装を公開しているレポジトリがあるのでそちらを参照してください。

実装コードの設定は元論文と厳密に同じでない場合があるので、鵜呑みにしすぎないように注意はしておくべきです。
例えば今回のContext Encoderであればこのissueによれば、OptimizerはGeneratorとDiscriminatorで学習率が違うはずだが同じにしているとの指摘がされています。

Context Encoderとは

元論文はこちら
スクリーンショット 2019-06-02 22.40.53.png

画像をみれば一発でわかると思いますが、一般的にImage Inpaintingと呼ばれるタスクに対するモデルです。

今回のためのミニマムな問題設定

元コード
Colabで書き直したもの

ColabのファイルをColabで開き、ドライブにコピーを保存(下図参考)をするといいと思います。
スクリーンショット 2019-06-02 23.22.29.png

  • cifar10(サイズ: 32x32)に対して
  • 8x8のマスクをかけ
  • Generator: 切り取られた画像からそのマスクの中を予測(生成)する
    • NNのサイズは32x32x3->8x8x3
  • Discriminator: 生成されたマスクの中の部分の真贋判定を行う
    • NNのサイズは8x8x3->1

切り取られた画像と切り取った画像(正解)はこんな感じ
スクリーンショット 2019-06-02 22.59.24.png

データの用意

Cifar10のデータセットから、犬と猫のみを抽出し、データセットとしている

data.py
(X_train, y_train), (_, _) = cifar10.load_data()
print(f'X_train.shape: {X_train.shape}')

# Extract dogs and cats
X_cats = X_train[(y_train == 3).flatten()]
X_dogs = X_train[(y_train == 5).flatten()]
X_train = np.vstack((X_cats, X_dogs))

# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
y_train = y_train.reshape(-1, 1)

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

print(f'X_train.shape: {X_train.shape}')
print(f'y_train.shape: {y_train.shape}')
print(f'X_cats.shape: {X_cats.shape}')

#=> X_train.shape: (50000, 32, 32, 3)  # 全cifar10データ
#=> X_train.shape: (10000, 32, 32, 3)  # 犬猫のみ
#=> y_train.shape: (50000, 1)
#=> X_cats.shape: (5000, 32, 32, 3)

Generator

  • 同じパラメータの場合横に広いのは見にくくなる原因なのでまとめている
  • 一度2x2までサイズをおとしてからUpSamplingをしている
  • Encoderの部分ではLeakyReLUだがDecoderではただのReLUを使っている
  • 最後の活性化関数はTanh
generator.py
# generatorのConv層の共通パラメータ
gen_params = {
    'kernel_size': 3,
    'padding': 'same'
}

def build_generator(**params):
        model = Sequential()

        # Encoder
        model.add(Conv2D(32, strides=2, input_shape=img_shape, **params))  # 32x32x3 -> 16x16x32
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(64, strides=2, **params))                                               # 8x8x64
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(128, strides=2, **params))                                            # 4x4x128
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Conv2D(512, kernel_size=1, strides=2, padding="same"))    # 2x2x512
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dropout(0.5))

        # Decoder
        model.add(UpSampling2D())                                       # 4x4x512
        model.add(Conv2D(128, **params))                         # 4x4x128
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))  
        model.add(UpSampling2D())                                       # 8x8x128
        model.add(Conv2D(64, **params))                           # 8x8x64
        model.add(Activation('relu'))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Conv2D(channels, **params))                 # 8x8x3
        model.add(Activation('tanh'))

        # model.summary()は確認用
        #  model.summary()

        masked_img = Input(shape=img_shape)
        gen_missing = model(masked_img)

        return Model(masked_img, gen_missing)

Discriminator

  • 識別モデルでLeakyReLUを使っている
  • BN層のmomentumはGeneratorも全く同じ
  • チャンネル数は256まで上がっている。そんなに必要なのかという気持ち
  • Activationはシグモイド。出力数は1なのでそうですよねという感じ
discriminator.py
# discriminatorのConv層の共通パラメータ
dis_params = {
    'kernel_size': 3,
    'padding': 'same'
}
def build_discriminator(**params):

    model = Sequential()

    model.add(Conv2D(64, strides=2, input_shape=missing_shape, **params))  # 8x8x3 -> 4x4x64
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(128, strides=2, **params))      # 2x2x128
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Conv2D(256, **params))                         # 2x2x256
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Flatten())                                                    # 1024
    model.add(Dense(1, activation='sigmoid'))              # 1
    # model.summary()

    img = Input(shape=missing_shape)
    validity = model(img)

    return Model(img, validity)

結果

32x32だとめちゃめちゃ見にくいですが、
1行目が真、2行目がマスク、3行目が予測結果です。
圧倒的に見にくいですが、なんとなく画像のcontextを考慮できていることがわかります。

スクリーンショット 2019-06-02 23.15.22.png

6
6
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
6