7
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Tensorflow.kerasの事前学習済みモデルで、どうしても入力チャンネル数を変えたい!

Last updated at Posted at 2021-03-21

はじめに

tensorflow.kerasには事前学習済みモデルがあります。
tf.keras.applications.EfficientNetB0(weights="imagenet")といった感じで、モデルが作れるので、ちょっとしたときにささっと活用できます。

ですが、入力チャンネルのデフォルトはRGB3チャンネルなので、

  • グレースケース1chで控えめにいきたいとき
  • 多チャンネル入力でガンガンいきたいとき

に対応することができません。

inputs = Input([None,None,5]) pretrained_model = EfficientNetB0(weights="imagenet", input_tensor=inputs)
と試したところで、きっちりエラーをはいてくれます。

そんなこんなでお悩みな初心者向けの記事になります。

チャンネル数を減らしたい場合

『入力1チャンネルでいいんですよ』というケース。
これは比較的簡単に解決できます。

① チャンネルをコピーして増やす

同じチャンネルを3つに増やして、無理やり辻褄を合わせます。
冗長な行列になりますが、numpy やtfのtileを使うと簡単にできます。たとえば、

inputs = Input([None,None,1])
x = Lambda(lambda x: tf.tile(x, [1,1,1,3]))(inputs)
pretrained_model = EfficientNetB0(weights="imagenet", input_tensor=x)

② 最初に畳み込みを一個かます

事前学習済みモデルの入力前に出力チャンネルを3にする畳み込みをかまします。
任意のチャンネル数の入力で使えます。

inputs = Input([None,None,2])
x = Conv2D(filters=3, kernel_size=1)(inputs)
pretrained_model = EfficientNetB0(weights="imagenet", input_tensor=None)
outputs = pretrained_model(x)
new_model = Model(inputs, outputs)

チャンネル数を増やしたい場合

『私の入力は53チャンネルです』というケース。
これはちょっと厄介です。上述②の畳み込みも使えますが、チャンネル数が無駄に多いと事前学習済みモデルの入力前に情報が大幅に圧縮されてしまうのが難点です。

個人的には事前学習済みモデルの中の最初に出てくる畳み込みを書き換えるぐらいの気持ちでどうにかしていただきたいところですが、私のテクニックtf.keras理解不足か、なかなかうまくいきません。
こうなったら全部レイヤーばらして再構成してやりましょう。

レイヤーの入出力関係を記録しておいて、最初の畳み込みだけ再定義して繋ぎなおせばいけるはずです。

def rebuild_model(pretrained_model, inputs):
    name_to_num = {}
    for i, layer in enumerate(pretrained_model.layers):
        name = layer.name
        name_to_num[name] = i
    
    # 最初の畳み込みを書き換えるので、パラメータを保管しておく。ついでに前処理レイヤも保管。
    preprocessing_layers = []
    for i, layer in enumerate(pretrained_model.layers):
        if "preprocessing" in str(layer.__class__):
            preprocessing_layers.append(i)
        if "Conv2D" in str(layer.__class__):# rebuild first conv2d
            rebuild_layer_num = i 
            rebuild_layer_params = {"filters":layer.filters,
                                    "kernel_size": layer.kernel_size,
                                    "strides": layer.strides,
                                    "activation": layer.activation,
                                    "use_bias": layer.use_bias,
                                    "name": layer.name}
            break
    
    # あとから再構成できるように、各レイヤーの入力を読んでおく。
    # ResNetやEfficientNetは分岐があるので、複数入力になるケースがあることに注意。
    in_layer_nums = []
    for i, layer in enumerate(pretrained_model.layers):
        layer_in = layer.input
        if not type(layer_in)==list:
            in_names = layer_in.name.split("/")[0].split(":")[0]
            in_nums = name_to_num[in_names]
        else:
            in_names = [l_in.name.split("/")[0].split(":")[0] for l_in in layer_in]
            in_nums = [name_to_num[in_name] for in_name in in_names]
        in_layer_nums.append(in_nums)
    
    # 再構成する。最初の畳み込みだけ作り直し。
    # 前処理レイヤは邪魔なので消す。(最初の0-255の変換処理でチャンネル数が違うことによるエラーが起きるので)
    new_layers = [inputs]
    for i, [layer, in_nums] in enumerate(zip(pretrained_model.layers, in_layer_nums)):
        if i==0:
            continue
        elif i in preprocessing_layers:
            print("skip preprocessing_layers: ", layer.name)
            x = new_layers[in_nums]
        elif i == rebuild_layer_num:
            print("rebuild conv2d layers: ", layer.name)
            x = Conv2D(**rebuild_layer_params)(new_layers[in_nums])
        else:
            if type(in_nums)==list:
                x = layer([new_layers[num] for num in in_nums])
            else:
                x = layer(new_layers[in_nums])
        new_layers.append(x)
    
    model = Model(inputs, x)
    return model

という感じに作ってやれば、あとは簡単。

pretrained_model = EfficientNetB0(weights="imagenet" ,input_shape=[None,None, 3], input_tensor=None)
inputs = Input([None,None,5])
new_model = rebuild_model(pretrained_model, inputs)

tf.kerasの事前学習済みモデルには255除算のような正規化処理が含まれており、チャンネルが変わると怒られるのでそれを消しています。あとは畳み込みを再定義しなおしていますのでここのweightは初期化されています。ご留意のほど。

ということで、任意の入力チャンネル数にすることができました。めでたしめでたし!

7
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
7
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?