突然ですが
爆速でGANをキャッチアップしたい時ってありますよね
ただ、論文をがっつり読むのって時間かかりますよね
GANってコードを読んだ方が早く理解できますよね
ということでコードから理解していこうと思います
より上から直線的に読めるという理由からKerasを選び、Keras-GANレポジトリからサンプルコードを拝借して理解いきます。
どうしてもPytorchじゃないとダメという人はPytorch-GANという、同じくミニマムな実装を公開しているレポジトリがあるのでそちらを参照してください。
※実装コードの設定は元論文と厳密に同じでない場合があるので、鵜呑みにしすぎないように注意はしておくべきです。
例えば今回のContext Encoderであればこのissueによれば、OptimizerはGeneratorとDiscriminatorで学習率が違うはずだが同じにしているとの指摘がされています。
Context Encoderとは
画像をみれば一発でわかると思いますが、一般的にImage Inpainting
と呼ばれるタスクに対するモデルです。
今回のためのミニマムな問題設定
ColabのファイルをColabで開き、ドライブにコピーを保存(下図参考)をするといいと思います。
- cifar10(サイズ: 32x32)に対して
- 8x8のマスクをかけ
- Generator: 切り取られた画像からそのマスクの中を予測(生成)する
- NNのサイズは32x32x3->8x8x3
- Discriminator: 生成されたマスクの中の部分の真贋判定を行う
- NNのサイズは8x8x3->1
データの用意
Cifar10のデータセットから、犬と猫のみを抽出し、データセットとしている
data.py
(X_train, y_train), (_, _) = cifar10.load_data()
print(f'X_train.shape: {X_train.shape}')
# Extract dogs and cats
X_cats = X_train[(y_train == 3).flatten()]
X_dogs = X_train[(y_train == 5).flatten()]
X_train = np.vstack((X_cats, X_dogs))
# Rescale -1 to 1
X_train = X_train / 127.5 - 1.
y_train = y_train.reshape(-1, 1)
# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
print(f'X_train.shape: {X_train.shape}')
print(f'y_train.shape: {y_train.shape}')
print(f'X_cats.shape: {X_cats.shape}')
#=> X_train.shape: (50000, 32, 32, 3) # 全cifar10データ
#=> X_train.shape: (10000, 32, 32, 3) # 犬猫のみ
#=> y_train.shape: (50000, 1)
#=> X_cats.shape: (5000, 32, 32, 3)
Generator
- 同じパラメータの場合横に広いのは見にくくなる原因なのでまとめている
- 一度2x2までサイズをおとしてからUpSamplingをしている
- Encoderの部分では
LeakyReLU
だがDecoderではただのReLUを使っている - 最後の活性化関数は
Tanh
。
generator.py
# generatorのConv層の共通パラメータ
gen_params = {
'kernel_size': 3,
'padding': 'same'
}
def build_generator(**params):
model = Sequential()
# Encoder
model.add(Conv2D(32, strides=2, input_shape=img_shape, **params)) # 32x32x3 -> 16x16x32
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, strides=2, **params)) # 8x8x64
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, strides=2, **params)) # 4x4x128
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(512, kernel_size=1, strides=2, padding="same")) # 2x2x512
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.5))
# Decoder
model.add(UpSampling2D()) # 4x4x512
model.add(Conv2D(128, **params)) # 4x4x128
model.add(Activation('relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D()) # 8x8x128
model.add(Conv2D(64, **params)) # 8x8x64
model.add(Activation('relu'))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(channels, **params)) # 8x8x3
model.add(Activation('tanh'))
# model.summary()は確認用
# model.summary()
masked_img = Input(shape=img_shape)
gen_missing = model(masked_img)
return Model(masked_img, gen_missing)
Discriminator
- 識別モデルで
LeakyReLU
を使っている - BN層の
momentum
はGeneratorも全く同じ - チャンネル数は256まで上がっている。そんなに必要なのかという気持ち
- Activationはシグモイド。出力数は1なのでそうですよねという感じ
discriminator.py
# discriminatorのConv層の共通パラメータ
dis_params = {
'kernel_size': 3,
'padding': 'same'
}
def build_discriminator(**params):
model = Sequential()
model.add(Conv2D(64, strides=2, input_shape=missing_shape, **params)) # 8x8x3 -> 4x4x64
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, strides=2, **params)) # 2x2x128
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(256, **params)) # 2x2x256
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Flatten()) # 1024
model.add(Dense(1, activation='sigmoid')) # 1
# model.summary()
img = Input(shape=missing_shape)
validity = model(img)
return Model(img, validity)
結果
32x32だとめちゃめちゃ見にくいですが、
1行目が真、2行目がマスク、3行目が予測結果です。
圧倒的に見にくいですが、なんとなく画像のcontextを考慮できていることがわかります。