LoginSignup
4
1

Tensorflow(keras)でもtimmみたいに任意ch入力でimagenet pretrainedモデルを使いたい!

Last updated at Posted at 2023-12-26

はじめに

みなさんゴリゴリにtensorflow kerasでdeep learningをして過ごされていることかと思います。
tensorflowは本当に最高なライブラリですが、pretrained weightを使った任意ch入力のモデルを作ろうとしたときにpytorchのtimmのようにpretrained weightをそのまま使えなくて悲しい気持ちになりますよね。
そこで、tensorflowでもtorchのtimmみたいに入力chを任意のchに変更できるようにしたい!というのが今回の記事になります。
やってみたという感じなので、間違いなどあればご指摘いただけると幸いです。

みんな大好きtimmのかっこいいch拡張

まずみんな大好きtimmがどのようにしてrgbのweightから任意chのweightへと拡張しているのかを確認してみます。
timmでは、入力chが3よりも大きいとき、chに合わせてpretrainedのconvからrgbのweightをrepeatすることで任意入力に対応しているようです。→link
つまり、chが増えていくごとにrgbrgbr...とrgbのchのconvを繰り返すことでpretrained weightをほかのchにも活用しているということですね。

やってみた

tensorflowでも同じことができると便利なのでefficientnetを例に任意ch入力ができるよう拡張してみました。

まずは元々使われているconvや入力を取り除く必要があるので、指定した層以降のみのモデルを使う関数を用意します。こちらの記事に便利な関数を掲載していただいていたのでありがたく使わせていただきます。

def get_cut_input_model(model, target_layer_name):
    def f(x):
        input_layers = {}
        target_layer_index = None
        for i, layer in enumerate(model.layers):
            if layer.name == target_layer_name:
                target_layer_index = i
            for node in layer._outbound_nodes:
                layer_name = node.outbound_layer.name
                if layer_name not in input_layers:
                    input_layers.update({layer_name: [layer.name]})
                else:
                    input_layers[layer_name].append(layer.name)

        if target_layer_index is None:
            raise ValueError(target_layer_name + " not found.")
        new_output_tensor = {}

        model_outputs = []
        new_output_tensor.update({model.layers[target_layer_index].name: x})
        for layer in model.layers[target_layer_index + 1 :]:

            layer_input = [
                new_output_tensor[layer_aux] for layer_aux in input_layers[layer.name]
            ]
            if len(layer_input) == 1:
                layer_input = layer_input[0]

            x = layer(layer_input)

            new_output_tensor.update({layer.name: x})
            if layer.name in model.output_names:
                model_outputs.append(x)
        return model_outputs

    return f

次に、efficientnetb0の最初のconvであるstem_convのweightを取り出して、指定したchまで並べたあと、以降の層は元のモデルとつながるようにしました。

def get_extended_input_model(model, input_shape, input_channel):
    extended_input = tf.keras.layers.Input(shape=(None, None, input_channel))
    x = tf.keras.layers.Rescaling(scale=1.0 / 255)(extended_input)
    x = tf.keras.layers.Normalization(axis=3)(x)
    # 任意のchに合わせて最初のconvのweightを拡張する
    stem_conv = model.get_layer(name="stem_conv")
    orig_weight = stem_conv.weights[0]
    tile_num = input_channel // 3 + 1
    tiled_weight = tf.tile(orig_weight, [1, 1, tile_num, 1])
    tiled_weight = tiled_weight[
        :, :, :input_channel, :
    ] # [filter_height, filter_width, in_channels, out_channels]
    tiled_conv = tf.keras.layers.Conv2D(
        filters=32,
        kernel_size=3,
        strides=2,
        padding="same",
        use_bias=False,
        name="stem_conv",
    )
    tiled_conv.build(input_shape=(None, input_shape[0], input_shape[1], input_channel))
    tiled_conv.set_weights([tiled_weight])
    x = tiled_conv(x)
    # stem_conv以降は元のmodelを使う
    x = get_cut_input_model(model, "stem_conv")(x)
    extended_input_model = tf.keras.Model(inputs=extended_input, outputs=x)
    extended_input_model.build(
        input_shape=(None, input_shape[0], input_shape[1], input_channel)
    )
    return extended_input_model

正しく変更できているか確認します。

input_channel = 8
input_shape = (256, 256)
effnet = tf.keras.applications.EfficientNetB0(
    include_top=False,
    weights="imagenet",
)
extended_input_effnet = get_extended_input_model(
    effnet, input_shape=input_shape, input_channel=input_channel
)
print(extended_input_effnet.summary())
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_2 (InputLayer)        [(None, None, None, 8)]      0         []                            
                                                                                                  
 rescaling_2 (Rescaling)     (None, None, None, 8)        0         ['input_2[0][0]']             
                                                                                                  
 normalization_1 (Normaliza  (None, None, None, 8)        17        ['rescaling_2[0][0]']         
 tion)                                                                                            
                                                                                                  
 stem_conv (Conv2D)          (None, None, None, 32)       2304      ['normalization_1[0][0]']     
                                                                                                  
 stem_bn (BatchNormalizatio  (None, None, None, 32)       128       ['stem_conv[0][0]']           
 n)                                                                                               
                                                                                                  
 stem_activation (Activatio  (None, None, None, 32)       0         ['stem_bn[1][0]']             
 n)                                                                                               
                                                                                                  
 block1a_dwconv (DepthwiseC  (None, None, None, 32)       288       ['stem_activation[1][0]']     
 onv2D)                                                                                          .
 .

ここでは元々あった3chから8chへと入力chが拡張されていることがわかります。
元々tensorflowのefficientnetでは入力層直後に255で除算する層(rescaling層)があります。今回はその層を残した実装になっていますが、入力を0~1にする場合には不要になるので削除する必要があります。

まとめ

timmで設定されているようにconvのpretrained weightを増えたchにも利用することで、Tensorflowでも任意のchに対してもpretrained weightを使ったモデル構築を行うことができました。
pretrained weightがちゃんと設定されているかどうかについては関数内のset_weightをするかしないかで関数を実行し、set_weightしない場合では0chと3chで差があり、set_weightすると0chと3chが同じ値になることからrgbのconvのweightを繰り返し活用していることが確認できます。
get_cut_input_modelの関数内でoutputを中間層を取り出すように変更することでsegmentationのencoderとして使うこともできます。
これからはTensorflowでもドヤ顔でtimmのようなch拡張ができます!やったね!

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1