機械学習
ディープラーニング
Keras
MNIST
NNコードゴルフ

線形識別よりパラメータ数の少ないCNN

はじめに

ディープラーニングのモデルはパラメータ数が膨大になりがちです。この膨大なパラメータが表現力につながり高度な処理が可能になる、とも考えられますが、本当にそうでしょうか。
今回は、パラメータ数をどこまで減らせるか、MNISTを題材にやってみた結果を簡単に紹介します。
実験用コードはこちら:
https://gist.github.com/stnk20/dbf26e86d2fe6adba1b49408a2b87992

モデル構築方針

一番の肝になるアイデアは、「同じ荷重を共有する層を繰り返し使う」ということです。こうすることでパラメータ数を増やさずに層の数は増やすことができます。そんなこと出来るのか?とも思いますが、CNNでは空間方向、RNNでは時間方向での荷重共有がなされていますので、層方向(チャンネル方向ではありません)の荷重共有が成り立ってもおかしくありません。
ポイントをまとめておきます:

  • 同じ荷重を共有する畳み込み層を繰り返し使います。
  • 層を積んでいくと学習が進みにくくなるので、ResNetで提唱されたSkip-Connectionを使います。
  • パラメータ数を減らすためにSeparableConvを使います。

なお、学習の高速化のためBatchNormalizationも使いたいのですが、試してみるとテストエラーが安定しなかったので使っていません。

パラメータや層の数は、テスト誤差1%未満を目指して調整しました。
学習後、ちゃんと達成できています:

Test loss: 0.029511989441869082
Test accuracy: 0.9908

Kerasで書いたモデルは下記のようになります。荷重を使い回すところ以外は特に変わった点はないかと思います。

def TightResNet(dim,loop,num_classes,input_shape,dropout=0.1):
    x = Input(shape=input_shape)
    h = Conv2D(dim,(5,5),strides=(2,2),padding="valid")(x)

    common_conv = SeparableConv2D(dim,(3,3),padding="same")
    for i in range(loop):
        b = h
        b = Activation("relu")(b)
        b = common_conv(b)
        h = Add()([h,b])

    h = AveragePooling2D((3,3))(h)
    h = Cropping2D(1)(h)
    h = Flatten()(h)
    h = Dropout(dropout)(h)
    y = Dense(num_classes, activation='softmax')(h)

    return Model(inputs=x,outputs=y)

モデルサマリ:

_________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 28, 28, 1)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 12, 12, 20)   520         input_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 12, 12, 20)   0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
separable_conv2d_1 (SeparableCo (None, 12, 12, 20)   600         activation_1[0][0]               
                                                                 activation_2[0][0]               
                                                                 activation_3[0][0]               
                                                                 activation_4[0][0]               
                                                                 activation_5[0][0]               
                                                                 activation_6[0][0]               
                                                                 activation_7[0][0]               
                                                                 activation_8[0][0]               
__________________________________________________________________________________________________
add_1 (Add)                     (None, 12, 12, 20)   0           conv2d_1[0][0]                   
                                                                 separable_conv2d_1[0][0]         
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 12, 12, 20)   0           add_1[0][0]                      
__________________________________________________________________________________________________
add_2 (Add)                     (None, 12, 12, 20)   0           add_1[0][0]                      
                                                                 separable_conv2d_1[1][0]         
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 12, 12, 20)   0           add_2[0][0]                      
__________________________________________________________________________________________________
add_3 (Add)                     (None, 12, 12, 20)   0           add_2[0][0]                      
                                                                 separable_conv2d_1[2][0]         
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 12, 12, 20)   0           add_3[0][0]                      
__________________________________________________________________________________________________
add_4 (Add)                     (None, 12, 12, 20)   0           add_3[0][0]                      
                                                                 separable_conv2d_1[3][0]         
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 12, 12, 20)   0           add_4[0][0]                      
__________________________________________________________________________________________________
add_5 (Add)                     (None, 12, 12, 20)   0           add_4[0][0]                      
                                                                 separable_conv2d_1[4][0]         
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 12, 12, 20)   0           add_5[0][0]                      
__________________________________________________________________________________________________
add_6 (Add)                     (None, 12, 12, 20)   0           add_5[0][0]                      
                                                                 separable_conv2d_1[5][0]         
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 12, 12, 20)   0           add_6[0][0]                      
__________________________________________________________________________________________________
add_7 (Add)                     (None, 12, 12, 20)   0           add_6[0][0]                      
                                                                 separable_conv2d_1[6][0]         
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 12, 12, 20)   0           add_7[0][0]                      
__________________________________________________________________________________________________
add_8 (Add)                     (None, 12, 12, 20)   0           add_7[0][0]                      
                                                                 separable_conv2d_1[7][0]         
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 4, 4, 20)     0           add_8[0][0]                      
__________________________________________________________________________________________________
cropping2d_1 (Cropping2D)       (None, 2, 2, 20)     0           average_pooling2d_1[0][0]        
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 80)           0           cropping2d_1[0][0]               
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 80)           0           flatten_1[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 10)           810         dropout_1[0][0]                  
==================================================================================================
Total params: 1,930
Trainable params: 1,930
Non-trainable params: 0
__________________________________________________________________________________________________

パラメータ数は1930個まで削減できました!
線形識別と比べてみましょう。線形識別は画素数xクラス数の荷重とバイアスの合計がパラメータの数になります。MNISTなら7850個(28*28*10+10)です。
CNNはどうでしょうか。下記リンク先にあるCNNのパラメータ数は1,199,882個です(テスト誤差0.8%)。
https://github.com/keras-team/keras/blob/master/examples/mnist_cnn.py

1930個ってすごいかも!?

まあ、所詮MNISTではあるんですが・・

ちなみに、学習の様子をみていると、もうすこし頑張れそうな感覚もあります。1000パラメータぐらいまでいけるかもしれません。
が、そこまでの気力もないので今回はここまで。

おわりに

今回の構成では計算量は減らないのであまり嬉しいことはないのですが、あえていうならメモリ消費や通信料を節約できるなどでしょうか。ハードウェア化する際に回路規模を小さくできるとかあるんでしょうか。
ニューラルネットワークは適当な制約をつけても動いてくれるので面白いですね。「NNコードゴルフ」というタグを作りました。みなさんも何かやってみませんか?