49
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Kerasで学習済みモデルに任意のレイヤー(BatchNorm、Dropoutなど)を差し込む方法

Last updated at Posted at 2018-11-29

転移学習として訓練済みモデルは非常に有用ですが、たまに途中にDropoutを入れたい、BatchNormを入れたいなど困ったことがおきます。今回はVGG16にBatchNormを入れる、MobileNetにDropoutを入れるを試してみます。

VGG16にBatchNormalizationを入れる

理論と実装上の注意

転移学習としてよく使われるVGG16ですが、実は古臭いモデルでBatchNormalizationが入っていません1。現在の分類問題において、よほどの理由がなければBatchNormalizationは入れるべきなので入れてみましょう2

VGG16では、「Conv→Conv→Conv→Pool」のように並んでいますが、Conv→Convを「Conv→BatchNorm→ReLU→Conv→…」と置き換えます。また元のConvにはReLUの活性化関数がついているので、Conv側の活性化関数を線形活性化関数(活性化関数なし)に置き換えます。

import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import BatchNormalization, Activation, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
import tensorflow.keras.activations as activations

def create_normal_model():
    model = VGG16(include_top=False, input_shape=(64,64,3), weights="imagenet")
    x = GlobalAveragePooling2D()(model.layers[-1].output)
    x = Dense(10, activation="softmax")(x)
    # あとでBatchNormを入れるため係数の固定はしない。初期値設定のみ転移学習とする
    return Model(model.inputs, x)

def create_batch_norm_model():
    model = create_normal_model()
    for i, layer in enumerate(model.layers):
        if i==0:
            input = layer.input
            x = input
        else:
            if "conv" in layer.name:
                layer.activation = activations.linear
                x = layer(x)
                x = BatchNormalization()(x)
                x = Activation("relu")(x)
            else:
                x = layer(x)

    bn_model = Model(input, x)
    return bn_model

CIFAR-10の分類をする問題です。今tf.kerasでやっていますが、これはColabのTPUを使うためです。元のKerasでもできます。

まずcreate_basic_model()で普通の転移学習のモデルを作って、それをcreate_batch_norm_model()でBatchNormとReLUの差し込みを行っています。layer.activation=Noneとすると元のKerasでは特にエラー出さなかったのですが、tf.kerasだと「Noneは__name__という属性を持ってねえぞ」と怒られてしまったので、keras.activationsにある活性化関数の関数オブジェクトをConvレイヤーのactivationに代入します。ここは活性化関数なしにしたいので、線形活性化関数(活性化関数なしと線形活性化関数は同義です)を代入しています。

活性化関数の変更をしたら、そのレイヤーに対して処理を行い(x=layer(x) )、あとは普通のモデルと同様にBatchNormとReLUを差し込みます。そして次のレイヤーに移ります。

BatchNormを使うと次レイヤーの入力の値が変わってしまうので、よくある転移学習みたいにレイヤーの重みを固定するのが必ずしも良いとは思えませんでした。ただし、BatchNorm自体が簡単な線形写像なので、すぐに学習できるはずですし、BatchNormを入れたからといって転移学習の意味がないというわけではないと思います。ここはいろんな意見があると思います。今回はただ初期値のみをImageNetからの重みで与える転移学習としました3

全体のコードはこちらにあります。
https://gist.github.com/koshian2/43f1c4eff7a5c24a6e80661a458b71da

そして、学習上の注意なのですが、普通の係数を固定する場合でもそうかもしれませんが、今回は初期値のみで事前学習した係数がすぐ壊れてしまいます。オプティマイザーにRMSPropを使い、学習率を1e-5という極めて低い値にしてゆっくり学習させるようにしました。

入力サイズは64×64にしています。CIFARの元の解像度は32×32ですが、KerasのVGG16の入力サイズが48以上ないといけないという制約のためです。データのジェネレーターを自作し、その中でバッチ単位でリサイズするようにしました4

summaryで確認してみる

