35
34

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Res-Netsの有効性をCIFAR-10で確認する

Last updated at Posted at 2018-07-18

畳み込みニューラルネットワーク(CNN)で成功を収めたネットワークの1つにRes-Nets(Residual Networks、残差ネットワーク)があります。一般にRes-Netsを使うと勾配消失問題に強いとされます。CIFAR-10データセットを使ってRes-Netsの有効性を軽く確認してみました。

元の論文

K.He, X.Zhang, S.Ren, J.Sun, "Deep Residual Learning for Image Recognition", 2015
https://arxiv.org/pdf/1512.03385.pdf

残差ブロック(Residual Block)とは

このようにショートカット構造のあるレイヤーの組み合わせ(ブロック)のこと。K.He et al.(2015)より
res-01.png

分岐前の活性化関数適用後の値を$a^{[l]}$とすると、合流後の活性化関数適用前の値は$z^{[l+2]}$となるので、

$$z^{[l+2]} = (w^{[l+2]}a^{[l+1]}+b^{[l+2]}) + a^{[l]}$$

で表されます。カッコ内はメインルートの値、右側はショートカットの値です。仮にメインルートの値が0になったとしても、ショートカット側の値が代入されるので、勾配消失問題に強いという理屈です。

そもそも勾配消失(爆発)問題とは

ネットワークが深くなると発生しやすくなる問題です。一般的な全結合ニューラルネットワークでの計算は、

$$a^{[l+1]}=g(w^{[l+1]}a^{[l]}+b^{[l+1]}) $$

で表されます。今わかりやすいように、活性化関数を線形活性化関数、b=0、wの係数を全て同一とすると、

$$a^{[l]} = w^lx$$

という指数関数になります。つまりネットワークが深くなればなるほど(lの値が大きくなるほど)、wが0.99という1に近い数字でも出力はほぼ0に近くなり、wが1.01なら出力は無限大になり……と出力をコントロールするのがとても難しくなります。

勾配の消失(爆発)はCNNだけではなく、全結合のNNやRNNでも発生する共通の問題です。一般に、勾配の爆発より消失のほうが深刻な問題であると言われています。無限大に爆発したものはクリッピング(max, min関数)で一定範囲内に収められますが、0になってしまったものは復活させようがなく、そこで学習が終わってしまうためです。

要はRes-Netsのメリットって何?

勾配消失問題を気にすることなく、モデルをどんどん深くできること。モデルを深く大きくすると、巨大な訓練データを精度に反映できるためです。

KerasでのResNet-50の実装を見る

Kerasには既に訓練済みのResNet-50のモデルがあり、簡単に読み込むことができます。summaryで表示し、ResBlockの実装を確認します。

from keras.applications.resnet50 import ResNet50

resnet = ResNet50()
with open("resnet50.txt", "w") as fp:
    resnet.summary(print_fn=lambda x: fp.write(x + "\r\n"))

ResBlock1つを取り出すとこうなります。最初のMaxPoolingはResBlockに関係ありません。

max_pooling2d_1 (MaxPooling2D)  (None, 55, 55, 64)   0           activation_1[0][0]               

__________________________________________________________________________________________________

res2a_branch2a (Conv2D)         (None, 55, 55, 64)   4160        max_pooling2d_1[0][0]            

__________________________________________________________________________________________________

