LoginSignup
6
4

More than 5 years have passed since last update.

CapsNetで遊んでみた♬~ブレイクスルーだよ??~

Last updated at Posted at 2018-03-19

前回、CapsNetで遊んでみた♬で見たように、CapsNetはMNISTでは効果的に収束したが、Cifar10では過学習でありしかもTestデータの精度が上がらないし、得られる生成画像もピンボケ画像で今一つだった。

今回は、CNNで得られた知識を駆使して、過学習の回避と精度向上、生成画像鮮明化を目指そうと思う。
結論から先に記載すると、
①過学習は改善し、精度は一つのブレイクスルー75.5%超えを達成
②生成画像鮮明化は、プログラム的には行けそうだけど、結果はお預けな状態
コードは以下に置きました
MuAuan/capsNet

GANの構造は、機能として大きく分けて以下の3つの部分からなっている。
①入力画像の特徴量を分析するレイヤー
 :CNNでは繰り込みを利用して各レイヤー毎に特徴量を抽出する
  今回のCapsNetでは、CapsuleLayerやベクトル構造によって画像を理解する部分を強調するためか、CNNでの一つの重要な方法論である繰り込みの方法を採用していない
②画像の広域を構成する特徴を分析するレイヤー
 :CNNではこれまであまり考慮されていなかったが、pix2pixや今回のCapsuleLayerで考慮された
③Genratorとして画像生成する部分
 :通常のGANではあまり考慮されていなかったが、pix2pixでは同一特徴量を写像して学習効率を大幅に改善した
  今回のCapsNetでは簡単なDenseで画像生成しており、あまりに貧弱である

ということを踏まえて、今回は
①入力側を多層レイヤーにして詳細部分の特徴量を把握できるようにした。
②ベクトル構造での演算及びCapsuleLayer構造はそのまま残すこととした
③画像生成(Generator)部分は、パラメータが膨大になる傾向もあるが、最も簡単なGANの通常の構造を持たせた

入力側を多層レイヤーにして詳細部分の特徴量を把握できるようにした。

すなわち、以下のとおりprimarycapsの直前にConvLayerを3層導入して特徴量を把握できるようにした。ここで、過学習等の対策としてBN層やDropoutを置きたいところであるが、ここではDropoutがあまり効果なかったので(まだまだ実験がやり切れていないのでこれは経過的な印象レベルでしかない)導入しなかった。

def CapsNet(input_shape, n_class, num_routing):
    axis_num = -1
    x = layers.Input(shape=input_shape)

    conv1 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='same', activation='relu', name='conv1')(x)
    conv1 = BatchNormalization(axis=axis_num)(conv1)   #add
    #conv1 = Dropout(0.5)(conv1) #add
    conv2 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='same', activation='relu', name='conv2')(conv1)
    conv2 = BatchNormalization(axis=axis_num)(conv2)   #add
    #conv2 = Dropout(0.5)(conv2) #add
    conv3 = layers.Conv2D(filters=256, kernel_size=9, strides=1, padding='valid', activation='relu', name='conv3')(conv2)
    conv3 = BatchNormalization(axis=axis_num)(conv3)   #add
    #conv3 = Dropout(0.5)(conv3) #add

    primarycaps = PrimaryCap(conv3, dim_vector=8, n_channels=32, kernel_size=9, strides=2, padding='valid')

画像生成(Generator)部分は、パラメータが膨大になる傾向もあるが、最も簡単なGANの(通常の)構造を持たせた

もともと画像生成ルーチンは以下ように三層のDenseで構成されたものであった。

    x_recon = layers.Dense(512, activation='relu')(masked)
    x_recon = layers.Dense(1024, activation='relu')(x_recon)
    x_recon = layers.Dense(np.prod(input_shape), activation='sigmoid')(x_recon)
    x_recon = layers.Reshape(target_shape=input_shape, name='out_recon')(x_recon)

    return models.Model([x, y], [out_caps, x_recon])

