はじめに
タイトル通り。
研究しているときに遭遇した問題とその直し方について共有。
なぜ必要なのか
kerasでカスタムレイヤを作ったとき、tf.Variable
あるいはself.add_weight
をつかって重みを追加し、model.save_weights
でh5
を指定すると次のようなエラーが発生する。
import tensorflow as tf
class Preprocess(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mean = tf.Variable(tf.reshape([0.485, 0.456, 0.406], (1, 1, 1, -1)),
dtype=tf.float32, trainable=False)
self.std = tf.Variable(tf.reshape([0.229, 0.224, 0.225], (1, 1, 1, -1)),
dtype=tf.float32, trainable=False)
def call(self, inputs):
return (inputs - self.mean) / self.std
def gen_model():
inputs = tf.keras.Input(shape=(None, None, 3), name='vgg_input')
x = Preprocess(name='preprocess_layer')(inputs)
output = x = tf.keras.layers.Conv2D(32, (3, 3), padding='same',activation='relu')(x)
model = tf.keras.Model(inputs, output)
return model
gen_model().save_weights('bad_model.h5')
ValueError: Unable to create dataset (name already exists)
原因
kerasがやっているhdf5の保存方法はシンプルで、各レイヤの名前を元にh5のkeyを決めて保存している。
例えば、vggのh5の中身を見るとこんな感じで保存されている。
block1_conv1
block1_conv1/block1_conv1
block1_conv1/block1_conv1/bias:0 (64,)
block1_conv1/block1_conv1/kernel:0 (3, 3, 3, 64)
block1_conv2
block1_conv2/block1_conv2
block1_conv2/block1_conv2/bias:0 (64,)
block1_conv2/block1_conv2/kernel:0 (3, 3, 64, 64)
block1_pool
...
kernel
とbias
はConv2D
の重み。それぞれにbias:0
の名前とkernel:0
という名前が割り当てられている。
どこで付けられているかは、githubを参照。
では、上で紹介したモデルでは、どうなっているかというと
print(gen_model().get_layer('preprocess_layer').weights)
[<tf.Variable 'preprocess_layer/Variable:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.485, 0.456, 0.406]]]], dtype=float32)>,
<tf.Variable 'preprocess_layer/Variable:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.229, 0.224, 0.225]]]], dtype=float32)>]
つまり、異なる重みに対して同じ名前が割り当てられるため、nameが重複してh5が保存できなくなってしまっている。
解決策
nameを指定すればok。
class Preprocess(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.mean = tf.Variable(tf.reshape([0.485, 0.456, 0.406], (1, 1, 1, -1)),
dtype=tf.float32, trainable=False, name='mean')
self.std = tf.Variable(tf.reshape([0.229, 0.224, 0.225], (1, 1, 1, -1)),
dtype=tf.float32, trainable=False, name='std')
...
[<tf.Variable 'preprocess_layer/mean:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.485, 0.456, 0.406]]]], dtype=float32)>,
<tf.Variable 'preprocess_layer/std:0' shape=(1, 1, 1, 3) dtype=float32, numpy=array([[[[0.229, 0.224, 0.225]]]], dtype=float32)>]
おわりに
ちなみに、h5指定ではなくckpt指定にすると上のエラーは出ずに保存できます。
ただckptは複数ロードするとエラーが出まくるので、デバッグとか含めてh5の方がやりやすいです。