Qiita Teams that are logged in
You are not logged in to any team

Log in to Qiita Team
Community
OrganizationAdvent CalendarQiitadon (β)
Service
Qiita JobsQiita ZineQiita Blog
14
Help us understand the problem. What is going on with this article?
@MuAuan

pix2pixの出現で。。。~検証編2 U-net構造を理解する~

More than 3 years have passed since last update.

前回、pix2pixは、過学習ではないかということを書いたが、実際過学習なのかどうか、もう少し深堀して、本当に過学習なのかどうかを見てみたいと思う。

そもそも過学習というのは、学習データのみを学習して、汎用的なデータに対しては全く無力、そもそも再現できない状態を指す。
通常は何万何百万ものデータを学習して、未知画像などへも同じように対応させているから、ここで見たような100個に満たないデータで何らかの結論を引き出すのは論外なことかもしれない。
 そういう事情はあるが、そこは無視して議論を進めたいと思う。

改変したmodels.pyコードは以下に置いた
https://github.com/MuAuan/pix2pix

pix2pixの特徴は以下の三点である

(1)pix2pixは、cGANである
(2)Generatorにu-netを使っている
(3)patchGANを利用して、L1相関を強化

ここで、過学習の原因があるとすれば、Generatorにu-netを使っていることが原因の可能性が高いと考えられる。そこで、u-netについてよく見てみようと思う。

今回実施したこと

(1)u-netの構造の確認

u-netは、以下のような構造である。
u-net_Fig1.png
U-Net: Convolutional Networks for Biomedical Image Segmentationより
上記の図からわかるように、このモデルの特徴は以下の3点ある

①u-netの特徴はU構造である、最下段で全結合層がない

 つまり、底が平坦

②下段に行くほど画像領域が縮小max-pool(2x2)し、上段に戻るときup-conv(2x2)される

 つまり、対称的で無駄がない

③一番重要な特徴は、同じサイズの入力側から出力側の同形のテンソルに直接入力されること

 つまり、入力に強く依存する

上記の3つの特徴から、過学習回避のための方策を考察する。
①は外していいと思う。どこまで画像の繰り込み(特徴の局所化)をやるかは一つの戦略になるが、ここでは採用しないこととする(後でやるかもしれないが。。)
②についても途中の層を増やすことも考えられるが、無暗に層を増やして不安定化させる必要も無いと判断した。
③は、同形テンソル間で特徴を直接concatenateにより、伝達する部分が学習効率を高めており、逆にテンソルのパラメータフィッティングに対しては、過学習の原因として一番怪しいと考えた。

(2)各層のconcatenateしているのを段階的に止める

つまり、上記の三つ目の特徴に着目し、この結合を実施しないことにより、通常のGenerator(encoder-decoder)に近づけてみようということである。

(3)Dropoutを変える

やはり、通常のモデルと同じように過学習を止めるには、Dropoutの導入により、途中の層間の伝達情報を減少させて過学習を避けることとする。もっとも元のモデルで下位層(i<2)において、Dropout(0.5)であり十分な回避になっているので、上位層までDropoutを導入することにした。

結果

まず、コードであるが、以下の通り変更した。

def generator_unet_upsampling(img_shape, disc_img_shape, model_name="generator_unet_upsampling"):
 # Decoder
    first_up_conv = up_conv_block_unet(list_encoder[-1], list_encoder[-2],
                        list_filters_num[0], "unet_upconv2D_1", axis_num, dropout=True)
    list_decoder = [first_up_conv]
    for i, f in enumerate(list_filters_num[1:]):
        name = "unet_upconv2D_" + str(i+2)
        if i<2:  #2これを変更してドロップアウトとconcatenateの操作をした
            d = True
            #up_conv = up_conv_block_unet(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
        else:
            d = False
            #up_conv = up_conv_block_unet_alt(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
        up_conv = up_conv_block_unet_alt(list_decoder[-1], list_encoder[-(i+3)], f, name, axis_num, dropout=d)
        list_decoder.append(up_conv)

また、以下のとおり、concatenateを削除することにより、結合を解いた。

