LoginSignup
1

More than 1 year has passed since last update.

posted at

updated at

Tensorflow.kerasで学習済みモデルのレイヤーを差し替える

このエントリは TSG Advent Calendar 2020 の7日目の記事です。

現在のtensorflow(v2.3)ではtf.keras.applicationsに学習済みモデルが定義されており、転移学習などで活用することができます。しかしその学習モデルにDropoutを差し込む、レイヤーを差し替えるなどちょっとした変更を加えたくなることがたまにあると思います。こちらの記事ではVGG-likeな一本道のモデルに対してそのような変更を加える方法が紹介されていますが、ResNetなど分岐のあるモデルには使えません。そこで本記事ではそういったモデルにも対応した動的なモデル改変ができるスクリプトを紹介します。

BatchnormをSyncBatchnormに差し替える

kaggleなどで使えるTPUは8並列になっており、Batchnormを行うとそれぞれのデバイスで統計量を計算するため、バッチサイズが1/8になり(もともとのバッチサイズがあまり大きくないと)不安定になってしまいます。これの対策の一つは単純にバッチサイズを増やすというものがありますが、統計量をデバイス全体で共有する(SyncBatchnorm)という方法もあります。しかしこの場合、転移学習ではbackbone(学習済みモデル)内のレイヤーが普通のBatchnormなので、フリーズせずに学習させたい場合ここもSyncBatchnormに差し替える必要があります。そこでResNet50を例にとって、モデル内のBatchnormを差し替えてみましょう。

(追記)
https://www.tensorflow.org/guide/keras/transfer_learning によるとそもそも学習済みモデルのBatchnormの統計量はいじっちゃだめらしいです... unfreezeしてfinetuningする際もレイヤーへの入力はtraining=Falseを指定して推論モードにしておけ、でないと急に壊れるぞだそうです。
まあモデルをいちから学習させたい&いちいちレイヤーを書き下すの面倒という場合でも使えるので一応使いどころがなくもないはず...

tf.py
from collections import defaultdict
import tensorflow as tf

def get_sync_backbone():
    backbone = tf.keras.applications.ResNet50(include_top=False, weights='imagenet')
    mapping = defaultdict() # 元モデルのレイヤー名=>改変モデルでの同じ位置のレイヤーにおける出力

    for i, layer in enumerate(backbone.layers):
        if i == 0: # 一番底のレイヤ
            inpt = layer.input  # backboneモデルの下端のテンソル(Input)
            x = layer.input
            out_name = layer.output.name
            mapping[layer.output.name] = x  # モデルの上方でこのレイヤと繋がっている場合はこのテンソルを持ってきて入力する
            continue

        # 元モデルのレイヤーに入力されるテンソルに対応した、改変後モデルにおけるテンソルを持ってくる
        if type(layer.input) is list: # layer.inputは複数入力のときだけlistになっている
            input_tensors = list(map(lambda t: mapping[t.name], layer.input))
        else:
            input_tensors = mapping[layer.input.name]

        out_name = layer.output.name
        # ここで差し替え
        if isinstance(layer, tf.keras.layers.BatchNormalization):
            newlayer = tf.keras.layers.experimental.SyncBatchNormalization(
                momentum=0.9, # TensorflowではBatchnormのmomentumが0.99らしい。ここではPytorchと同じ0.9に
                beta_initializer=tf.initializers.constant(layer.beta.numpy()),
                gamma_initializer=tf.initializers.constant(layer.gamma.numpy()),
                moving_mean_initializer=tf.initializers.constant(layer.moving_mean.numpy()),
                moving_variance_initializer=tf.initializers.constant(layer.moving_variance.numpy()))
            x = newlayer(input_tensors)
        else:
            # 差し替えの必要がないレイヤーは再利用
            x = layer(input_tensors)
        mapping[out_name] = x
    return tf.keras.Model(inpt, x)

summaryを見てみると、

Model: "functional_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, None, None,  0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, None, None, 3 0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, None, None, 6 9472        conv1_pad[1][0]                  
__________________________________________________________________________________________________
sync_batch_normalization (SyncB (None, None, None, 6 256         conv1_conv[1][0]                 
__________________________________________________________________________________________________
conv1_relu (Activation)         (None, None, None, 6 0           sync_batch_normalization[0][0]   
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, None, None, 6 0           conv1_relu[1][0]                 
__________________________________________________________________________________________________

このようにBatchnormがあったところがSyncBatchnormに差し替えられていることが分かります。
Connected toの要素がところどころ[1][0]になっていますが、これはレイヤーを再利用したためでしょう。

最後に

本記事ではTPUなどの複数デバイスにおける転移学習を想定し、学習済みモデルのBatchnormをSyncBatchnormに変更する方法を紹介しましたが、これを行うことで必ずしも学習が改善するとは限りません。こんなことをせずとも純粋にバッチサイズを増やしてしまったほうが楽だし、色々いじる前にまずデバイスを一つに限定してみて、そもそもバッチサイズの分割が原因なのかどうかをまず確認したほうが良いでしょう。私の場合はあれこれ調べた挙句、バッチサイズ云々ではなく学習率が高すぎるのが原因でした(Radamによるwarmupで解決しました)。機械学習ナンもわからん

次のカレンダーはkcz146さんの「絶対書く なんかかく」です。お楽しみに。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
1