Edited at

TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習する


はじめに

今回はCNNの中でも比較的新しいResNetに取り組んでみたいと思います。と言っても2015年に発表されたようなので、もう4年前ですね・・・。


ResNetとは?

正確に説明する力はないのですが、Skip Connection (Shortcut) を利用することで従来よりも深い層を持つことを実現したネットワークと理解しています。日本語ですと、

あたりがわかりやすいかと思います。あとは元の論文ですね。英語ですが12ページだけなので、その気になれば読めると思います(自分はところどころしか読んでいないですが・・・)。


今回のテーマ

以前取り組んだFashion-MNISTの分類をResNet-50で実現しようと思います。今回は制約はなしにしました(ResNetの学習には時間がかかりそうだったので)。


環境


  • Google Colaboratory

  • TensorFlow 2.0 Alpha


コード

こちらです。

なぜかGitHub上ではうまく開けませんでした。GitHubのURLはこちらです。


コード解説


ResidualBlock

from tensorflow.keras.layers import Conv2D, Dense, BatchNormalization, Activation, MaxPool2D, GlobalAveragePooling2D, Add

from tensorflow.keras import Model

class ResidualBlock(Model):
def __init__(self, channel_in = 64, channel_out = 256):
super().__init__()

channel = channel_out // 4

self.conv1 = Conv2D(channel, kernel_size = (1, 1), padding = "same")
self.bn1 = BatchNormalization()
self.av1 = Activation(tf.nn.relu)
self.conv2 = Conv2D(channel, kernel_size = (3, 3), padding = "same")
self.bn2 = BatchNormalization()
self.av2 = Activation(tf.nn.relu)
self.conv3 = Conv2D(channel_out, kernel_size = (1, 1), padding = "same")
self.bn3 = BatchNormalization()
self.shortcut = self._shortcut(channel_in, channel_out)
self.add = Add()
self.av3 = Activation(tf.nn.relu)

def call(self, x):
h = self.conv1(x)
h = self.bn1(h)
h = self.av1(h)
h = self.conv2(h)
h = self.bn2(h)
h = self.av2(h)
h = self.conv3(h)
h = self.bn3(h)
shortcut = self.shortcut(x)
h = self.add([h, shortcut])
y = self.av3(h)
return y

def _shortcut(self, channel_in, channel_out):
if channel_in == channel_out:
return lambda x : x
else:
return self._projection(channel_out)

def _projection(self, channel_out):
return Conv2D(channel_out, kernel_size = (1, 1), padding = "same")

ResNetではこのブロックを積み重ねていきますので、それをクラスにします。今回はResNet-50ですので、Bottleneck Architectureを利用し一旦次元削減してから復元する処理になっています。余談ですが、Bottleneck Architectureではない通常のアーキテクチャーで実装するとResNet-34になります。

Skip Connectionはself.addの部分になります。このブロック内で計算してきたhとこのブロックの入力であるxを足し合わせています(その前のself.shortcutxの次元を合わせています)。このようにすることで逆伝播の際に勾配消失しづらくなるそうです。


ResNet50

class ResNet50(Model):

def __init__(self, input_shape, output_dim):
super().__init__()

self._layers = [
Conv2D(64, input_shape = input_shape, kernel_size = (7, 7), strides=(2, 2), padding = "same"),
BatchNormalization(),
Activation(tf.nn.relu),
MaxPool2D(pool_size = (3, 3), strides = (2, 2), padding = "same"),
ResidualBlock(64, 256),
[
ResidualBlock(256, 256) for _ in range(2)
],
Conv2D(512, kernel_size = (1, 1), strides=(2, 2)),
[
ResidualBlock(512, 512) for _ in range(4)
],
Conv2D(1024, kernel_size = (1, 1), strides=(2, 2)),
[
ResidualBlock(1024, 1024) for _ in range(6)
],
Conv2D(2048, kernel_size = (1, 1), strides=(2, 2)),
[
ResidualBlock(2048, 2048) for _ in range(3)
],
GlobalAveragePooling2D(),
Dense(1000, activation = tf.nn.relu),
Dense(output_dim, activation = tf.nn.softmax)
]

def call(self, x):
for layer in self._layers:
if isinstance(layer, list):
for l in layer:
x = l(x)
else:
x = layer(x)
return x

先ほど作成したResidualBlockや畳み込み層などを組み合わせています。論文の表が参考になるかと思います。今回は論文に忠実に実装したつもりですが、Fashion-MNISTを取り扱う場合、層の数や次元のチューニングは必要かもしれません。


モデル作成

model = ResNet50((28, 28, 1), 10)

model.build(input_shape = (None, 28, 28, 1))
model.summary()

""" 結果
Model: "res_net50"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) multiple 3200
_________________________________________________________________
batch_normalization_v2 (Batc multiple 256
_________________________________________________________________
activation (Activation) multiple 0
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple 0
_________________________________________________________________
residual_block (ResidualBloc multiple 75904
_________________________________________________________________
residual_block_1 (ResidualBl multiple 71552
_________________________________________________________________
residual_block_2 (ResidualBl multiple 71552
_________________________________________________________________
conv2d_11 (Conv2D) multiple 131584
_________________________________________________________________
residual_block_3 (ResidualBl multiple 282368
_________________________________________________________________
residual_block_4 (ResidualBl multiple 282368
_________________________________________________________________
residual_block_5 (ResidualBl multiple 282368
_________________________________________________________________
residual_block_6 (ResidualBl multiple 282368
_________________________________________________________________
conv2d_24 (Conv2D) multiple 525312
_________________________________________________________________
residual_block_7 (ResidualBl multiple 1121792
_________________________________________________________________
residual_block_8 (ResidualBl multiple 1121792
_________________________________________________________________
residual_block_9 (ResidualBl multiple 1121792
_________________________________________________________________
residual_block_10 (ResidualB multiple 1121792
_________________________________________________________________
residual_block_11 (ResidualB multiple 1121792
_________________________________________________________________
residual_block_12 (ResidualB multiple 1121792
_________________________________________________________________
conv2d_43 (Conv2D) multiple 2099200
_________________________________________________________________
residual_block_13 (ResidualB multiple 4471808
_________________________________________________________________
residual_block_14 (ResidualB multiple 4471808
_________________________________________________________________
residual_block_15 (ResidualB multiple 4471808
_________________________________________________________________
global_average_pooling2d (Gl multiple 0
_________________________________________________________________
dense (Dense) multiple 2049000
_________________________________________________________________
dense_1 (Dense) multiple 10010
=================================================================
Total params: 26,313,218
Trainable params: 26,267,778
Non-trainable params: 45,440
_________________________________________________________________
"""

このmodel.buildの意味は正直よくわかっていないのですが、公式によると、どんなinputが来るかわからないサブクラスのために必要とのことで、今回は確かにこれまで書いたコードとは違いinput shapeを引数として与えているので、そのせいかなぁと考えています。


その他


バッチサイズ

バッチサイズは128にしました。論文には256と書かれていましたが、メモリエラー寸前だったので・・・。


最適化手法

Adamにしました。論文にはSGD + Momentumと書かれており、少し試してみたのですが学習の進みが遅いように見受けられたので・・・。余裕があれば、SGD + Momentumでも試してみます。


結果

400エポック訓練した結果、Test Accuracyが91.3%と前回を下回ってしまいました。チューニングすればもう少し精度が上がるかもしれません。


参考にさせて頂いたコード


追記:SGD + Momentumの結果

400エポック訓練した結果、Test Accuracyが91.4%でした。Adamとそう変わらなかったですね。