def up_conv_block_unet_alt(x, x2, f, name, bn_axis, bn=True, dropout=False):
    x = Activation('relu')(x)
    x = UpSampling2D(size=(2, 2))(x)
    x = Conv2D(f, (3,3), name=name, padding='same')(x)
    if bn: x = BatchNormalization(axis=bn_axis)(x)
    if dropout: x = Dropout(0.5)(x)
    #x = Concatenate(axis=bn_axis)([x, x2])
    return x

と2つのペアとTrain、Testデータとして何を選ぶかが問題であるが、前回同様顔データとファッションのデータ、物体、そしてそれらの組み合わせとした。
この結果の代表的なものは、以下のとおりである。

①Train:fashion Test:fashion モデル:pix2pixオリジナル

Train
current_batch_training190.png
Test
current_batch_validation170.png
前回は、これで不満と言ったが、。。。
しかし、faceの輪郭画像のTestは相変わらず壊れて見せられない

②Train:物体 Test:物体 モデル:encoder-decoder

Canny画像だと、10時間で収束が今一つだった。
Train
current_batch_training990.png
Test
current_batch_validation990.png

③対象Train:face Test:fashion モデル:pix2pix

Train
current_batch_training160.png
Test
current_batch_validation160.png
まあ、いろいろ見てるとこれもまとも。。
しかし、色が一色で不満

③対象Train:face Test:fashion モデル:encoder-decoder

Train
current_batch_training190.png
Test
色がと言っていたが、encoder-decoderで以下の画像が出てきた。
current_batch_validation190.png
これって、何が原因だろう。普通のGANだとやはりこの程度には汎用的に覚えてくれるんじゃなかろうか。
この図柄は、もちろん学習データから学習したテンソルが精いっぱい入力画像の特徴を生かして出力している。局所的な特徴と画面全体の大域的特徴からうまく描いている。顔しか想定しないところでこれらの入力があった時の反応として正しいような気がする。
この真ん中の画像をよく観察しつつ、少し離れたところから上の入力を見ると、真ん中のような画像が出てくる理由を理解してもらえると思う。
例えば、人面に見える建物を見たときに顔を意識したとたん顔にしか見えない感覚と似ているように見える。
つまり、u-netは大域的特徴の表現力が落ちているのじゃないかということだ。これは入力画像の全体的特徴に対して過学習となっており、つまり塗り絵になってしまっているし、その色も今一つ表現力がないということになっている。
以下、同じような図柄をいくつか並べてみようと思う。
current_batch_validation120.png
current_batch_validation180.png
current_batch_validation480.png
実に表現力が豊かなような気がする。
最後にこのシリーズのTrainデータの再現性を見ておこう
current_batch_training640.png
最後の絵は学習に12時間程度かかったが、それなりの画像が得られている。
もっとも、pix2pixは10分もかからずにこの精度が出るからそれぞれの使い方があると思う。

まとめ

・pix2pixのGeneratorは大域構造の表現力が減少しているという意味で過学習となっていると言える
・pix2pixのGeneratoerに通常のencoder-decoderを導入すると、大域的な特徴表現が増加して、入力に関連した面白い画像を創造できる

課題

・上記と異なる言い方だが、表現力は落ちているが、過学習とまでは言えず、そのあたりの対象と限界をもっと見える化しないとわからない
・今回、u-netのほんの一部のモデル要素を変更して、出力の変化を見たが、それ以外のpatchGANの部分やu-netの他の要素を変更すると何が起こるのかまだまだ不明である
・今回のアプリを使うと適当な画像から思ってもみない画像を創造して見せてくれるがそのバリエーションはどこまで作れるのか
・例えば、グレー画像をTwitterで送信すると、おかしな人面画像(に限らないが)に変換して返信するアプリを作成できそうだ
。。。

generator_unet のモデルサマリ
  generator_unet = Model(input=[unet_input], outputs=[x])
