DeepLearning
Keras
Autoencoder
GANs
CapsNet

高精度で画像が綺麗なclassifier_Autoencoderができた♬

CapsNetを深堀してきたが、ちょっとがっかり、。。。しかし、ここで以下のような大切な結論に到達した
①Autoencoderでは中間層が充実すると生成画像が綺麗になる
②CapsNetをEncoder-Decoderに分解すると、いわゆるClassifier_Autoencoderが作成できた
③画像生成と精度は必ずしも一致しない特性である

そして、どうせならVGG16などの高精度なClassifierを利用して、中間層の充実した所謂
高精度で画像が綺麗なClassifier_Autoencoderを作成しよう。

そこで、VGG16でググってみると、なんとおあえつらいむきなモデルが出てきた。
※Cifar10では精度が93%以上出ているという。
【参考】
geifmany/cifar-vgg

The CIFAR-10 reaches a validation accuracy of 93.56% CIFAR-100 reaches validation accuracy of 70.48%.

ある意味、やればできる感じ。。。ということで、やってみました。
目指せ、最高精度のClassifier_Autoencoder‼

コードは以下に置きました。
AutoEncoder/vgg16_class_AE.py

modelの構造

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_2 (InputLayer)            (None, 32, 32, 3)    0
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 64)   1792        input_2[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 32, 32, 64)   256         conv2d_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 32, 32, 64)   0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 32, 32, 64)   36928       dropout_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 32, 32, 64)   256         conv2d_2[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 32, 32, 64)   0           batch_normalization_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D)  (None, 16, 16, 64)   0           dropout_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 16, 16, 128)  73856       max_pooling2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, 16, 16, 128)  512         conv2d_3[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, 16, 16, 128)  0           batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, 16, 16, 128)  147584      dropout_3[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 16, 16, 128)  512         conv2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D)  (None, 8, 8, 128)    0           batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, 8, 8, 256)    295168      max_pooling2d_2[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 8, 8, 256)    1024        conv2d_5[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, 8, 8, 256)    0           batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 8, 8, 256)    590080      dropout_4[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 8, 8, 256)    1024        conv2d_6[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, 8, 8, 256)    0           batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 8, 8, 256)    590080      dropout_5[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8, 8, 256)    1024        conv2d_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 4, 4, 256)    0           batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 4, 4, 512)    1180160     max_pooling2d_3[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 4, 4, 512)    2048        conv2d_8[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, 4, 4, 512)    0           batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 4, 4, 512)    2359808     dropout_6[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 4, 4, 512)    2048        conv2d_9[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, 4, 4, 512)    0           batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 4, 4, 512)    2359808     dropout_7[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 4, 4, 512)    2048        conv2d_10[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 2, 2, 512)    0           batch_normalization_10[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 2, 2, 512)    2359808     max_pooling2d_4[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 2, 2, 512)    2048        conv2d_11[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, 2, 2, 512)    0           batch_normalization_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 2, 2, 512)    2359808     dropout_8[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 2, 2, 512)    2048        conv2d_12[0][0]
__________________________________________________________________________________________________
dropout_9 (Dropout)             (None, 2, 2, 512)    0           batch_normalization_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 2, 2, 512)    2359808     dropout_9[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 2, 2, 512)    2048        conv2d_13[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D)  (None, 1, 1, 512)    0           batch_normalization_13[0][0]
__________________________________________________________________________________________________
dropout_10 (Dropout)            (None, 1, 1, 512)    0           max_pooling2d_5[0][0]
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 512)          0           dropout_10[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 512)          262656      flatten_1[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 512)          2048        dense_1[0][0]
__________________________________________________________________________________________________
dropout_11 (Dropout)            (None, 512)          0           batch_normalization_14[0][0]
__________________________________________________________________________________________________
up_sampling2d_4 (UpSampling2D)  (None, 32, 32, 64)   0           max_pooling2d_1[0][0]
__________________________________________________________________________________________________
softmax (Dense)                 (None, 10)           5130        dropout_11[0][0]
__________________________________________________________________________________________________
conv_out (Conv2D)               (None, 32, 32, 3)    1731        up_sampling2d_4[0][0]
==================================================================================================
Total params: 15,003,149
Trainable params: 14,993,677
Non-trainable params: 9,472
__________________________________________________________________________________________________

