Help us understand the problem. What is going on with this article?

TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習する

More than 1 year has passed since last update.

はじめに

今回はCNNの中でも比較的新しいResNetに取り組んでみたいと思います。と言っても2015年に発表されたようなので、もう4年前ですね・・・。

ResNetとは?

正確に説明する力はないのですが、Skip Connection (Shortcut) を利用することで従来よりも深い層を持つことを実現したネットワークと理解しています。日本語ですと、

あたりがわかりやすいかと思います。あとは元の論文ですね。英語ですが12ページだけなので、その気になれば読めると思います(自分はところどころしか読んでいないですが・・・)。

今回のテーマ

以前取り組んだFashion-MNISTの分類をResNet-50で実現しようと思います。今回は制約はなしにしました(ResNetの学習には時間がかかりそうだったので)。

環境

  • Google Colaboratory
  • TensorFlow 2.0 Alpha

コード

こちらです。
なぜかGitHub上ではうまく開けませんでした。GitHubのURLはこちらです。

コード解説

ResidualBlock

from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Activation, MaxPool2D, GlobalAveragePooling2D, Add
from tensorflow.keras import Model

class ResidualBlock(Model):
    def __init__(self, channel_in = 64, channel_out = 256):
        super().__init__()

        channel = channel_out // 4

        self.conv1 = Conv2D(channel, kernel_size = (1, 1), padding = "same")
        self.bn1 = BatchNormalization()
        self.av1 = Activation(tf.nn.relu)
        self.conv2 = Conv2D(channel, kernel_size = (3, 3), padding = "same")
        self.bn2 = BatchNormalization()
        self.av2 = Activation(tf.nn.relu)
        self.conv3 = Conv2D(channel_out, kernel_size = (1, 1), padding = "same")
        self.bn3 = BatchNormalization()
        self.shortcut = self._shortcut(channel_in, channel_out)
        self.add = Add()
        self.av3 = Activation(tf.nn.relu)

    def call(self, x):
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.av1(h)
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.av2(h)
        h = self.conv3(h)
        h = self.bn3(h)
        shortcut = self.shortcut(x)
        h = self.add([h, shortcut])
        y = self.av3(h)
        return y

    def _shortcut(self, channel_in, channel_out):
        if channel_in == channel_out:
            return lambda x : x
        else:
            return self._projection(channel_out)

    def _projection(self, channel_out):
        return Conv2D(channel_out, kernel_size = (1, 1), padding = "same")

ResNetではこのブロックを積み重ねていきますので、それをクラスにします。今回はResNet-50ですので、Bottleneck Architectureを利用し一旦次元削減してから復元する処理になっています。余談ですが、Bottleneck Architectureではない通常のアーキテクチャーで実装するとResNet-34になります。

Skip Connectionはself.addの部分になります。このブロック内で計算してきたhとこのブロックの入力であるxを足し合わせています(その前のself.shortcutxの次元を合わせています)。このようにすることで逆伝播の際に勾配消失しづらくなるそうです。

ResNet50

class ResNet50(Model):
    def __init__(self, input_shape, output_dim):
        super().__init__()                

        self._layers = [
            Conv2D(64, input_shape = input_shape, kernel_size = (7, 7), strides=(2, 2), padding = "same"),
            BatchNormalization(),
            Activation(tf.nn.relu),
            MaxPool2D(pool_size = (3, 3), strides = (2, 2), padding = "same"),
            ResidualBlock(64, 256),
            [
                ResidualBlock(256, 256) for _ in range(2)                
            ],
            Conv2D(512, kernel_size = (1, 1), strides=(2, 2)),
            [
                ResidualBlock(512, 512) for _ in range(4)                
            ],
            Conv2D(1024, kernel_size = (1, 1), strides=(2, 2)),
            [
                ResidualBlock(1024, 1024) for _ in range(6)                
            ],
            Conv2D(2048, kernel_size = (1, 1), strides=(2, 2)),
            [
                ResidualBlock(2048, 2048) for _ in range(3)
            ],
            GlobalAveragePooling2D(),
            Dense(1000, activation = tf.nn.relu),
            Dense(output_dim, activation = tf.nn.softmax)
        ]

    def call(self, x):
        for layer in self._layers:
            if isinstance(layer, list):
                for l in layer:
                    x = l(x)    
            else:
                x = layer(x)
        return x