bn2a_branch2a (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2a[0][0]             

__________________________________________________________________________________________________

activation_2 (Activation)       (None, 55, 55, 64)   0           bn2a_branch2a[0][0]              

__________________________________________________________________________________________________

res2a_branch2b (Conv2D)         (None, 55, 55, 64)   36928       activation_2[0][0]               

__________________________________________________________________________________________________

bn2a_branch2b (BatchNormalizati (None, 55, 55, 64)   256         res2a_branch2b[0][0]             

__________________________________________________________________________________________________

activation_3 (Activation)       (None, 55, 55, 64)   0           bn2a_branch2b[0][0]              

__________________________________________________________________________________________________

res2a_branch2c (Conv2D)         (None, 55, 55, 256)  16640       activation_3[0][0]               

__________________________________________________________________________________________________

res2a_branch1 (Conv2D)          (None, 55, 55, 256)  16640       max_pooling2d_1[0][0]            

__________________________________________________________________________________________________

bn2a_branch2c (BatchNormalizati (None, 55, 55, 256)  1024        res2a_branch2c[0][0]             

__________________________________________________________________________________________________

bn2a_branch1 (BatchNormalizatio (None, 55, 55, 256)  1024        res2a_branch1[0][0]              

__________________________________________________________________________________________________

add_1 (Add)                     (None, 55, 55, 256)  0           bn2a_branch2c[0][0]              

                                                                 bn2a_branch1[0][0]               

__________________________________________________________________________________________________

activation_4 (Activation)       (None, 55, 55, 256)  0           add_1[0][0]                      

__________________________________________________________________________________________________

まとめると次のとおりです。分岐以降、

  • メインルート:Conv2d→BN→Activation → Conv2d→BN→Activation → Conv2d→BN
  • ショートカット:Conv2d→BN

このあとAddで合流させます。これが1つのRes-Blockです。3レイヤーで1ResBlockを構成していますが、2レイヤー(メインルートの2回目のCov2d→BN→Activationを外す)にしても特に問題はなさそうです。今回の実験では簡単にするために1Res-Blockは2レイヤーとします。

派生研究として、合流後にReLUをするのではなく、ReLUをかけてから分岐させ合流後にはReLUをかけないというのもあります。下記資料によると、現在ではこっちがRes-Netsと呼ばれるそうです。

実験設定

CIFAR-10データを使い、深いモデルを定義し、Res Blockの有無で学習の過程を比較します。CIFAR-10程度のデータ数(5万)だと通常はここまで深いモデルは必要ありません。

元の論文がVGGを意識したモデルなので、これを踏襲し、簡素化したモデルにします。K.He et al.(2015)より
res-02.png

全てカーネルサイズ3x3の畳み込み(Conv2d)をやっているものの、フィルター数がだんだん増えていくモデルですね。元のVGGはフィルター数が変わるタイミングでPoolingを挟んでいます。

元の論文で示されていた用語ではありませんが、わかりやすいように同一のフィルター数の**レイヤー/ブロックの集まりを「1フェーズ」**と呼ぶことにします。上の図では緑色のレイヤーの集まりで1フェーズ、紫色のレイヤーの集まりで別の1フェーズ……となります。また、**ResBlockの1つを「1ブロック」**と呼びます。1ブロックは2つの畳み込みレイヤーで構成されるため、1ブロックのレイヤー数は2となります。

細かな設定は以下の通りです。フィルター数の変化はResNetよりもVGGを参考にしました。

  • 1フェーズは3ブロックで構成される。フェーズの最初に2×2のAveragePoolingを入れる。またショートカットとメインルートのフィルター数を統一するため、AveragePoolingの後に1回だけ3×3の活性化関数なしのConv2dを入れる。したがって1フェーズは7層。
  • 1ブロックは2レイヤーから構成され、メインルートはConv2d→BN→Activation→Conv2d→BN、ショートカットはBNだけ入れてAddで合流させる。合流後にActivationをかける。
  • フィルター数は初期値16として、フェーズが変わるたびに倍々になっていく
  • 畳み込みは全て3x3、same paddingとする。つまり、フェーズ間ではテンソルのサイズは同一
  • 入力の次元は(32,32,3)なので、フェーズごとに(32,32,16)→(16,16,32)→(8,8,64)→(4,4,128)となる
  • ResBlockの有無、フェーズ数が1~4の合計8パターンを試して比較する。ResBlockを使わない場合は、単に畳み込みフィルターを7個並べたものが1フェーズとなる。

字面にするとわかりづらいですが、コードを見たほうがわかりやすいと思います。

コード

import numpy as np
import pickle
import os
from keras.layers import Input, Conv2D, AveragePooling2D, BatchNormalization, Add, Activation, Flatten, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import cifar10
from keras.utils import to_categorical

class TestModel:
    def __init__(self, use_resblock, nb_blocks):
        self.use_resblock = use_resblock
        self.nb_blocks = nb_blocks
        # モデルの作成
        self.model = self._create_model()
        # モデル名
        self.name = ""
        if use_resblock: self.name += "use_res_"
        else: self.name += "no_res_"
        self.name = f"{self.name}{self.nb_blocks:02d}"

    def _create_model(self):
        input = Input(shape=(32, 32, 3))
        X = input
        n_filter = 16
        for i in range(self.nb_blocks):
            # 3ブロック単位でAveragePoolingを入れる、フィルター数を倍にする
            if i % 3 == 0 and i != 0:
                X = AveragePooling2D((2,2))(X)
                n_filter *= 2
            # ショートカットとメインのフィルター数を揃えるために活性化関数なしの畳込みレイヤーを作る
            if i % 3 == 0:
                X = Conv2D(n_filter, (3,3), padding="same")(X)
            # 1ブロック単位の処理
            if self.use_resblock:
                # ショートカット:ショートカット→BatchNorm(ResBlockを使う場合のみ)
                shortcut = X
                shortcut = BatchNormalization()(shortcut)
            # メイン
            # 畳み込み→BatchNorm→活性化関数
            X = Conv2D(n_filter, (3,3), padding="same")(X)
            X = BatchNormalization()(X)
            X = Activation("relu")(X)
            # 畳み込み→BatchNorm
            X = Conv2D(n_filter, (3,3), padding="same")(X)
            X = BatchNormalization()(X)
            if self.use_resblock:
                # ショートカットとマージ(ResBlockを使う場合のみ)
                X = Add()([X, shortcut])
            # 活性化関数
            X = Activation("relu")(X)
        # 全結合
        X = Flatten()(X)
        y = Dense(10, activation="softmax")(X)
        # モデル
        model = Model(inputs=input, outputs=y)
        return model

    def train(self, Xtrain, ytrain, Xval, yval, nb_epoch=100, learning_rate=0.01):
        self.model.compile(optimizer=Adam(lr=learning_rate), loss="categorical_crossentropy", metrics=["accuracy"])
        history = self.model.fit(Xtrain, ytrain, batch_size=128, epochs=nb_epoch, validation_data=(Xval, yval)).history
        # historyの保存
        if not os.path.exists("history"): os.mkdir("history")
        with open(f"history/{self.name}.dat", "wb") as fp:
            pickle.dump(history, fp)

if __name__ == "__main__":
    # データの読み込み
    (X_train, y_train), (X_test, y_test) = cifar10.load_data()
    X_train, X_test = X_train / 255.0, X_test / 255.0
    y_train, y_test = to_categorical(y_train), to_categorical(y_test)
    # テストパターン
    resflag = [False, True]
    nb_blocks = [3, 6, 9, 12]
    # モデルの作成
    for res in resflag:
        for nb in nb_blocks:
            print("Testing model... / ", res, nb)
            model = TestModel(res, nb)
            model.train(X_train, y_train, X_test, y_test, nb_epoch=100)

ちなみにモデルが一番大きくなる「Res-Blockあり、4フェーズ(12ブロック)」の場合のsummaryは次のようになりました(パラメーター数部分だけ抜粋)。

Total params: 1,303,050
Trainable params: 1,298,730
Non-trainable params: 4,320

130万個のパラメーターです。自分のCPUだと1epoch40分近くかかりますが、Google ColabのGPUを使ったら1epoch1分強でできました。100epochなので最悪1時間半ぐらいですね。ちなみに4フェーズの場合は28層のモデルです。1フェーズだけだと7層なのでもっと早く終わります。

ちなみに本家のResNet-50の場合はパラメーターが2563万もあります。これはちょっと自分では訓練したくないです。

結果

縦軸を訓練精度、横軸をepoch数としてプロットしました。
res-03.png

訓練精度0.9~1で拡大すると次のとおりです。
res-04.png

ResBlockなし、フェーズ=1、2(ブロック数3、6)の場合は勾配が消失して学習が進みませんでした。ResBlockありの場合は、どのフェーズ数でもきちんと学習が進んでいることが確認できます。また、ResBlockを使った場合はブロック数(モデルの深さ)が大きくなるほど、訓練精度が上がっていることが確認できます

なお勾配消失はモデルが深くなるほど起こりやすいと言われていますが、なぜResBlockなしの場合浅いモデルのほうで勾配が消失して、深いモデルでは消失しなかったかはよくわかりません。しかし、ResBlockを使えばどの場合も勾配が消失していないというのは確かです。

同様に、交差検証データに対する精度を縦軸にしてプロットしてみます。ここではcifar10を読み込んだ際のテストデータを交差検証データとして用いています。

res-05.png

交差検証の精度を0.6~0.9で拡大すると次の通りです。

res-06.png

概ね交差検証の精度は0.8弱というところでしょうか。訓練精度が98%近いので、かなりオーバーフィッティングしています。それは半ば当たり前で、CIFAR-10程度でここまで深いモデルは必要なく、モデルが訓練データの数に対して大きすぎるのでオーバーフィッティングしがちということです。データの増強やドロップアウト使うとかなりいい結果になると思います。

まとめ

実験の結果をまとめます。

  • Res-Netは勾配消失問題に対して有効と言われているが、この実験を通じてそれを確認できた。ResBlockを使わない場合は勾配消失が起こったが、ResBlockを使った場合はどれも勾配消失は起こらなかった。
  • ResBlockを使うと、モデルを大きくするほど訓練誤差が下がり、より理想的な結果になりやすい。

データはgithubから見れます
https://github.com/koshian2/ResNet-CIFAR10

35
34
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
35
34

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?