ちゃんとBatchNormが入っているかsummaryで確認してみます。

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 64, 64, 3)         0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 64, 64, 64)        1792      
_________________________________________________________________
batch_normalization (BatchNo (None, 64, 64, 64)        256       
_________________________________________________________________
activation (Activation)      (None, 64, 64, 64)        0         
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 64, 64, 64)        36928     
_________________________________________________________________
batch_normalization_1 (Batch (None, 64, 64, 64)        256       

確かに入っています。長いので省略しました。全体が見たい方はこちらをどうぞ。

結果

そのまま転移学習した例と、BatchNormを差し込んで転移学習した例です。
layer_01.png
同じ転移学習といえど、BatchNormを入れたほうが明らかに良かったです。収束の速度も精度も全然違います。他に代替手段があれば別にしても、やっぱり今どきBatchNormを入れないモデルはダサいのでは?:thinking:

MobileNetにDropoutを追加する

同じく転移学習として、MobileNetのActivationの後にDropoutを追加してみます。MobileNetは通常のConvとDepthwise Convの2種類を使ったVGGライクなモデルです5

def create_normal_model():
    model = MobileNet(include_top=False, weights="imagenet", input_shape=(128,128,3))
    x = GlobalAveragePooling2D()(model.layers[-1].output)
    x = Dense(10, activation="softmax")(x)

    return Model(model.inputs, x)

def create_dropout_model():
    model = create_normal_model()
    # ActivationのあとにDropoutを入れる
    for i, layer in enumerate(model.layers):
        if i==0:
            input = layer.input
            x = input
        else:
            if "relu" in layer.name:
                x = layer(x)
                x = Dropout(0.01)(x)
            else:
                x = layer(x)
    drop_model = Model(input, x)
    return drop_model

ちなみにこれは全てのActivationの後にDropoutを入れているのですが、Dropoutが強くなりすぎてるので1%にして弱くしています。1%でも26回入れれば、出力層に行くまでには3割近く落ちてしまいますので。

コードはこちらにあります。
https://gist.github.com/koshian2/f592f342634e815ad81998b51dd656a4

入力が128×128からしか受け付けてくれないので、4倍の引き伸ばしをしています。

結果

layer_02.png
Dropoutを反映することができました。ただ精度のほうはDropoutがないほうが良いという結果になっています。これはDropoutが強すぎて訓練速度が遅くなっているためです。1%でも全レイヤーに入れると強すぎるのですね。
ただ、Dropoutなしの場合は6~7%オーバーフィッティングしていたのが、Dropoutを入れると4~5%のオーバーフィッティングになったので、オーバーフィッティングの解消という点では効果が出ています。Dropoutとはそういうものなので、Dropoutがないほうが精度が良くなってはいますが、結果としては間違いではないです。

しかし、MobileNetを使えばCIFAR-10で93.5%も出るのですね。これはびっくりしました。

全てうまく行くとは限らない

このように学習済みモデルに任意のレイヤーを挿入することは可能ですが、Dropoutで見たように全てうまくいくとは限りません。他にも例えば、ResNetやInceptionのように、分岐や結合があるケース。これはできないわけではないですが、AddやMergeの引数が複数になるので動的な差し込みが少し難しくなります。今回見たようなVGG16やMobileNetのように直線的なモデルの場合は特に意識しなくていいです。

他には差し込む側のレイヤーの問題で、訓練可能な重みを持っているレイヤーの挿入。BatchNormやDropoutの場合は気にしなくてもOKですが、訓練可能な重みを持っている場合はまっさらなレイヤーを突っ込むことになるので、うまくいくかどうかかなり怪しくなります(失敗するというわけではありません)。レイヤー間で訓練にムラが出るので、そこはわかった上で使ってみてください。

以上、ここまで見ように、学習済みモデルに対して意外と簡単に任意のレイヤーを挿入することができました。

  1. Summaryはこちら https://blog.shikoan.com/keras-vgg/

  2. ただし、元のVGGがBatchNormなしで訓練しているので、後付でBatchNormを入れれば必ず精度が上がるかというと、その保証はできません

  3. これも転移学習といいます。転移学習だからといって係数を固定しなければいけないという制約はどこにもありません。

  4. 全体をリサイズしてfloat32にキャストするとメモリ効率が最悪なのでこうしています。全体をリサイズするなら、まだuint8でリサイズして持ってるほうがいいです。

  5. サマリーはこちら https://blog.shikoan.com/keras-mobilenet/

49
39
6

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
49
39

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?