特に、重要なアイディアは中間層の次元を確保するということで、

up_sampling2d_4 (UpSampling2D)  (None, 32, 32, 64)   0           max_pooling2d_1[0][0]
conv_out (Conv2D)               (None, 32, 32, 3)    1731        up_sampling2d_4[0][0]

とかなり上部というか、max_pooling2d_1[0][0]は最初のpoolingです。
そして、それとは関係なく、精度を稼ぐために、VGGライクなかなりディープなネットワークでカテゴライズします。残念ながら、元のソースではoptimizerがSGDでしかもパラメーターをepoch毎にドロップする設計になっています。またL2正則化も取り入れています。
しかし、ここではソースの簡単化のため(実際にはやったけどうまく収束しないので)どちらも実施しないこととしました。
※もちろんもともとのソースはちゃんと収束してほぼ記載の通りの精度が出ました。

結果

まず精度は以下の通り、89%を簡単に超えています。

val_softmax_acc: 0.892663- val_conv_out_acc: 0.884400
softmax_acc: 0.991320- conv_out_acc: 0.827088

【参考】
CIFAR-10 - Object Recognition in Images
そして、画像は以下のとおりのものが得られています。

生成画像

中間層の画像(上:conv2d_9のものを拡張)と初期の近くの画像(先ほどの最初のpoolingの結果を拾って一回のUpSamplingで拡張したもの)を示します。
conv2d_9 (Conv2D) (None, 4, 4, 512) 2359808 dropout_6[0][0]
(8192次元)
conv2d_9 ある意味、普通の生成画像
conv_last.gif
max_pooling2d_1 (MaxPooling2D) (None, 16, 16, 64) 0 dropout_2[0][0]
(16384次元)と次元は前者の倍であるが、生成画像は非常に美しくしかも収束は非常に速く、早い段階から画像変化はほとんど見られない。
conv2d_2.gif

まとめ

・パラメーターフィッティングはまだまだ最適化の余地はあるが、十分に美しい画像生成88.44%と十分な精度89.26%を兼ね備えたClassifier_Autoencoderを作成できた
・画像生成は、中間層の次元が大きいほど美しい画像が得られ、精度はよりディープなCNN(max_poolingが有効)が有利なようである。

課題

・重要なことは精度と美しい画像生成とは必ずしも強い相関を持っているようには見えない。つまり、巷で言われているように、マルチタスクの方がより高い精度を達成できるかどうかは再度検証する必要があると考察される。
・どのLayerを復元すれば一番美しい画像が得られるか、そもそも各層の次元もまだまだ最適化の余地がある。特に今回次元の大きさは2倍だが、得られる画像は段違いに美しい画像であったことからまだまだ未知の要素がある可能性もある。
・また、過学習の状態も画像生成については全くなく、一方精度については若干観測されるので、調整する必要がある。また、L2正則化やDropoutも最適化できておらず、さらなる精度や美しい画像生成は十分に可能な範囲である。
・Classifier_Autoencoderとしては、ほぼ満足なものであるが、cGANやVAEなどへの拡張ができると期待できる
・物体検知などを想定すると、入力画像の大きさの自由度はないので、入力画像サイズフリーに対応する予定である

ちょっと考察

因みに、CapsNetで問題となったMax_poolingについて、今回モデルはMax_poolingを次元最小になるまで実施しており、これはちょうどあの物理の相転移などで応用された繰り込み手法と同じように特徴抽出に十分に威力を発揮しているということを示していると考察される。もちろん位置情報の記憶はどんどん失われるが、その分各階層でそれぞれ特徴を抽出していることが重要である。そして今回適当な中間層からまともな画像生成ができたのが意義深いと考える(ある意味、そこをもっと追究すれば世界最高かな)。この場合も、最適化関数がまさしくエントロピーであり、同じことを数理的に実施している(エントロピーを下げる)と言える。この考察(ヒントンさんの主張の再考察)は別途詳細に実施したいと思う。
特徴量抽出については、参考①が詳しく説明している。また、相転移への繰り込み群手法については参考②を参照されたい。
【参考】
TensorFlow Tutorialの数学的背景 − Deep MNIST for Experts(その1)
繰り込み群と物性物理学 臨界現象,そして多様な展開へ by 菊 池 誠・岡 部 豊(pdf)