###概要
ResNet論文にあるアーキテクチャに従い、ResNet50を実装しました。
ResNetの**Shortcut Connection(Skip Connection)**という手法は他のネットワークモデルでもよく使われる手法ですので、実装法を知っておこうと思いやってみました。
###環境
-Software-
Windows 10 Home
Anaconda3 64-bit(Python3.7)
VScode
-Library-
Tensorflow 2.2.0
-Hardware-
CPU: Intel core i9 9900K
GPU: NVIDIA GeForce RTX2080ti
RAM: 16GB 3200MHz
###参考
サイト
・ResNet論文
・Residual Network(ResNet)の理解とチューニングのベストプラクティス
・TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習する
↑ほとんどこの方が実装したコードと同じです
###プログラム
Githubに上げておきます。
https://github.com/himazin331/ResNet
今回は、データセットにFashion-MNISTを使いました。
###ソースコード
import tensorflow as tf
import tensorflow.keras.layers as kl
import argparse as arg
import os
import numpy as np
import matplotlib.pyplot as plt
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# 残差ブロック(Bottleneckアーキテクチャ)
class Res_Block(tf.keras.Model):
def __init__(self, in_channels, out_channels):
super().__init__()
bneck_channels = out_channels // 4
self.bn1 = kl.BatchNormalization()
self.av1 = kl.Activation(tf.nn.relu)
self.conv1 = kl.Conv2D(bneck_channels, kernel_size=1,
strides=1, padding='valid', use_bias=False)
self.bn2 = kl.BatchNormalization()
self.av2 = kl.Activation(tf.nn.relu)
self.conv2 = kl.Conv2D(bneck_channels, kernel_size=3,
strides=1, padding='same', use_bias=False)
self.bn3 = kl.BatchNormalization()
self.av3 = kl.Activation(tf.nn.relu)
self.conv3 = kl.Conv2D(out_channels, kernel_size=1,
strides=1, padding='valid', use_bias=False)
self.shortcut = self._scblock(in_channels, out_channels)
self.add = kl.Add()
# Shortcut Connection
def _scblock(self, in_channels, out_channels):
if in_channels != out_channels:
self.bn_sc1 = kl.BatchNormalization()
self.conv_sc1 = kl.Conv2D(out_channels, kernel_size=1,
strides=1, padding='same', use_bias=False)
return self.conv_sc1
else:
return lambda x: x
def call(self, x):
out1 = self.conv1(self.av1(self.bn1(x)))
out2 = self.conv2(self.av2(self.bn2(out1)))
out3 = self.conv3(self.av3(self.bn3(out2)))
shortcut = self.shortcut(x)
out4 = self.add([out3, shortcut])
return out4
# ResNet50(Pre Activation)
class ResNet(tf.keras.Model):
def __init__(self, input_shape, output_dim):
super().__init__()
self._layers = [
kl.BatchNormalization(),
kl.Activation(tf.nn.relu),
kl.Conv2D(64, kernel_size=7, strides=2, padding="same", use_bias=False, input_shape=input_shape),
kl.MaxPool2D(pool_size=3, strides=2, padding="same"),
Res_Block(64, 256),
[
Res_Block(256, 256) for _ in range(2)
],
kl.Conv2D(512, kernel_size=1, strides=2),
[
Res_Block(512, 512) for _ in range(4)
],
kl.Conv2D(1024, kernel_size=1, strides=2, use_bias=False),
[
Res_Block(1024, 1024) for _ in range(6)
],
kl.Conv2D(2048, kernel_size=1, strides=2, use_bias=False),
[
Res_Block(2048, 2048) for _ in range(3)
],
kl.GlobalAveragePooling2D(),
kl.Dense(1000, activation="relu"),
kl.Dense(output_dim, activation="softmax")
]
def call(self, x):
for layer in self._layers:
if isinstance(layer, list):
for _layer in layer:
x = _layer(x)
else:
x = layer(x)
return x
# 学習
class trainer(object):
def __init__(self):
self.resnet = ResNet((28, 28, 1), 10)
self.resnet.build(input_shape=(None, 28, 28, 1))
self.resnet.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.9),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
def train(self, train_img, train_lab, test_images, test_labels, out_path, batch_size, epochs):
print("\n\n___Start training...")
his = self.resnet.fit(train_img, train_lab, batch_size=batch_size, epochs=epochs)
graph_output(his, out_path) # グラフ出力
print("___Training finished\n\n")
self.resnet.evaluate(test_images, test_labels) # テストデータ推論
print("\n___Saving parameter...")
out_path = os.path.join(out_path, "resnet.h5")
self.resnet.save_weights(out_path) # パラメータ保存
print("___Successfully completed\n\n")
# accuracy, lossグラフ
def graph_output(history, out_path):
plt.plot(history.history['accuracy'])
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='upper left')
plt.savefig(os.path.join(out_path, "acc_graph.jpg"))
plt.show()
plt.plot(history.history['loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train'], loc='upper left')
plt.savefig(os.path.join(out_path, "loss_graph.jpg"))
plt.show()
def main():
# コマンドラインオプション作成
parser = arg.ArgumentParser(description='ResNet50')
parser.add_argument('--out', '-o', type=str,
default=os.path.dirname(os.path.abspath(__file__)),
help='パラメータの保存先指定(デフォルト値=./resnet.h5')
parser.add_argument('--batch_size', '-b', type=int, default=256,
help='ミニバッチサイズの指定(デフォルト値=256)')
parser.add_argument('--epoch', '-e', type=int, default=40,
help='学習回数の指定(デフォルト値=40)')
args = parser.parse_args()
# 設定情報出力
print("=== Setting information ===")
print("# Output folder: {}".format(args.out))
print("# Minibatch-size: {}".format(args.batch_size))
print("# Epoch: {}".format(args.epoch))
print("===========================")
# 出力フォルダの作成(フォルダが存在する場合は作成しない)
os.makedirs(args.out, exist_ok=True)
# Fashion-MNIST読込
f_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = f_mnist.load_data()
# 画像データ加工
train_imgs = train_images / 255.0
train_imgs = train_imgs[:, :, :, np.newaxis]
test_imgs = test_images / 255.0
test_imgs = test_imgs[:, :, :, np.newaxis]
Trainer = trainer()
Trainer.train(train_imgs, train_labels, test_imgs, test_labels, args.out, args.batch_size, args.epoch)
if __name__ == "__main__":
main()
###実行結果
今回はFashion-MNISTをミニバッチ数256で40エポック学習させました。
学習終了後に汎化性能をテストします。
コマンド
python resnet.py -b <ミニバッチサイズ> -e <学習回数> (-o <保存先>)
###説明
簡単な説明をしていきます。
####ResNet
# ResNet50(Pre Activation)
class ResNet(tf.keras.Model):
def __init__(self, input_shape, output_dim):
super().__init__()
self._layers = [
kl.BatchNormalization(),
kl.Activation(tf.nn.relu),
kl.Conv2D(64, kernel_size=7, strides=2, padding="same", use_bias=False, input_shape=input_shape),
kl.MaxPool2D(pool_size=3, strides=2, padding="same"),
Res_Block(64, 256),
[
Res_Block(256, 256) for _ in range(2)
],
kl.Conv2D(512, kernel_size=1, strides=2),
[
Res_Block(512, 512) for _ in range(4)
],
kl.Conv2D(1024, kernel_size=1, strides=2, use_bias=False),
[
Res_Block(1024, 1024) for _ in range(6)
],
kl.Conv2D(2048, kernel_size=1, strides=2, use_bias=False),
[
Res_Block(2048, 2048) for _ in range(3)
],
kl.GlobalAveragePooling2D(),
kl.Dense(1000, activation="relu"),
kl.Dense(output_dim, activation="softmax")
]
def call(self, x):
for layer in self._layers:
if isinstance(layer, list):
for _layer in layer:
x = _layer(x)
else:
x = layer(x)
return x
ResNet50
クラスでネットワークモデルの定義を行っています。
アーキテクチャは論文にある通りに決定しました。
####Residual Block
# 残差ブロック(Bottleneckアーキテクチャ)
class Res_Block(tf.keras.Model):
def __init__(self, in_channels, out_channels):
super().__init__()
bneck_channels = out_channels // 4
self.bn1 = kl.BatchNormalization()
self.av1 = kl.Activation(tf.nn.relu)
self.conv1 = kl.Conv2D(bneck_channels, kernel_size=1,
strides=1, padding='valid', use_bias=False)
self.bn2 = kl.BatchNormalization()
self.av2 = kl.Activation(tf.nn.relu)
self.conv2 = kl.Conv2D(bneck_channels, kernel_size=3,
strides=1, padding='same', use_bias=False)
self.bn3 = kl.BatchNormalization()
self.av3 = kl.Activation(tf.nn.relu)
self.conv3 = kl.Conv2D(out_channels, kernel_size=1,
strides=1, padding='valid', use_bias=False)
self.shortcut = self._scblock(in_channels, out_channels)
self.add = kl.Add()
# Shortcut Connection
def _scblock(self, in_channels, out_channels):
if in_channels != out_channels:
self.bn_sc1 = kl.BatchNormalization()
self.conv_sc1 = kl.Conv2D(out_channels, kernel_size=1,
strides=1, padding='same', use_bias=False)
return self.conv_sc1
else:
return lambda x: x
def call(self, x):
out1 = self.conv1(self.av1(self.bn1(x)))
out2 = self.conv2(self.av2(self.bn2(out1)))
out3 = self.conv3(self.av3(self.bn3(out2)))
shortcut = self.shortcut(x)
out4 = self.add([out3, shortcut])
return out4
Res_Block
クラスは肝となる残差ブロックです。
_scblock
メソッドでShortcut Connectionを実装しています。
アーキテクチャはPre Activation(参照)で実装しています。
####trainer
# 学習
class trainer(object):
def __init__(self):
self.resnet = ResNet((28, 28, 1), 10)
self.resnet.build(input_shape=(None, 28, 28, 1))
self.resnet.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.9),
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
論文中ではSGD+Momentumを用いたとあったので、準拠してSGD+Momentumを用いています。
###おわりに
勉強でResNetを実装しただけなので、特段説明すること無いです...
TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習するで紹介されているコードをみて、大変勉強になりました。
リストとfor文を使って層を展開していく発想いいなーって思い、今後真似できる場面があったら真似してみたいと思いました。