2
1

More than 3 years have passed since last update.

tf.kerasの定義済みモデルの入力側のレイヤーを削除する方法

Last updated at Posted at 2020-12-14

はじめに

Kerasの定義済みモデルの入力側を変更したい場合がある。少し面倒だが不可能ではないので方法を紹介する。

動機

tf.ketasではResNetやNASnet等各種モデルを作成してくれるモジュールがある。しかし、例えばResNet50では224x224サイズの画像を前提にしているので、入力層付近(Stem)で元の画像を1/4サイズにしてしまう。その他のモデルでも程度の差はあれど同様の処理が入っている。
これは、CIFARのように32x32といった小さいサイズの画像入力で問題になる。モデル内では画像のサイズを随時縮小していくが、元画像が小さいと後のほうで有効な畳み込みができなくなるからである。
こういった問題に対処するためには、ふつうは前処理で224x224に拡大して入力するはずだが、なんだか無駄に思えるので「モデルから入力付近の処理を取り除きたい」、というのが動機。

環境

  • TensorFlow 2.3.0
  • tf.keras 2.4.0

方法

tf.kerasで用意されているResNet50V2を例にして、手順を解説する。

入力層切り落とし用の関数を用意

これはstackoverflow.com内でのコメントをもとにして作成した。
切り落とす部分の最後のレイヤー名を指定すると、そのレイヤーから先の層だけ返す。

def CutoffInputLayers(model, target_layer_name):
    def f(x):
        input_layers = {}
        target_layer_index = None
        for i, layer in enumerate(model.layers):
            if layer.name == target_layer_name:
                target_layer_index = i
            for node in layer._outbound_nodes:
                layer_name = node.outbound_layer.name
                if layer_name not in input_layers:
                    input_layers.update(
                            {layer_name: [layer.name]})
                else:
                    input_layers[layer_name].append(layer.name)

        if target_layer_index==None:
            raise ValueError(target_layer_name+" not found.")
        new_output_tensor = {}

        model_outputs = []
        new_output_tensor.update(
            {model.layers[target_layer_index].name: x})
        for layer in model.layers[target_layer_index+1:]:

            layer_input = [new_output_tensor[layer_aux] for layer_aux in input_layers[layer.name]]
            if len(layer_input) == 1:
                layer_input = layer_input[0]

            x = layer(layer_input)

            new_output_tensor.update({layer.name: x})
            if layer.name in model.output_names:
                model_outputs.append(x)
        return x
    return f

対象となるモデルを観察

どの部分まで切り落としたいか、を見極めるためにもとになるモデルを観察する。

from tensorflow.keras.applications import resnet_v2
from tensorflow.keras.utils import plot_model
resnet = resnet_v2.ResNet50V2(weights=None)
plot_model(resnet,show_shapes=True, to_file='ResNet50V2.png')

入力付近だけ拡大すると以下の通り。

ResNet50V2_head.png

residual blockの直前が'conv1_conv'とわかるので、そこまでを切り離す。
切り離した後の入力は56x56x64になっている。今回はCIFAR10を想定するので、ここが32x32x64だと後から追加する入力層ですこし楽になる。
これはモデルにもよるが、ResNet50V2の場合はinput_shape=(128,128,3)とすると、ここがちょうど32x32x64になる。
この辺の調査や調整が少し面倒なところ。

入力層を入れ替えて新しいモデルを作る

入力層を切り離したモデルでは、そのままではチャンネル数が64になっている。これはConv2Dを入れることによって変換できるので、最終的にモデルを作成するコードは以下の様になる。

import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.applications import resnet_v2
from tensorflow.keras.utils import plot_model
from tensorflow.keras.models import Model

resnet = resnet_v2.ResNet50V2(input_shape=(128,128,3),weights=None, classes=10)

input = layers.Input(shape=(32,32,3), name='Input')
x = layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', 
    kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(1.e-4),
    name='stem_conv2d' )(input)
x = CutoffInputLayers(resnet, 'pool1_pool')(x)

model = Model(input, x)
model.build(input_shape=(32,32,3))
plot_model(model,show_shapes=True, to_file='mod_ResNet50V2.png')

新しいモデルの様子はこちら。

replaced_ResNet50V2_head.png

入力段が差し変わり、32x32x64の入力をresidual blockに接続していることがわかる。

ちなみに、resizeを入れて内部で32x32から56x56にする手もある。リサイズが無駄なように思えるが、本来の設計に近くなるのでこちらのほうが安全かもしれない。
その場合のコードはこちら。

resnet = resnet_v2.ResNet50V2(input_shape=(224,224,3),weights=None, classes=10)

input = layers.Input(shape=(32,32,3), name='Input')
x = layers.Lambda(lambda image: tf.image.resize(image, [56, 56]), output_shape=(56,56,3), name='stem_resize')(input)
x = layers.Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', name='stem_conv2d',
                    kernel_initializer='he_normal', kernel_regularizer=tf.keras.regularizers.l2(1.e-4))(x)
x = CutoffInputLayers(resnet, 'pool1_pool')(x)

model = Model(input, x)
model.build(input_shape=(32,32,3))
plot_model(model,show_shapes=True, to_file='mod2_ResNet50V2.png')  

注意

  • 例に示したResNet50V2ベースのモデルは素の状態から学習させることができるが、その他のモデル(NASnetやMobileNetV2)では学習が進まなかった。学習済みの重みを読み込んで使用することが前提で、kernel_initializer等が設定されていなかったりするのかもしれない。1
  • 転移学習もできなくもないはずだが、入力段が切り替わってるので、いろいろ工夫がいるかもしれない。

まとめ

ResNet50をサンプルとして、入力段を差し替える手順をしめした。
注意点に書いたように、入力段を変えたモデルが作れたからと言って素直に動作するわけではないので、実用性がどれくらいあるかは不明。
こんな方法もあるということで、何かの参考になれば。

コード


  1. 一部モデルではValidationの数値が上昇するまで時間がかかるようで、記事作成時に「学習できない」と書いてしまったが、たいていエポックが進めば学習が進むので取り消し。 

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1