search
LoginSignup
0

posted at

updated at

kerasでカスタムレイヤ内に重みを複数追加する場合は名前を指定した方が良い

はじめに

タイトル通り。
研究しているときに遭遇した問題とその直し方について共有。

なぜ必要なのか

kerasでカスタムレイヤを作ったとき、tf.Variableあるいはself.add_weightをつかって重みを追加し、model.save_weightsh5を指定すると次のようなエラーが発生する。

import tensorflow as tf

class Preprocess(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.mean = tf.Variable(tf.reshape([0.485, 0.456, 0.406], (1, 1, 1, -1)),
                                dtype=tf.float32, trainable=False)
        self.std = tf.Variable(tf.reshape([0.229, 0.224, 0.225], (1, 1, 1, -1)),
                               dtype=tf.float32, trainable=False)

    def call(self, inputs):
        return (inputs - self.mean) / self.std
    

def gen_model():
    inputs = tf.keras.Input(shape=(None, None, 3), name='vgg_input')
    x = Preprocess(name='preprocess_layer')(inputs)
    output = x = tf.keras.layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
    model = tf.keras.Model(inputs, output)
    return model
    
gen_model().save_weights('bad_model.h5')
ValueError: Unable to create dataset (name already exists)

原因

kerasがやっているhdf5の保存方法はシンプルで、各レイヤの名前を元にh5のkeyを決めて保存している。
例えば、vggのh5の中身を見るとこんな感じで保存されている。

block1_conv1
block1_conv1/block1_conv1
block1_conv1/block1_conv1/bias:0 (64,)
block1_conv1/block1_conv1/kernel:0 (3, 3, 3, 64)
block1_conv2
block1_conv2/block1_conv2
block1_conv2/block1_conv2/bias:0 (64,)
block1_conv2/block1_conv2/kernel:0 (3, 3, 64, 64)
block1_pool
...

kernelbiasConv2Dの重み。それぞれにbias:0の名前とkernel:0という名前が割り当てられている。
どこで付けられているかは、githubを参照

では、上で紹介したモデルでは、どうなっているかというと

print(gen_model().get_layer('preprocess_layer').weights)
[<tf.Variable 'preprocess_layer/Variable:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.485, 0.456, 0.406]]]], dtype=float32)>,
 <tf.Variable 'preprocess_layer/Variable:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.229, 0.224, 0.225]]]], dtype=float32)>]

つまり、異なる重みに対して同じ名前が割り当てられるため、nameが重複してh5が保存できなくなってしまっている。

解決策

nameを指定すればok。

class Preprocess(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.mean = tf.Variable(tf.reshape([0.485, 0.456, 0.406], (1, 1, 1, -1)),
                                dtype=tf.float32, trainable=False, name='mean')
        self.std = tf.Variable(tf.reshape([0.229, 0.224, 0.225], (1, 1, 1, -1)),
                               dtype=tf.float32, trainable=False, name='std')
...
[<tf.Variable 'preprocess_layer/mean:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.485, 0.456, 0.406]]]], dtype=float32)>,
 <tf.Variable 'preprocess_layer/std:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.229, 0.224, 0.225]]]], dtype=float32)>]

おわりに

ちなみに、h5指定ではなくckpt指定にすると上のエラーは出ずに保存できます。
ただckptは複数ロードするとエラーが出まくるので、デバッグとか含めてh5の方がやりやすいです。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
What you can do with signing up
0