8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習する

Last updated at Posted at 2019-05-05

はじめに

今回はCNNの中でも比較的新しいResNetに取り組んでみたいと思います。と言っても2015年に発表されたようなので、もう4年前ですね・・・。

ResNetとは?

正確に説明する力はないのですが、Skip Connection (Shortcut) を利用することで従来よりも深い層を持つことを実現したネットワークと理解しています。日本語ですと、

あたりがわかりやすいかと思います。あとは元の論文ですね。英語ですが12ページだけなので、その気になれば読めると思います(自分はところどころしか読んでいないですが・・・)。

今回のテーマ

以前取り組んだFashion-MNISTの分類をResNet-50で実現しようと思います。今回は制約はなしにしました(ResNetの学習には時間がかかりそうだったので)。

環境

  • Google Colaboratory
  • TensorFlow 2.0 Alpha

コード

こちらです。
なぜかGitHub上ではうまく開けませんでした。GitHubのURLはこちらです。

コード解説

ResidualBlock

from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Activation, MaxPool2D, GlobalAveragePooling2D, Add
from tensorflow.keras import Model

class ResidualBlock(Model):
    def __init__(self, channel_in = 64, channel_out = 256):
        super().__init__()
        
        channel = channel_out // 4
        
        self.conv1 = Conv2D(channel, kernel_size = (1, 1), padding = "same")
        self.bn1 = BatchNormalization()
        self.av1 = Activation(tf.nn.relu)
        self.conv2 = Conv2D(channel, kernel_size = (3, 3), padding = "same")
        self.bn2 = BatchNormalization()
        self.av2 = Activation(tf.nn.relu)
        self.conv3 = Conv2D(channel_out, kernel_size = (1, 1), padding = "same")
        self.bn3 = BatchNormalization()
        self.shortcut = self._shortcut(channel_in, channel_out)
        self.add = Add()
        self.av3 = Activation(tf.nn.relu)
        
    def call(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.av1(h)
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.av2(h)
        h = self.conv3(h)
        h = self.bn3(h)
        shortcut = self.shortcut(x)
        h = self.add([h, shortcut])
        y = self.av3(h)
        return y
    
    def _shortcut(self, channel_in, channel_out):
        if channel_in == channel_out:
            return lambda x : x
        else:
            return self._projection(channel_out)
        
    def _projection(self, channel_out):
        return Conv2D(channel_out, kernel_size = (1, 1), padding = "same")

ResNetではこのブロックを積み重ねていきますので、それをクラスにします。今回はResNet-50ですので、Bottleneck Architectureを利用し一旦次元削減してから復元する処理になっています。余談ですが、Bottleneck Architectureではない通常のアーキテクチャーで実装するとResNet-34になります。

Skip Connectionはself.addの部分になります。このブロック内で計算してきたhとこのブロックの入力であるxを足し合わせています(その前のself.shortcutxの次元を合わせています)。このようにすることで逆伝播の際に勾配消失しづらくなるそうです。

ResNet50

class ResNet50(Model):
    def __init__(self, input_shape, output_dim):
        super().__init__()                
        
        self._layers = [
            Conv2D(64, input_shape = input_shape, kernel_size = (7, 7), strides=(2, 2), padding = "same"),
            BatchNormalization(),
            Activation(tf.nn.relu),
            MaxPool2D(pool_size = (3, 3), strides = (2, 2), padding = "same"),
            ResidualBlock(64, 256),
            [
                ResidualBlock(256, 256) for _ in range(2)                
            ],
            Conv2D(512, kernel_size = (1, 1), strides=(2, 2)),
            [
                ResidualBlock(512, 512) for _ in range(4)                
            ],
            Conv2D(1024, kernel_size = (1, 1), strides=(2, 2)),
            [
                ResidualBlock(1024, 1024) for _ in range(6)                
            ],
            Conv2D(2048, kernel_size = (1, 1), strides=(2, 2)),
            [
                ResidualBlock(2048, 2048) for _ in range(3)
            ],
            GlobalAveragePooling2D(),
            Dense(1000, activation = tf.nn.relu),
            Dense(output_dim, activation = tf.nn.softmax)
        ]
        
    def call(self, x):
        for layer in self._layers:
            if isinstance(layer, list):
                for l in layer:
                    x = l(x)    
            else:
                x = layer(x)
        return x

先ほど作成したResidualBlockや畳み込み層などを組み合わせています。論文の表が参考になるかと思います。今回は論文に忠実に実装したつもりですが、Fashion-MNISTを取り扱う場合、層の数や次元のチューニングは必要かもしれません。
image.png

モデル作成

model = ResNet50((28, 28, 1), 10)
model.build(input_shape = (None, 28, 28, 1))
model.summary()

""" 結果
Model: "res_net50"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  3200      
_________________________________________________________________
batch_normalization_v2 (Batc multiple                  256       
_________________________________________________________________
activation (Activation)      multiple                  0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
residual_block (ResidualBloc multiple                  75904     
_________________________________________________________________
residual_block_1 (ResidualBl multiple                  71552     
_________________________________________________________________
residual_block_2 (ResidualBl multiple                  71552     
_________________________________________________________________
conv2d_11 (Conv2D)           multiple                  131584    
_________________________________________________________________
residual_block_3 (ResidualBl multiple                  282368    
_________________________________________________________________
residual_block_4 (ResidualBl multiple                  282368    
_________________________________________________________________
residual_block_5 (ResidualBl multiple                  282368    
_________________________________________________________________
residual_block_6 (ResidualBl multiple                  282368    
_________________________________________________________________
conv2d_24 (Conv2D)           multiple                  525312    
_________________________________________________________________
residual_block_7 (ResidualBl multiple                  1121792   
_________________________________________________________________
residual_block_8 (ResidualBl multiple                  1121792   
_________________________________________________________________
residual_block_9 (ResidualBl multiple                  1121792   
_________________________________________________________________
residual_block_10 (ResidualB multiple                  1121792   
_________________________________________________________________
residual_block_11 (ResidualB multiple                  1121792   
_________________________________________________________________
residual_block_12 (ResidualB multiple                  1121792   
_________________________________________________________________
conv2d_43 (Conv2D)           multiple                  2099200   
_________________________________________________________________
residual_block_13 (ResidualB multiple                  4471808   
_________________________________________________________________
residual_block_14 (ResidualB multiple                  4471808   
_________________________________________________________________
residual_block_15 (ResidualB multiple                  4471808   
_________________________________________________________________
global_average_pooling2d (Gl multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  2049000   
_________________________________________________________________
dense_1 (Dense)              multiple                  10010     
=================================================================
Total params: 26,313,218
Trainable params: 26,267,778
Non-trainable params: 45,440
_________________________________________________________________
"""

このmodel.buildの意味は正直よくわかっていないのですが、公式によると、どんなinputが来るかわからないサブクラスのために必要とのことで、今回は確かにこれまで書いたコードとは違いinput shapeを引数として与えているので、そのせいかなぁと考えています。

その他

バッチサイズ

バッチサイズは128にしました。論文には256と書かれていましたが、メモリエラー寸前だったので・・・。

最適化手法

Adamにしました。論文にはSGD + Momentumと書かれており、少し試してみたのですが学習の進みが遅いように見受けられたので・・・。余裕があれば、SGD + Momentumでも試してみます。

結果

400エポック訓練した結果、Test Accuracyが91.3%と前回を下回ってしまいました。チューニングすればもう少し精度が上がるかもしれません。

image.png

参考にさせて頂いたコード

追記:SGD + Momentumの結果

400エポック訓練した結果、Test Accuracyが91.4%でした。Adamとそう変わらなかったですね。
image.png

8
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?