機械学習
DeepLearning
Keras
Autoencoder
SpatialPyramidPooling

SPP(SpatialPyramidPooling)で入力画像サイズフリーに出来た♬

前回の課題として、物体認識・物体検知などではやはり入力画像サイズはフリーにしておきたい。

ということで、今回は前回のWide_resnet_AutoencoderとWide_Resnetを入力サイズフリーにしたいと思います。

技術的には、以下の参考のもの、いわゆるSpatial Pyramid Poolingという技術を適用してみました。
【参考】
オリジナル論文
Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun
(Submitted on 18 Jun 2014 (v1), last revised 23 Apr 2015 (this version, v4))

参照ソースコード
yhenon/keras-spp Spatial pyramid pooling layers for keras

基本的には、参照ソースコードのとおり、Flatten周りのところを以下のように置き換えるだけなんですが、。。少しはまりました。

model = Sequential()

# uses theano ordering. Note that we leave the image size as None to allow multiple image sizes
model.add(Convolution2D(32, 3, 3, border_mode='same', input_shape=(3, None, None)))
model.add(Activation('relu'))
model.add(Convolution2D(32, 3, 3))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Convolution2D(64, 3, 3, border_mode='same'))
model.add(Activation('relu'))
model.add(Convolution2D(64, 3, 3))
model.add(Activation('relu'))
model.add(SpatialPyramidPooling([1, 2, 4]))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy', optimizer='sgd')

ハマったところ

一つは、input_shape=(3, None, None)))の順番だと、以下のようなエラーが出ます。
「None is not supported in input_shape [bug]」
てっきりNoneはだめなのかなと思いますが、ググってみると以下のサイトに到達
https://github.com/keras-team/keras/issues/5900
そして、以下のような回答を得ます。
If your image_data_format is channels_last, your input_shape should be (None, None, 3).
ということで、配列順の問題でした。

もう一つ大きな問題が起きました。
それは、同じパラメータでは収束しないという問題。
もともと入力が分散と平均で修正して入力していること、SGDのパラメータが大きめなことなどがあるせいだと思いますが、そのまま単に置き換えても全く収束側に動かず、Lossは減少もしないし、ACCはピクリともしません。③をFlattenだけやるとLossが1e-7とか通常は見られない数字になったりしました。AveragePoolingも削除すると通常の値になりました。

ということで、試行錯誤の結果
①入力を通常のx_train/255.などとしました
②OptimizerをAdamに変更しました
③AveragePoolingとFlattenを削除して置き換えました
たぶん、①はAutoencoderの都合上も綺麗な画像を得るためにも必要なことなのでこれしかないと思います。
②はSGDでも正しいパラメータを選べば収束するだろうと考えています。③はLossがありえない大きさになったりするので、この方法が正しいと思います。

ということで、以下のような結果になりました。

コードは以下に置きました
MuAuan/SPP
ここで、SPP / wide_resnet_SPP.pyは、Wide_Resnetの入力フリーソースです。
また、SPP / SPP_wide_resnet_ACE.pyはAutoencoderの入力フリーソースです。
一方、SPP / SpatialPyramidPooling.pyが、入力をサイズフリーにしている関数です。

結果

この置き換えにより、メリットは入力画像サイズフリーですがその恩恵は被写体のサイズに応じて入力画像サイズフリーで解析できるということです。
つまり、同じPG中で同じモデルでサイズ変えて、むしろ解像度は変えずに解析できます。

しかし、デメリットもあるようです。
一つは収束時間が少し長めになるようです。
もう一つは、オプティマイザでも悩みましたが、つまり収束が少し不安定になるかもです。
ただし、オプティマイザの数字を決めてしまえば同じような推移で収束していくので、それほど気にする必要はないかもしれません。
Autoencoderの精度としては、以下の通りでありさほど変化はありません。

しかし、元々のWide_Resnetに適用したときの精度とパラメータを確認すると、以下の結果となりました。並べてみると、精度も少し落ちているように見えます。

epoch/ACC AveragePooling+Flatten SPP[1,2,4] SPP[1] SPP[2] SPP[4] SPP[8]
1 0.543000 0.457000 0.516600 0.448700 0.517600 0.547400
10 0.766100 0.660500 0.682900 0.687900 0.642700 0.712700
50 0.860300 0.841000 0.846400 0.834500 0.823700 0.830000
P.size AveragePooling+Flatten SPP[1,2,4] SPP[1] SPP[2] SPP[4] SPP[8]
Total 2,279,882 2,311,882 2,279,882 2,284,682 2,303,882 2,380,682
Trainable 2,275,370 2,307,370 2,275,370 2,280,170 2,299,370 2,376,170

ここで、SPP[1]とは、以下のように、渡すpool_listの次元1を表しています。

x = SpatialPyramidPooling([1])(x)
x = SpatialPyramidPooling([2])(x)
x = SpatialPyramidPooling([4])(x)
x = SpatialPyramidPooling([1,2,4])(x)

