前回、pix2pixは、過学習ではないかということを書いたが、実際過学習なのかどうか、もう少し深堀して、本当に過学習なのかどうかを見てみたいと思う。
そもそも過学習というのは、学習データのみを学習して、汎用的なデータに対しては全く無力、そもそも再現できない状態を指す。
通常は何万何百万ものデータを学習して、未知画像などへも同じように対応させているから、ここで見たような100個に満たないデータで何らかの結論を引き出すのは論外なことかもしれない。
そういう事情はあるが、そこは無視して議論を進めたいと思う。
改変したmodels.pyコードは以下に置いた
https://github.com/MuAuan/pix2pix
pix2pixの特徴は以下の三点である
(1)pix2pixは、cGANである
(2)Generatorにu-netを使っている
(3)patchGANを利用して、L1相関を強化
ここで、過学習の原因があるとすれば、Generatorにu-netを使っていることが原因の可能性が高いと考えられる。そこで、u-netについてよく見てみようと思う。
今回実施したこと
(1)u-netの構造の確認
u-netは、以下のような構造である。
U-Net: Convolutional Networks for Biomedical Image Segmentationより
上記の図からわかるように、このモデルの特徴は以下の3点ある
①u-netの特徴はU構造である、最下段で全結合層がない
つまり、底が平坦
②下段に行くほど画像領域が縮小max-pool(2x2)し、上段に戻るときup-conv(2x2)される
つまり、対称的で無駄がない
③一番重要な特徴は、同じサイズの入力側から出力側の同形のテンソルに直接入力されること
つまり、入力に強く依存する
上記の3つの特徴から、過学習回避のための方策を考察する。
①は外していいと思う。どこまで画像の繰り込み(特徴の局所化)をやるかは一つの戦略になるが、ここでは採用しないこととする(後でやるかもしれないが。。)
②についても途中の層を増やすことも考えられるが、無暗に層を増やして不安定化させる必要も無いと判断した。
③は、同形テンソル間で特徴を直接concatenateにより、伝達する部分が学習効率を高めており、逆にテンソルのパラメータフィッティングに対しては、過学習の原因として一番怪しいと考えた。
(2)各層のconcatenateしているのを段階的に止める
つまり、上記の三つ目の特徴に着目し、この結合を実施しないことにより、通常のGenerator(encoder-decoder)に近づけてみようということである。
(3)Dropoutを変える
やはり、通常のモデルと同じように過学習を止めるには、Dropoutの導入により、途中の層間の伝達情報を減少させて過学習を避けることとする。もっとも元のモデルで下位層(i<2)において、Dropout(0.5)であり十分な回避になっているので、上位層までDropoutを導入することにした。
結果
まず、コードであるが、以下の通り変更した。
def generator_unet_upsampling(img_shape, disc_img_shape, model_name="generator_unet_upsampling"):
# Decoder
first_up_conv = up_conv_block_unet(list_encoder[-1], list_encoder[-2],
list_filters_num[0], "unet_upconv2D_1", axis_num, dropout=True)
list_decoder = [first_up_conv]
for i, f in enumerate(list_filters_num[1:]):
name = "unet_upconv2D_" + str(i+2)
if i<2: #2これを変更してドロップアウトとconcatenateの操作をした
d = True
#up_conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
else:
d = False
#up_conv = up_conv_block_unet_alt(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
up_conv = up_conv_block_unet_alt(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
list_decoder.append(up_conv)
また、以下のとおり、concatenateを削除することにより、結合を解いた。
def up_conv_block_unet_alt(x, x2, f, name, bn_axis, bn=True, dropout=False):
x = Activation('relu')(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(f, (3,3), name=name, padding='same')(x)
if bn: x = BatchNormalization(axis=bn_axis)(x)
if dropout: x = Dropout(0.5)(x)
#x = Concatenate(axis=bn_axis)([x, x2])
return x
と2つのペアとTrain、Testデータとして何を選ぶかが問題であるが、前回同様顔データとファッションのデータ、物体、そしてそれらの組み合わせとした。
この結果の代表的なものは、以下のとおりである。
①Train:fashion Test:fashion モデル:pix2pixオリジナル
Train
Test
前回は、これで不満と言ったが、。。。
しかし、faceの輪郭画像のTestは相変わらず壊れて見せられない
②Train:物体 Test:物体 モデル:encoder-decoder
Canny画像だと、10時間で収束が今一つだった。
Train
Test
③対象Train:face Test:fashion モデル:pix2pix
Train
Test
まあ、いろいろ見てるとこれもまとも。。
しかし、色が一色で不満
③対象Train:face Test:fashion モデル:encoder-decoder
Train
Test
色がと言っていたが、encoder-decoderで以下の画像が出てきた。
これって、何が原因だろう。普通のGANだとやはりこの程度には汎用的に覚えてくれるんじゃなかろうか。
この図柄は、もちろん学習データから学習したテンソルが精いっぱい入力画像の特徴を生かして出力している。局所的な特徴と画面全体の大域的特徴からうまく描いている。顔しか想定しないところでこれらの入力があった時の反応として正しいような気がする。
この真ん中の画像をよく観察しつつ、少し離れたところから上の入力を見ると、真ん中のような画像が出てくる理由を理解してもらえると思う。
例えば、人面に見える建物を見たときに顔を意識したとたん顔にしか見えない感覚と似ているように見える。
つまり、u-netは大域的特徴の表現力が落ちているのじゃないかということだ。これは入力画像の全体的特徴に対して過学習となっており、つまり塗り絵になってしまっているし、その色も今一つ表現力がないということになっている。
以下、同じような図柄をいくつか並べてみようと思う。
実に表現力が豊かなような気がする。
最後にこのシリーズのTrainデータの再現性を見ておこう
最後の絵は学習に12時間程度かかったが、それなりの画像が得られている。
もっとも、pix2pixは10分もかからずにこの精度が出るからそれぞれの使い方があると思う。
まとめ
・pix2pixのGeneratorは大域構造の表現力が減少しているという意味で過学習となっていると言える
・pix2pixのGeneratoerに通常のencoder-decoderを導入すると、大域的な特徴表現が増加して、入力に関連した面白い画像を創造できる
課題
・上記と異なる言い方だが、表現力は落ちているが、過学習とまでは言えず、そのあたりの対象と限界をもっと見える化しないとわからない
・今回、u-netのほんの一部のモデル要素を変更して、出力の変化を見たが、それ以外のpatchGANの部分やu-netの他の要素を変更すると何が起こるのかまだまだ不明である
・今回のアプリを使うと適当な画像から思ってもみない画像を創造して見せてくれるがそのバリエーションはどこまで作れるのか
・例えば、グレー画像をTwitterで送信すると、おかしな人面画像(に限らないが)に変換して返信するアプリを作成できそうだ
。。。
generator_unet のモデルサマリ
generator_unet = Model(input=[unet_input], outputs=[x])
____________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
====================================================================================================
unet_input (InputLayer) (None, 128, 128, 1) 0
____________________________________________________________________________________________________
unet_conv2D_1 (Conv2D) (None, 64, 64, 64) 640 unet_input[0][0]
____________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU) (None, 64, 64, 64) 0 unet_conv2D_1[0][0]
____________________________________________________________________________________________________
unet_conv2D_2 (Conv2D) (None, 32, 32, 128) 73856 leaky_re_lu_1[0][0]
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 32, 32, 128) 512 unet_conv2D_2[0][0]
____________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU) (None, 32, 32, 128) 0 batch_normalization_1[0][0]
____________________________________________________________________________________________________
unet_conv2D_3 (Conv2D) (None, 16, 16, 256) 295168 leaky_re_lu_2[0][0]
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 16, 16, 256) 1024 unet_conv2D_3[0][0]
____________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU) (None, 16, 16, 256) 0 batch_normalization_2[0][0]
____________________________________________________________________________________________________
unet_conv2D_4 (Conv2D) (None, 8, 8, 512) 1180160 leaky_re_lu_3[0][0]
____________________________________________________________________________________________________
batch_normalization_3 (BatchNorm (None, 8, 8, 512) 2048 unet_conv2D_4[0][0]
____________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU) (None, 8, 8, 512) 0 batch_normalization_3[0][0]
____________________________________________________________________________________________________
unet_conv2D_5 (Conv2D) (None, 4, 4, 512) 2359808 leaky_re_lu_4[0][0]
____________________________________________________________________________________________________
batch_normalization_4 (BatchNorm (None, 4, 4, 512) 2048 unet_conv2D_5[0][0]
____________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 4, 4, 512) 0 batch_normalization_4[0][0]
____________________________________________________________________________________________________
unet_conv2D_6 (Conv2D) (None, 2, 2, 512) 2359808 leaky_re_lu_5[0][0]
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 2, 2, 512) 2048 unet_conv2D_6[0][0]
____________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 2, 2, 512) 0 batch_normalization_5[0][0]
____________________________________________________________________________________________________
unet_conv2D_7 (Conv2D) (None, 1, 1, 512) 2359808 leaky_re_lu_6[0][0]
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 1, 1, 512) 2048 unet_conv2D_7[0][0]
____________________________________________________________________________________________________
activation_1 (Activation) (None, 1, 1, 512) 0 batch_normalization_6[0][0]
____________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D) (None, 2, 2, 512) 0 activation_1[0][0]
____________________________________________________________________________________________________
unet_upconv2D_1 (Conv2D) (None, 2, 2, 512) 2359808 up_sampling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_7 (BatchNorm (None, 2, 2, 512) 2048 unet_upconv2D_1[0][0]
____________________________________________________________________________________________________
dropout_1 (Dropout) (None, 2, 2, 512) 0 batch_normalization_7[0][0]
____________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 2, 2, 1024) 0 dropout_1[0][0]
batch_normalization_5[0][0]
____________________________________________________________________________________________________
activation_2 (Activation) (None, 2, 2, 1024) 0 concatenate_1[0][0]
____________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D) (None, 4, 4, 1024) 0 activation_2[0][0]
____________________________________________________________________________________________________
unet_upconv2D_2 (Conv2D) (None, 4, 4, 512) 4719104 up_sampling2d_2[0][0]
____________________________________________________________________________________________________
batch_normalization_8 (BatchNorm (None, 4, 4, 512) 2048 unet_upconv2D_2[0][0]
____________________________________________________________________________________________________
dropout_2 (Dropout) (None, 4, 4, 512) 0 batch_normalization_8[0][0]
____________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 4, 4, 1024) 0 dropout_2[0][0]
batch_normalization_4[0][0]
____________________________________________________________________________________________________
activation_3 (Activation) (None, 4, 4, 1024) 0 concatenate_2[0][0]
____________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D) (None, 8, 8, 1024) 0 activation_3[0][0]
____________________________________________________________________________________________________
unet_upconv2D_3 (Conv2D) (None, 8, 8, 256) 2359552 up_sampling2d_3[0][0]
____________________________________________________________________________________________________
batch_normalization_9 (BatchNorm (None, 8, 8, 256) 1024 unet_upconv2D_3[0][0]
____________________________________________________________________________________________________
dropout_3 (Dropout) (None, 8, 8, 256) 0 batch_normalization_9[0][0]
____________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 8, 8, 768) 0 dropout_3[0][0]
batch_normalization_3[0][0]
____________________________________________________________________________________________________
activation_4 (Activation) (None, 8, 8, 768) 0 concatenate_3[0][0]
____________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D) (None, 16, 16, 768) 0 activation_4[0][0]
____________________________________________________________________________________________________
unet_upconv2D_4 (Conv2D) (None, 16, 16, 128) 884864 up_sampling2d_4[0][0]
____________________________________________________________________________________________________
batch_normalization_10 (BatchNor (None, 16, 16, 128) 512 unet_upconv2D_4[0][0]
____________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 16, 16, 384) 0 batch_normalization_10[0][0]
batch_normalization_2[0][0]
____________________________________________________________________________________________________
activation_5 (Activation) (None, 16, 16, 384) 0 concatenate_4[0][0]
____________________________________________________________________________________________________
up_sampling2d_5 (UpSampling2D) (None, 32, 32, 384) 0 activation_5[0][0]
____________________________________________________________________________________________________
unet_upconv2D_5 (Conv2D) (None, 32, 32, 64) 221248 up_sampling2d_5[0][0]
____________________________________________________________________________________________________
batch_normalization_11 (BatchNor (None, 32, 32, 64) 256 unet_upconv2D_5[0][0]
____________________________________________________________________________________________________
concatenate_5 (Concatenate) (None, 32, 32, 192) 0 batch_normalization_11[0][0]
batch_normalization_1[0][0]
____________________________________________________________________________________________________
activation_6 (Activation) (None, 32, 32, 192) 0 concatenate_5[0][0]
____________________________________________________________________________________________________
up_sampling2d_6 (UpSampling2D) (None, 64, 64, 192) 0 activation_6[0][0]
____________________________________________________________________________________________________
unet_upconv2D_6 (Conv2D) (None, 64, 64, 64) 110656 up_sampling2d_6[0][0]
____________________________________________________________________________________________________
batch_normalization_12 (BatchNor (None, 64, 64, 64) 256 unet_upconv2D_6[0][0]
____________________________________________________________________________________________________
concatenate_6 (Concatenate) (None, 64, 64, 128) 0 batch_normalization_12[0][0]
unet_conv2D_1[0][0]
____________________________________________________________________________________________________
activation_7 (Activation) (None, 64, 64, 128) 0 concatenate_6[0][0]
____________________________________________________________________________________________________
up_sampling2d_7 (UpSampling2D) (None, 128, 128, 128) 0 activation_7[0][0]
____________________________________________________________________________________________________
last_conv (Conv2D) (None, 128, 128, 3) 3459 up_sampling2d_7[0][0]
____________________________________________________________________________________________________
activation_8 (Activation) (None, 128, 128, 3) 0 last_conv[0][0]
====================================================================================================
Total params: 19,303,811
Trainable params: 19,295,875
Non-trainable params: 7,936
____________________________________________________________________________________________________