1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasでユーザーからのアノテーション情報を受け入れるネットワークを記述する

Posted at

ユーザーからのアノテーション(重み付け)情報を受けた場合、処理を切り替えるDeep LearningのネットワークモデルをKerasで実装するにはどうすればよいか試行錯誤したので、内容をまとめてみます。
最近はPyTorchでの実装がメインだったので、記述方法の違いに戸惑うところがありました。
Kerasで複雑めなネットワークを記述するにはfunctional APIを使います。
参考:keras functional APIの使い方メモ のQiitaページ

Functional APIではkeras.layersで定義される層をつなげていく必要があります。
今回のように独自の処理の層を入れるにはLambdaを使って実装する必要があります。
図のような画像認識タスクを想定したネットワークのコード例を下に示します。
SelfAttentionNetwork.png

通常であれば、ネットワークに対して入力データとして原画像が与えられます。今回はそれに加え、原画像に対応した重み付けマップ、その重み付けマップを使うかのフラグ(0 or 1)で、計3つの入力を与えます。
原画像、重み付けマップ、重み付けマップを使うかのフラグは入力データセットによりサイズが変わるので、keras.layers.Inputで定義されます。そのため、単純にIf文などで判定して処理を切り替えるわけにいきません。

ネットワークモデルのコード

from keras.models import Model
from keras.layers import Conv2D, Activation, BatchNormalization, GlobalAveragePooling2D, Dense, Input, Lambda, Add, Multiply
from keras.backend import switch as k_switch
from keras.backend import equal as k_equal
import numpy as np

def net(x, user_weight_map, user_weight_map_flg, feature_ch=16):
 """
    x: 原画像
    user_weight_map: ユーザーから与えられた重み付けマップ
    user_weight_map_flg: ユーザーから与えられた重み付けマップを使うかのフラグ
    """
    # 4回Convolutionをかける
    h = Conv2d(feature_ch, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)

    h = Conv2d(feature_ch*2, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*4, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*8, 3, strides=2, padding='same')(x)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    #---------------------
    # 分岐ネットワークで重み付けマップ(Self Attention)を計算する
    bh = Conv2D(feature_ch*4, 3, strides=1, padding='same')(h)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    bh = Conv2D(feature_ch*2, 3, strides=1, padding='same')(bh)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    bh = Conv2D(2, 1, strides=1, padding='same')(bh)
    bh = BatchNormalization()(bh)
    bh = Activation(activation='relu')(bh)
    
    model_weight = Conv2D(1, 3, strides=1, padding='same')(bh)
    model_weight = BatchNormalization()(bh)
    model_weight = Activation(activation='sigmoid', name='model_weight_output')(bh)

    bh = Conv2D(2, 1, strides=1, padding='same')(bh)
    bh = GlobalAveragePooling2D()(ah)
    bh = Dense(1000)(bh)
    #---------------------

    # フラグ情報を読み取り、ネットワークから算出された重み付けマップを使うか、ユーザーの作った重み付けマップを使うか切り替える
    weight_h = Lambda(lambda x: switch_weight_map(x), name='swith_weight_map')([h, model_weight, user_weight_map, user_weight_map_flg])

    h = Add(name='weight_map_add')([h, weight_h])

    h = Conv2d(feature_ch*16, 3, strides=2, padding='same')(h)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)
    
    h = Conv2d(feature_ch*32, 3, strides=2, padding='same')(h)
    h = BatchNormalization()(h)
    h = Activation(activation='relu')(h)

    h = GlobalAveragePooling2D()(h)
    h = Dense(1000)(h)
  
    return h, bh, model_weight

def switch_weight_map(inputs):
    feature_map = inputs[0]
    model_weight_map = inputs[1]
    user_weight_map = inputs[2]
    user_weight_map_flg = inputs[3]
    
    model_weight = Multiply()([feature_map, model_weight_map])
    user_weight = Multiply()([feature_map, user_weight_map])

    weight_cond = k_equal(user_weight_map_flg, 0)
    
    weight_h = k_switch(weight_cond, model_weight, user_weight)

    return weight_h

# Save Network Architecture
def save_network_param(save_path, feature_ch):
    param = {'base_feature_num':feature_ch}
    
    with open(save_path, 'w') as f:
        yaml.dump(param, f, default_flow_style=False)

# Load Network Architecture
def load_network_param(load_path):
    with open(load_path) as f:
        param = yaml.load(f)

    return param

学習処理を回しているときに、コールバック関数keras.callbacks.ModelCheckpoint()で引数をsave_weights_only=Falseにしてepochごとにモデルをsaveしようとすると、エラーメッセージとして"can't pickle _thread.RLock objects"のようなものが出ました。
また、model.to_json()やmodel.to_yaml()でモデルを書き出しておこうとしても同様のエラーが出ました。
Lambdaに入力データが与えられるまで不定形のInputがあることなどでpickleのSerializeができないようなことが起こっているようでした。
keras.callbacks.ModelCheckpoint()では引数をsave_weights_only=Trueで保存しておきます。save_network_param()とload_network_param()を用意し、trainで作ったモデルをpredictで使うにはネットワークのコードと書き出したyamlファイルでネットワーク構造を再現し、model.load_weights()で各層の重みをセットします。

Lambdaを使った実装において、引数xに対して[h, model_weight, user_weight_map, user_weight_map_flg]のようにリスト化して与えるのがコツでした。
以下のようにLambdaの引数xに対してuser_weight_map_flgだけをとるようにしてしまうと、Kerasがネットワーク構造を解釈し、saveやloadをする際にmodel_weightが他の層に繋がるか判別できないようでうまくいきませんでした。

weight_h = Lambda(lambda x:k_switch(k_equal(x, 0), model_weight, user_weight), name='switch_weight_map')(user_weight_map_flg)

参考情報

https://stackoverflow.com/questions/52448652/attributeerror-nonetype-object-has-no-attribute-inbound-nodes-while-trying
https://stackoverflow.com/questions/44855603/typeerror-cant-pickle-thread-lock-objects-in-seq2seq
https://github.com/keras-team/keras/issues/8343
https://github.com/matterport/Mask_RCNN/issues/1126
https://stackoverflow.com/questions/53212672/read-only-mode-in-keras
https://stackoverflow.com/questions/47066635/checkpointing-keras-model-typeerror-cant-pickle-thread-lock-objects/55229794#55229794
https://blog.shikoan.com/lambda_arguments/
https://github.com/keras-team/keras/issues/6621
https://stackoverflow.com/questions/59635570/keras-backend-k-switch-for-loss-function-error

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?