先ほど作成したResidualBlockや畳み込み層などを組み合わせています。論文の表が参考になるかと思います。今回は論文に忠実に実装したつもりですが、Fashion-MNISTを取り扱う場合、層の数や次元のチューニングは必要かもしれません。
image.png

モデル作成

model = ResNet50((28, 28, 1), 10)
model.build(input_shape = (None, 28, 28, 1))
model.summary()

""" 結果
Model: "res_net50"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              multiple                  3200      
_________________________________________________________________
batch_normalization_v2 (Batc multiple                  256       
_________________________________________________________________
activation (Activation)      multiple                  0         
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
residual_block (ResidualBloc multiple                  75904     
_________________________________________________________________
residual_block_1 (ResidualBl multiple                  71552     
_________________________________________________________________
residual_block_2 (ResidualBl multiple                  71552     
_________________________________________________________________
conv2d_11 (Conv2D)           multiple                  131584    
_________________________________________________________________
residual_block_3 (ResidualBl multiple                  282368    
_________________________________________________________________
residual_block_4 (ResidualBl multiple                  282368    
_________________________________________________________________
residual_block_5 (ResidualBl multiple                  282368    
_________________________________________________________________
residual_block_6 (ResidualBl multiple                  282368    
_________________________________________________________________
conv2d_24 (Conv2D)           multiple                  525312    
_________________________________________________________________
residual_block_7 (ResidualBl multiple                  1121792   
_________________________________________________________________
residual_block_8 (ResidualBl multiple                  1121792   
_________________________________________________________________
residual_block_9 (ResidualBl multiple                  1121792   
_________________________________________________________________
residual_block_10 (ResidualB multiple                  1121792   
_________________________________________________________________
residual_block_11 (ResidualB multiple                  1121792   
_________________________________________________________________
residual_block_12 (ResidualB multiple                  1121792   
_________________________________________________________________
conv2d_43 (Conv2D)           multiple                  2099200   
_________________________________________________________________
residual_block_13 (ResidualB multiple                  4471808   
_________________________________________________________________
residual_block_14 (ResidualB multiple                  4471808   
_________________________________________________________________
residual_block_15 (ResidualB multiple                  4471808   
_________________________________________________________________
global_average_pooling2d (Gl multiple                  0         
_________________________________________________________________
dense (Dense)                multiple                  2049000   
_________________________________________________________________
dense_1 (Dense)              multiple                  10010     
=================================================================
Total params: 26,313,218
Trainable params: 26,267,778
Non-trainable params: 45,440
_________________________________________________________________
"""

このmodel.buildの意味は正直よくわかっていないのですが、公式によると、どんなinputが来るかわからないサブクラスのために必要とのことで、今回は確かにこれまで書いたコードとは違いinput shapeを引数として与えているので、そのせいかなぁと考えています。

その他

バッチサイズ

バッチサイズは128にしました。論文には256と書かれていましたが、メモリエラー寸前だったので・・・。

最適化手法

Adamにしました。論文にはSGD + Momentumと書かれており、少し試してみたのですが学習の進みが遅いように見受けられたので・・・。余裕があれば、SGD + Momentumでも試してみます。

結果

400エポック訓練した結果、Test Accuracyが91.3%と前回を下回ってしまいました。チューニングすればもう少し精度が上がるかもしれません。

image.png

参考にさせて頂いたコード

追記:SGD + Momentumの結果

400エポック訓練した結果、Test Accuracyが91.4%でした。Adamとそう変わらなかったですね。
image.png

shoji9x9
2020年1月よりMaaS関係に従事。プライベートでは機械学習、Kaggleに取り組んでいます。
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした