4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

ResNetを実装してみただけ (TF2.0)

Last updated at Posted at 2020-08-08

###概要
ResNet論文にあるアーキテクチャに従い、ResNet50を実装しました。
ResNetの**Shortcut Connection(Skip Connection)**という手法は他のネットワークモデルでもよく使われる手法ですので、実装法を知っておこうと思いやってみました。

###環境
-Software-
Windows 10 Home
Anaconda3 64-bit(Python3.7)
VScode
-Library-
Tensorflow 2.2.0
-Hardware-
CPU: Intel core i9 9900K
GPU: NVIDIA GeForce RTX2080ti
RAM: 16GB 3200MHz

###参考
サイト
ResNet論文
Residual Network(ResNet)の理解とチューニングのベストプラクティス
TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習する
↑ほとんどこの方が実装したコードと同じです

###プログラム
Githubに上げておきます。
https://github.com/himazin331/ResNet

今回は、データセットにFashion-MNISTを使いました。

###ソースコード

resnet.py
import tensorflow as tf
import tensorflow.keras.layers as kl

import argparse as arg
import os

import numpy as np
import matplotlib.pyplot as plt

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


# 残差ブロック(Bottleneckアーキテクチャ)
class Res_Block(tf.keras.Model):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        bneck_channels = out_channels // 4

        self.bn1 = kl.BatchNormalization()
        self.av1 = kl.Activation(tf.nn.relu)
        self.conv1 = kl.Conv2D(bneck_channels, kernel_size=1,
                                strides=1, padding='valid', use_bias=False)
        
        self.bn2 = kl.BatchNormalization()
        self.av2 = kl.Activation(tf.nn.relu)
        self.conv2 = kl.Conv2D(bneck_channels, kernel_size=3,
                                strides=1, padding='same', use_bias=False)

        self.bn3 = kl.BatchNormalization()
        self.av3 = kl.Activation(tf.nn.relu)
        self.conv3 = kl.Conv2D(out_channels, kernel_size=1,
                                strides=1, padding='valid', use_bias=False)
        
        self.shortcut = self._scblock(in_channels, out_channels)
        self.add = kl.Add()

    # Shortcut Connection
    def _scblock(self, in_channels, out_channels):
        if in_channels != out_channels:
            self.bn_sc1 = kl.BatchNormalization()
            self.conv_sc1 = kl.Conv2D(out_channels, kernel_size=1,
                                        strides=1, padding='same', use_bias=False)
            return self.conv_sc1
        else:
            return lambda x: x

    def call(self, x):
        out1 = self.conv1(self.av1(self.bn1(x)))
        out2 = self.conv2(self.av2(self.bn2(out1)))
        out3 = self.conv3(self.av3(self.bn3(out2)))
        shortcut = self.shortcut(x)
        out4 = self.add([out3, shortcut])

        return out4


# ResNet50(Pre Activation)
class ResNet(tf.keras.Model):
    def __init__(self, input_shape, output_dim):
        super().__init__()

        self._layers = [

            kl.BatchNormalization(),
            kl.Activation(tf.nn.relu),
            kl.Conv2D(64, kernel_size=7, strides=2, padding="same", use_bias=False, input_shape=input_shape),
            kl.MaxPool2D(pool_size=3, strides=2, padding="same"),
            Res_Block(64, 256),
            [
                Res_Block(256, 256) for _ in range(2)
            ],
            kl.Conv2D(512, kernel_size=1, strides=2),
            [
                Res_Block(512, 512) for _ in range(4)
            ],
            kl.Conv2D(1024, kernel_size=1, strides=2, use_bias=False),
            [
                Res_Block(1024, 1024) for _ in range(6)
            ],
            kl.Conv2D(2048, kernel_size=1, strides=2, use_bias=False),
            [
                Res_Block(2048, 2048) for _ in range(3)
            ],
            kl.GlobalAveragePooling2D(),
            kl.Dense(1000, activation="relu"),
            kl.Dense(output_dim, activation="softmax")
        ]

    def call(self, x):
        for layer in self._layers:
            if isinstance(layer, list):
                for _layer in layer:
                    x = _layer(x)
            else:
                x = layer(x)
        
        return x


# 学習
class trainer(object):
    def __init__(self):
        self.resnet = ResNet((28, 28, 1), 10)
        self.resnet.build(input_shape=(None, 28, 28, 1))
        self.resnet.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.9),
                            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                            metrics=['accuracy'])

    def train(self, train_img, train_lab, test_images, test_labels, out_path, batch_size, epochs):
        print("\n\n___Start training...")

        his = self.resnet.fit(train_img, train_lab, batch_size=batch_size, epochs=epochs)
        
        graph_output(his, out_path)  # グラフ出力

        print("___Training finished\n\n")

        self.resnet.evaluate(test_images, test_labels)  # テストデータ推論

        print("\n___Saving parameter...")
        out_path = os.path.join(out_path, "resnet.h5")
        self.resnet.save_weights(out_path)  # パラメータ保存
        print("___Successfully completed\n\n")


# accuracy, lossグラフ
def graph_output(history, out_path):
    plt.plot(history.history['accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.savefig(os.path.join(out_path, "acc_graph.jpg"))
    plt.show()

    plt.plot(history.history['loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train'], loc='upper left')
    plt.savefig(os.path.join(out_path, "loss_graph.jpg"))
    plt.show()


def main():
    # コマンドラインオプション作成
    parser = arg.ArgumentParser(description='ResNet50')
    parser.add_argument('--out', '-o', type=str,
                        default=os.path.dirname(os.path.abspath(__file__)),
                        help='パラメータの保存先指定(デフォルト値=./resnet.h5')
    parser.add_argument('--batch_size', '-b', type=int, default=256,
                        help='ミニバッチサイズの指定(デフォルト値=256)')
    parser.add_argument('--epoch', '-e', type=int, default=40,
                        help='学習回数の指定(デフォルト値=40)')
    args = parser.parse_args()
        
    # 設定情報出力
    print("=== Setting information ===")
    print("# Output folder: {}".format(args.out))
    print("# Minibatch-size: {}".format(args.batch_size))
    print("# Epoch: {}".format(args.epoch))
    print("===========================")

    # 出力フォルダの作成(フォルダが存在する場合は作成しない)
    os.makedirs(args.out, exist_ok=True)

    # Fashion-MNIST読込
    f_mnist = tf.keras.datasets.fashion_mnist
    (train_images, train_labels), (test_images, test_labels) = f_mnist.load_data()

    # 画像データ加工
    train_imgs = train_images / 255.0
    train_imgs = train_imgs[:, :, :, np.newaxis]
    test_imgs = test_images / 255.0
    test_imgs = test_imgs[:, :, :, np.newaxis]
    
    Trainer = trainer()
    Trainer.train(train_imgs, train_labels, test_imgs, test_labels, args.out, args.batch_size, args.epoch)


if __name__ == "__main__":
    main()

###実行結果
今回はFashion-MNISTをミニバッチ数256で40エポック学習させました。

コンソール画面
コンソール画面

学習終了後に汎化性能をテストします。

予測精度グラフ
予測精度グラフ

損失値グラフ
損失値グラフ

コマンド
python resnet.py -b <ミニバッチサイズ> -e <学習回数> (-o <保存先>)

###説明
簡単な説明をしていきます。

####ResNet

ResNetクラス
# ResNet50(Pre Activation)
class ResNet(tf.keras.Model):
    def __init__(self, input_shape, output_dim):
        super().__init__()

        self._layers = [

            kl.BatchNormalization(),
            kl.Activation(tf.nn.relu),
            kl.Conv2D(64, kernel_size=7, strides=2, padding="same", use_bias=False, input_shape=input_shape),
            kl.MaxPool2D(pool_size=3, strides=2, padding="same"),
            Res_Block(64, 256),
            [
                Res_Block(256, 256) for _ in range(2)
            ],
            kl.Conv2D(512, kernel_size=1, strides=2),
            [
                Res_Block(512, 512) for _ in range(4)
            ],
            kl.Conv2D(1024, kernel_size=1, strides=2, use_bias=False),
            [
                Res_Block(1024, 1024) for _ in range(6)
            ],
            kl.Conv2D(2048, kernel_size=1, strides=2, use_bias=False),
            [
                Res_Block(2048, 2048) for _ in range(3)
            ],
            kl.GlobalAveragePooling2D(),
            kl.Dense(1000, activation="relu"),
            kl.Dense(output_dim, activation="softmax")
        ]

    def call(self, x):
        for layer in self._layers:
            if isinstance(layer, list):
                for _layer in layer:
                    x = _layer(x)
            else:
                x = layer(x)
        
        return x

ResNet50クラスでネットワークモデルの定義を行っています。
アーキテクチャは論文にある通りに決定しました。
image.png

summary()でみるとこんな感じです。
image.png


####Residual Block

Res_Block
# 残差ブロック(Bottleneckアーキテクチャ)
class Res_Block(tf.keras.Model):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        bneck_channels = out_channels // 4

        self.bn1 = kl.BatchNormalization()
        self.av1 = kl.Activation(tf.nn.relu)
        self.conv1 = kl.Conv2D(bneck_channels, kernel_size=1,
                                strides=1, padding='valid', use_bias=False)
        
        self.bn2 = kl.BatchNormalization()
        self.av2 = kl.Activation(tf.nn.relu)
        self.conv2 = kl.Conv2D(bneck_channels, kernel_size=3,
                                strides=1, padding='same', use_bias=False)

        self.bn3 = kl.BatchNormalization()
        self.av3 = kl.Activation(tf.nn.relu)
        self.conv3 = kl.Conv2D(out_channels, kernel_size=1,
                                strides=1, padding='valid', use_bias=False)
        
        self.shortcut = self._scblock(in_channels, out_channels)
        self.add = kl.Add()

    # Shortcut Connection
    def _scblock(self, in_channels, out_channels):
        if in_channels != out_channels:
            self.bn_sc1 = kl.BatchNormalization()
            self.conv_sc1 = kl.Conv2D(out_channels, kernel_size=1,
                                        strides=1, padding='same', use_bias=False)
            return self.conv_sc1
        else:
            return lambda x: x

    def call(self, x):
        out1 = self.conv1(self.av1(self.bn1(x)))
        out2 = self.conv2(self.av2(self.bn2(out1)))
        out3 = self.conv3(self.av3(self.bn3(out2)))
        shortcut = self.shortcut(x)
        out4 = self.add([out3, shortcut])

        return out4

Res_Blockクラスは肝となる残差ブロックです。
_scblockメソッドでShortcut Connectionを実装しています。
アーキテクチャはPre Activation(参照)で実装しています。


####trainer

trainer
# 学習
class trainer(object):
    def __init__(self):
        self.resnet = ResNet((28, 28, 1), 10)
        self.resnet.build(input_shape=(None, 28, 28, 1))
        self.resnet.compile(optimizer=tf.keras.optimizers.SGD(momentum=0.9),
                            loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                            metrics=['accuracy'])

論文中ではSGD+Momentumを用いたとあったので、準拠してSGD+Momentumを用いています。

###おわりに
勉強でResNetを実装しただけなので、特段説明すること無いです...
TensorFlow2.0を使ってFashion-MNISTをResNet-50で学習するで紹介されているコードをみて、大変勉強になりました。
リストとfor文を使って層を展開していく発想いいなーって思い、今後真似できる場面があったら真似してみたいと思いました。

4
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?