ユーザーからのアノテーション(重み付け)情報を受けた場合、処理を切り替えるDeep LearningのネットワークモデルをKerasで実装するにはどうすればよいか試行錯誤したので、内容をまとめてみます。
最近はPyTorchでの実装がメインだったので、記述方法の違いに戸惑うところがありました。
Kerasで複雑めなネットワークを記述するにはfunctional APIを使います。
参考:keras functional APIの使い方メモ のQiitaページ
Functional APIではkeras.layersで定義される層をつなげていく必要があります。
今回のように独自の処理の層を入れるにはLambdaを使って実装する必要があります。
図のような画像認識タスクを想定したネットワークのコード例を下に示します。
通常であれば、ネットワークに対して入力データとして原画像が与えられます。今回はそれに加え、原画像に対応した重み付けマップ、その重み付けマップを使うかのフラグ(0 or 1)で、計3つの入力を与えます。
原画像、重み付けマップ、重み付けマップを使うかのフラグは入力データセットによりサイズが変わるので、keras.layers.Inputで定義されます。そのため、単純にIf文などで判定して処理を切り替えるわけにいきません。
ネットワークモデルのコード
from keras.models import Model
from keras.layers import Conv2D, Activation, BatchNormalization, GlobalAveragePooling2D, Dense, Input, Lambda, Add, Multiply
from keras.backend import switch as k_switch
from keras.backend import equal as k_equal
import numpy as np
def net(x, user_weight_map, user_weight_map_flg, feature_ch=16):
"""
x: 原画像
user_weight_map: ユーザーから与えられた重み付けマップ
user_weight_map_flg: ユーザーから与えられた重み付けマップを使うかのフラグ
"""
# 4回Convolutionをかける
h = Conv2d(feature_ch, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*2, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*4, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*8, 3, strides=2, padding='same')(x)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
#---------------------
# 分岐ネットワークで重み付けマップ(Self Attention)を計算する
bh = Conv2D(feature_ch*4, 3, strides=1, padding='same')(h)
bh = BatchNormalization()(bh)
bh = Activation(activation='relu')(bh)
bh = Conv2D(feature_ch*2, 3, strides=1, padding='same')(bh)
bh = BatchNormalization()(bh)
bh = Activation(activation='relu')(bh)
bh = Conv2D(2, 1, strides=1, padding='same')(bh)
bh = BatchNormalization()(bh)
bh = Activation(activation='relu')(bh)
model_weight = Conv2D(1, 3, strides=1, padding='same')(bh)
model_weight = BatchNormalization()(bh)
model_weight = Activation(activation='sigmoid', name='model_weight_output')(bh)
bh = Conv2D(2, 1, strides=1, padding='same')(bh)
bh = GlobalAveragePooling2D()(ah)
bh = Dense(1000)(bh)
#---------------------
# フラグ情報を読み取り、ネットワークから算出された重み付けマップを使うか、ユーザーの作った重み付けマップを使うか切り替える
weight_h = Lambda(lambda x: switch_weight_map(x), name='swith_weight_map')([h, model_weight, user_weight_map, user_weight_map_flg])
h = Add(name='weight_map_add')([h, weight_h])
h = Conv2d(feature_ch*16, 3, strides=2, padding='same')(h)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = Conv2d(feature_ch*32, 3, strides=2, padding='same')(h)
h = BatchNormalization()(h)
h = Activation(activation='relu')(h)
h = GlobalAveragePooling2D()(h)
h = Dense(1000)(h)
return h, bh, model_weight
def switch_weight_map(inputs):
feature_map = inputs[0]
model_weight_map = inputs[1]
user_weight_map = inputs[2]
user_weight_map_flg = inputs[3]
model_weight = Multiply()([feature_map, model_weight_map])
user_weight = Multiply()([feature_map, user_weight_map])
weight_cond = k_equal(user_weight_map_flg, 0)
weight_h = k_switch(weight_cond, model_weight, user_weight)
return weight_h
# Save Network Architecture
def save_network_param(save_path, feature_ch):
param = {'base_feature_num':feature_ch}
with open(save_path, 'w') as f:
yaml.dump(param, f, default_flow_style=False)
# Load Network Architecture
def load_network_param(load_path):
with open(load_path) as f:
param = yaml.load(f)
return param
学習処理を回しているときに、コールバック関数keras.callbacks.ModelCheckpoint()で引数をsave_weights_only=Falseにしてepochごとにモデルをsaveしようとすると、エラーメッセージとして"can't pickle _thread.RLock objects"のようなものが出ました。
また、model.to_json()やmodel.to_yaml()でモデルを書き出しておこうとしても同様のエラーが出ました。
Lambdaに入力データが与えられるまで不定形のInputがあることなどでpickleのSerializeができないようなことが起こっているようでした。
keras.callbacks.ModelCheckpoint()では引数をsave_weights_only=Trueで保存しておきます。save_network_param()とload_network_param()を用意し、trainで作ったモデルをpredictで使うにはネットワークのコードと書き出したyamlファイルでネットワーク構造を再現し、model.load_weights()で各層の重みをセットします。
Lambdaを使った実装において、引数xに対して[h, model_weight, user_weight_map, user_weight_map_flg]のようにリスト化して与えるのがコツでした。
以下のようにLambdaの引数xに対してuser_weight_map_flgだけをとるようにしてしまうと、Kerasがネットワーク構造を解釈し、saveやloadをする際にmodel_weightが他の層に繋がるか判別できないようでうまくいきませんでした。
weight_h = Lambda(lambda x:k_switch(k_equal(x, 0), model_weight, user_weight), name='switch_weight_map')(user_weight_map_flg)
参考情報
https://stackoverflow.com/questions/52448652/attributeerror-nonetype-object-has-no-attribute-inbound-nodes-while-trying
https://stackoverflow.com/questions/44855603/typeerror-cant-pickle-thread-lock-objects-in-seq2seq
https://github.com/keras-team/keras/issues/8343
https://github.com/matterport/Mask_RCNN/issues/1126
https://stackoverflow.com/questions/53212672/read-only-mode-in-keras
https://stackoverflow.com/questions/47066635/checkpointing-keras-model-typeerror-cant-pickle-thread-lock-objects/55229794#55229794
https://blog.shikoan.com/lambda_arguments/
https://github.com/keras-team/keras/issues/6621
https://stackoverflow.com/questions/59635570/keras-backend-k-switch-for-loss-function-error