____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to
====================================================================================================
unet_input (InputLayer)          (None, 128, 128, 1)   0
____________________________________________________________________________________________________
unet_conv2D_1 (Conv2D)           (None, 64, 64, 64)    640         unet_input[0][0]
____________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)        (None, 64, 64, 64)    0           unet_conv2D_1[0][0]
____________________________________________________________________________________________________
unet_conv2D_2 (Conv2D)           (None, 32, 32, 128)   73856       leaky_re_lu_1[0][0]
____________________________________________________________________________________________________
batch_normalization_1 (BatchNorm (None, 32, 32, 128)   512         unet_conv2D_2[0][0]
____________________________________________________________________________________________________
leaky_re_lu_2 (LeakyReLU)        (None, 32, 32, 128)   0           batch_normalization_1[0][0]
____________________________________________________________________________________________________
unet_conv2D_3 (Conv2D)           (None, 16, 16, 256)   295168      leaky_re_lu_2[0][0]
____________________________________________________________________________________________________
batch_normalization_2 (BatchNorm (None, 16, 16, 256)   1024        unet_conv2D_3[0][0]
____________________________________________________________________________________________________
leaky_re_lu_3 (LeakyReLU)        (None, 16, 16, 256)   0           batch_normalization_2[0][0]
____________________________________________________________________________________________________
unet_conv2D_4 (Conv2D)           (None, 8, 8, 512)     1180160     leaky_re_lu_3[0][0]
____________________________________________________________________________________________________
batch_normalization_3 (BatchNorm (None, 8, 8, 512)     2048        unet_conv2D_4[0][0]
____________________________________________________________________________________________________
leaky_re_lu_4 (LeakyReLU)        (None, 8, 8, 512)     0           batch_normalization_3[0][0]
____________________________________________________________________________________________________
unet_conv2D_5 (Conv2D)           (None, 4, 4, 512)     2359808     leaky_re_lu_4[0][0]
____________________________________________________________________________________________________
batch_normalization_4 (BatchNorm (None, 4, 4, 512)     2048        unet_conv2D_5[0][0]
____________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)        (None, 4, 4, 512)     0           batch_normalization_4[0][0]
____________________________________________________________________________________________________
unet_conv2D_6 (Conv2D)           (None, 2, 2, 512)     2359808     leaky_re_lu_5[0][0]
____________________________________________________________________________________________________
batch_normalization_5 (BatchNorm (None, 2, 2, 512)     2048        unet_conv2D_6[0][0]
____________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)        (None, 2, 2, 512)     0           batch_normalization_5[0][0]
____________________________________________________________________________________________________
unet_conv2D_7 (Conv2D)           (None, 1, 1, 512)     2359808     leaky_re_lu_6[0][0]
____________________________________________________________________________________________________
batch_normalization_6 (BatchNorm (None, 1, 1, 512)     2048        unet_conv2D_7[0][0]
____________________________________________________________________________________________________
activation_1 (Activation)        (None, 1, 1, 512)     0           batch_normalization_6[0][0]
____________________________________________________________________________________________________
up_sampling2d_1 (UpSampling2D)   (None, 2, 2, 512)     0           activation_1[0][0]
____________________________________________________________________________________________________
unet_upconv2D_1 (Conv2D)         (None, 2, 2, 512)     2359808     up_sampling2d_1[0][0]
____________________________________________________________________________________________________
batch_normalization_7 (BatchNorm (None, 2, 2, 512)     2048        unet_upconv2D_1[0][0]
____________________________________________________________________________________________________
dropout_1 (Dropout)              (None, 2, 2, 512)     0           batch_normalization_7[0][0]
____________________________________________________________________________________________________
concatenate_1 (Concatenate)      (None, 2, 2, 1024)    0           dropout_1[0][0]
                                                                   batch_normalization_5[0][0]
____________________________________________________________________________________________________
activation_2 (Activation)        (None, 2, 2, 1024)    0           concatenate_1[0][0]
____________________________________________________________________________________________________
up_sampling2d_2 (UpSampling2D)   (None, 4, 4, 1024)    0           activation_2[0][0]
____________________________________________________________________________________________________
unet_upconv2D_2 (Conv2D)         (None, 4, 4, 512)     4719104     up_sampling2d_2[0][0]
____________________________________________________________________________________________________
batch_normalization_8 (BatchNorm (None, 4, 4, 512)     2048        unet_upconv2D_2[0][0]
____________________________________________________________________________________________________
dropout_2 (Dropout)              (None, 4, 4, 512)     0           batch_normalization_8[0][0]
____________________________________________________________________________________________________
concatenate_2 (Concatenate)      (None, 4, 4, 1024)    0           dropout_2[0][0]
                                                                   batch_normalization_4[0][0]
____________________________________________________________________________________________________
activation_3 (Activation)        (None, 4, 4, 1024)    0           concatenate_2[0][0]
____________________________________________________________________________________________________
up_sampling2d_3 (UpSampling2D)   (None, 8, 8, 1024)    0           activation_3[0][0]
____________________________________________________________________________________________________
unet_upconv2D_3 (Conv2D)         (None, 8, 8, 256)     2359552     up_sampling2d_3[0][0]
____________________________________________________________________________________________________
batch_normalization_9 (BatchNorm (None, 8, 8, 256)     1024        unet_upconv2D_3[0][0]
____________________________________________________________________________________________________
dropout_3 (Dropout)              (None, 8, 8, 256)     0           batch_normalization_9[0][0]
____________________________________________________________________________________________________
concatenate_3 (Concatenate)      (None, 8, 8, 768)     0           dropout_3[0][0]
                                                                   batch_normalization_3[0][0]
____________________________________________________________________________________________________
activation_4 (Activation)        (None, 8, 8, 768)     0           concatenate_3[0][0]
____________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)   (None, 16, 16, 768)   0           activation_4[0][0]
____________________________________________________________________________________________________
unet_upconv2D_4 (Conv2D)         (None, 16, 16, 128)   884864      up_sampling2d_4[0][0]
____________________________________________________________________________________________________
batch_normalization_10 (BatchNor (None, 16, 16, 128)   512         unet_upconv2D_4[0][0]
____________________________________________________________________________________________________
concatenate_4 (Concatenate)      (None, 16, 16, 384)   0           batch_normalization_10[0][0]
                                                                   batch_normalization_2[0][0]
____________________________________________________________________________________________________
activation_5 (Activation)        (None, 16, 16, 384)   0           concatenate_4[0][0]
____________________________________________________________________________________________________
up_sampling2d_5 (UpSampling2D)   (None, 32, 32, 384)   0           activation_5[0][0]
____________________________________________________________________________________________________
unet_upconv2D_5 (Conv2D)         (None, 32, 32, 64)    221248      up_sampling2d_5[0][0]
____________________________________________________________________________________________________
batch_normalization_11 (BatchNor (None, 32, 32, 64)    256         unet_upconv2D_5[0][0]
____________________________________________________________________________________________________
concatenate_5 (Concatenate)      (None, 32, 32, 192)   0           batch_normalization_11[0][0]
                                                                   batch_normalization_1[0][0]
____________________________________________________________________________________________________
activation_6 (Activation)        (None, 32, 32, 192)   0           concatenate_5[0][0]
____________________________________________________________________________________________________
up_sampling2d_6 (UpSampling2D)   (None, 64, 64, 192)   0           activation_6[0][0]
____________________________________________________________________________________________________
unet_upconv2D_6 (Conv2D)         (None, 64, 64, 64)    110656      up_sampling2d_6[0][0]
____________________________________________________________________________________________________
batch_normalization_12 (BatchNor (None, 64, 64, 64)    256         unet_upconv2D_6[0][0]
____________________________________________________________________________________________________
concatenate_6 (Concatenate)      (None, 64, 64, 128)   0           batch_normalization_12[0][0]
                                                                   unet_conv2D_1[0][0]
____________________________________________________________________________________________________
activation_7 (Activation)        (None, 64, 64, 128)   0           concatenate_6[0][0]
____________________________________________________________________________________________________
up_sampling2d_7 (UpSampling2D)   (None, 128, 128, 128) 0           activation_7[0][0]
____________________________________________________________________________________________________
last_conv (Conv2D)               (None, 128, 128, 3)   3459        up_sampling2d_7[0][0]
____________________________________________________________________________________________________
activation_8 (Activation)        (None, 128, 128, 3)   0           last_conv[0][0]
====================================================================================================
Total params: 19,303,811
Trainable params: 19,295,875
Non-trainable params: 7,936
____________________________________________________________________________________________________
14
Help us understand the problem. What is going on with this article?
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
MuAuan
2021年為になる記事にする 記事420いいね2000フォロワー200 2020年;いい記事を書く 記事359/350いいね1590/1500フォロワ ー144/150 2019年 記事275/300いいね1035/1000フォロワー97/100 2018年 記事140/200いいね423/500フォロワー48/50 7/8/2018 記事90いいね227フォロワー25

Comments

No comments
Sign up for free and join this conversation.
Sign Up
If you already have a Qiita account Login
14
Help us understand the problem. What is going on with this article?