なお、SpatialPyramidPoolingの構造と特徴については、次回示すこととする。

Noneだらけのモデル(Wide_Resnetモデル)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, None, None, 3 0
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, None, None, 1 448         input_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, None, None, 1 64          conv2d_1[0][0]
__________________________________________________________________________________________________
activation_1 (Activation)       (None, None, None, 1 0           batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, None, None, 4 5800        activation_1[0][0]
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, None, None, 4 160         conv2d_2[0][0]
__________________________________________________________________________________________________
activation_2 (Activation)       (None, None, None, 4 0           batch_normalization_2[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, None, None, 4 0           activation_2[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, None, None, 4 14440       dropout_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, None, None, 4 0           conv2d_1[0][0]
__________________________________________________________________________________________________
merge_1 (Merge)                 (None, None, None, 4 0           conv2d_3[0][0]
                                                                 lambda_1[0][0]
__________________________________________________________________________________________________
batch_normalization_3 (BatchNor (None, None, None, 4 160         merge_1[0][0]
__________________________________________________________________________________________________
activation_3 (Activation)       (None, None, None, 4 0           batch_normalization_3[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D)               (None, None, None, 4 14440       activation_3[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, None, None, 4 160         conv2d_4[0][0]
__________________________________________________________________________________________________
activation_4 (Activation)       (None, None, None, 4 0           batch_normalization_4[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, None, None, 4 0           activation_4[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D)               (None, None, None, 4 14440       dropout_2[0][0]
__________________________________________________________________________________________________
merge_2 (Merge)                 (None, None, None, 4 0           conv2d_5[0][0]
                                                                 merge_1[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, None, None, 4 160         merge_2[0][0]
__________________________________________________________________________________________________
activation_5 (Activation)       (None, None, None, 4 0           batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, None, None, 4 14440       activation_5[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, None, None, 4 160         conv2d_6[0][0]
__________________________________________________________________________________________________
activation_6 (Activation)       (None, None, None, 4 0           batch_normalization_6[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout)             (None, None, None, 4 0           activation_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, None, None, 4 14440       dropout_3[0][0]
__________________________________________________________________________________________________
merge_3 (Merge)                 (None, None, None, 4 0           conv2d_7[0][0]
                                                                 merge_2[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, None, None, 4 160         merge_3[0][0]
__________________________________________________________________________________________________
activation_7 (Activation)       (None, None, None, 4 0           batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, None, None, 4 14440       activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, None, None, 4 160         conv2d_8[0][0]
__________________________________________________________________________________________________
activation_8 (Activation)       (None, None, None, 4 0           batch_normalization_8[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout)             (None, None, None, 4 0           activation_8[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, None, None, 4 14440       dropout_4[0][0]
__________________________________________________________________________________________________
merge_4 (Merge)                 (None, None, None, 4 0           conv2d_9[0][0]
                                                                 merge_3[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, None, None, 4 160         merge_4[0][0]
__________________________________________________________________________________________________
activation_9 (Activation)       (None, None, None, 4 0           batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, None, None, 8 28880       activation_9[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, None, None, 8 320         conv2d_10[0][0]
__________________________________________________________________________________________________
activation_10 (Activation)      (None, None, None, 8 0           batch_normalization_10[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout)             (None, None, None, 8 0           activation_10[0][0]
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, None, None, 4 0           merge_4[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, None, None, 8 57680       dropout_5[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, None, None, 8 0           average_pooling2d_1[0][0]
__________________________________________________________________________________________________
merge_5 (Merge)                 (None, None, None, 8 0           conv2d_11[0][0]
                                                                 lambda_2[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, None, None, 8 320         merge_5[0][0]
__________________________________________________________________________________________________
activation_11 (Activation)      (None, None, None, 8 0           batch_normalization_11[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, None, None, 8 57680       activation_11[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, None, None, 8 320         conv2d_12[0][0]
__________________________________________________________________________________________________
activation_12 (Activation)      (None, None, None, 8 0           batch_normalization_12[0][0]
__________________________________________________________________________________________________
dropout_6 (Dropout)             (None, None, None, 8 0           activation_12[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, None, None, 8 57680       dropout_6[0][0]
__________________________________________________________________________________________________
merge_6 (Merge)                 (None, None, None, 8 0           conv2d_13[0][0]
                                                                 merge_5[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, None, None, 8 320         merge_6[0][0]
__________________________________________________________________________________________________
activation_13 (Activation)      (None, None, None, 8 0           batch_normalization_13[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D)              (None, None, None, 8 57680       activation_13[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, None, None, 8 320         conv2d_14[0][0]
__________________________________________________________________________________________________
activation_14 (Activation)      (None, None, None, 8 0           batch_normalization_14[0][0]
__________________________________________________________________________________________________
dropout_7 (Dropout)             (None, None, None, 8 0           activation_14[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D)              (None, None, None, 8 57680       dropout_7[0][0]
__________________________________________________________________________________________________
merge_7 (Merge)                 (None, None, None, 8 0           conv2d_15[0][0]
                                                                 merge_6[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, None, None, 8 320         merge_7[0][0]
__________________________________________________________________________________________________
activation_15 (Activation)      (None, None, None, 8 0           batch_normalization_15[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, None, None, 8 57680       activation_15[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, None, None, 8 320         conv2d_16[0][0]
__________________________________________________________________________________________________
activation_16 (Activation)      (None, None, None, 8 0           batch_normalization_16[0][0]
__________________________________________________________________________________________________
dropout_8 (Dropout)             (None, None, None, 8 0           activation_16[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D)              (None, None, None, 8 57680       dropout_8[0][0]
__________________________________________________________________________________________________
merge_8 (Merge)                 (None, None, None, 8 0           conv2d_17[0][0]
                                                                 merge_7[0][0]
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, None, None, 8 320         merge_8[0][0]
__________________________________________________________________________________________________
activation_17 (Activation)      (None, None, None, 8 0           batch_normalization_17[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, None, None, 1 115360      activation_17[0][0]
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, None, None, 1 640         conv2d_18[0][0]
__________________________________________________________________________________________________
activation_18 (Activation)      (None, None, None, 1 0           batch_normalization_18[0][0]
__________________________________________________________________________________________________
dropout_9 (Dropout)             (None, None, None, 1 0           activation_18[0][0]
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, None, None, 8 0           merge_8[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, None, None, 1 230560      dropout_9[0][0]
__________________________________________________________________________________________________
lambda_3 (Lambda)               (None, None, None, 1 0           average_pooling2d_2[0][0]
__________________________________________________________________________________________________
merge_9 (Merge)                 (None, None, None, 1 0           conv2d_19[0][0]
                                                                 lambda_3[0][0]
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, None, None, 1 640         merge_9[0][0]
__________________________________________________________________________________________________
activation_19 (Activation)      (None, None, None, 1 0           batch_normalization_19[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, None, None, 1 230560      activation_19[0][0]
__________________________________________________________________________________________________
batch_normalization_20 (BatchNo (None, None, None, 1 640         conv2d_20[0][0]
__________________________________________________________________________________________________
activation_20 (Activation)      (None, None, None, 1 0           batch_normalization_20[0][0]
__________________________________________________________________________________________________
dropout_10 (Dropout)            (None, None, None, 1 0           activation_20[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, None, None, 1 230560      dropout_10[0][0]
__________________________________________________________________________________________________
merge_10 (Merge)                (None, None, None, 1 0           conv2d_21[0][0]
                                                                 merge_9[0][0]
__________________________________________________________________________________________________
batch_normalization_21 (BatchNo (None, None, None, 1 640         merge_10[0][0]
__________________________________________________________________________________________________
activation_21 (Activation)      (None, None, None, 1 0           batch_normalization_21[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, None, None, 1 230560      activation_21[0][0]
__________________________________________________________________________________________________
batch_normalization_22 (BatchNo (None, None, None, 1 640         conv2d_22[0][0]
__________________________________________________________________________________________________
activation_22 (Activation)      (None, None, None, 1 0           batch_normalization_22[0][0]
__________________________________________________________________________________________________
dropout_11 (Dropout)            (None, None, None, 1 0           activation_22[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, None, None, 1 230560      dropout_11[0][0]
__________________________________________________________________________________________________
merge_11 (Merge)                (None, None, None, 1 0           conv2d_23[0][0]
                                                                 merge_10[0][0]
__________________________________________________________________________________________________
batch_normalization_23 (BatchNo (None, None, None, 1 640         merge_11[0][0]
__________________________________________________________________________________________________
activation_23 (Activation)      (None, None, None, 1 0           batch_normalization_23[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, None, None, 1 230560      activation_23[0][0]
__________________________________________________________________________________________________
batch_normalization_24 (BatchNo (None, None, None, 1 640         conv2d_24[0][0]
__________________________________________________________________________________________________
activation_24 (Activation)      (None, None, None, 1 0           batch_normalization_24[0][0]
__________________________________________________________________________________________________
dropout_12 (Dropout)            (None, None, None, 1 0           activation_24[0][0]
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, None, None, 1 230560      dropout_12[0][0]
__________________________________________________________________________________________________
merge_12 (Merge)                (None, None, None, 1 0           conv2d_25[0][0]
                                                                 merge_11[0][0]
__________________________________________________________________________________________________
batch_normalization_25 (BatchNo (None, None, None, 1 640         merge_12[0][0]
__________________________________________________________________________________________________
activation_25 (Activation)      (None, None, None, 1 0           batch_normalization_25[0][0]
__________________________________________________________________________________________________
spatial_pyramid_pooling_1 (Spat (None, 3360)         0           activation_25[0][0]
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 10)           33610       spatial_pyramid_pooling_1[0][0]
==================================================================================================
Total params: 2,311,882
Trainable params: 2,307,370
Non-trainable params: 4,512
__________________________________________________________________________________________________