これを、以下のように変更した。これが正解というわけではないし、最後のところは適当に辻褄合わせてパラメータ削減してるので後ろめたいが、上の構造よりは2次元的な学習ができるような気がする。(ちなみに、この構造は先日のACGANのGeneratorからの転用である)

    y = layers.Input(shape=(n_class,))
    masked = Mask()([digitcaps, y])
    x_recon = layers.Dense(512, activation='relu')(masked)
    x_recon = layers.Reshape((8, 8, 8))(x_recon)

    # upsample to (..., 16, 16)
    x_recon = layers.UpSampling2D(size=(2, 2))(x_recon)
    x_recon = layers.Conv2D(64, (3, 3), padding='same', activation='relu', kernel_initializer='glorot_normal')(x_recon)

    # upsample to (..., 32, 32)
    x_recon = layers.UpSampling2D(size=(2, 2))(x_recon)
    x_recon = layers.Conv2D(16, (3, 3), padding='same', activation='relu', kernel_initializer='glorot_normal')(x_recon)

    x_recon = layers.Conv2D(3, (2, 2), padding='same', activation='tanh', kernel_initializer='glorot_normal',name='out_recon')(x_recon)

    return models.Model([x, y], [out_caps, x_recon])

その他の改善?

それ以外にも気になるところがたくさんあって、すべてを試せるほど時間もマシンもないのだが、一番気になったのはLoss関数のWeightとLoss関数そのものの選択だ。
基本は、元のLoss関数を利用したが、それぞれ何を使うべきかゆっくり考えたいと思う。

結果

過学習は改善し、精度は一つのブレイクスルー75.5%超えを達成

いろいろいじってもなかなか69%を超えられず、もう多層にするしかという決断をして実施した。
もともと多層じゃないネットワークであり、過学習の改善もDropoutとBatchNormalizationを入れるところは限られていて、多層(パラメータが多い)ということが過学習の原因になっているわけでもなさそうなのでお手上げでした。
たぶん、ベクトルで特徴固定にすることが過学習を招いているというpix2pix(unet版)と同じような状況の予感(構造的過学習)がします。

しかし、逆に入力側に3層入れてBatchNormalizationを入れるとそれで少し過学習が改善し、しかもval_accが75.5%超えを達成できました。ということで以下Trainデータ50000個、Testデータ10000個で実施したときの経緯を示します。このデータはPreTrainingは実施していません。

epoch loss reconloss acc val_los valreconloss val_acc
0 0.748304 0.035237 0.426380 0.616297 0.027571 0.510700
1 0.592471 0.028237 0.566860 0.596677 0.027988 0.553400
2 0.537946 0.027219 0.643800 0.522961 0.026648 0.648500
3 0.494449 0.026103 0.696280 0.507356 0.025010 0.658500
4 0.461882 0.025524 0.739620 0.465719 0.024205 0.711300
5 0.438031 0.025114 0.771700 0.447826 0.023558 0.722200
6 0.414638 0.024783 0.800660 0.431824 0.022865 0.741400
7 0.386893 0.024404 0.839500 0.432268 0.022986 0.738300
8 0.367982 0.024162 0.862140 0.436246 0.023553 0.736800
9 0.346711 0.023973 0.888080 0.406177 0.022033 0.755200
10 0.319535 0.023750 0.919740 0.416420 0.022683 0.751300

生成画像の改善について

これは、今回出力前の3連Dense構造をGeneratorの構造に置き換えたが、あまり改善しなかったので、次回以降の報告とする

そもそも。。。

実は、ブレイクスルーと称したが、75.5%って上記のCNNの例を見てもわかるように、単純なCNNと同程度の精度であり、CapsuleLayerの寄与があるとは言い難い。ということで、次回の課題として本当にCapusuleLayer(ベクトル計算)は寄与しているのかについて取り上げる予定である。

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4