はじめに
本記事ではKerasでmodel.save(あるいはmodel.to_json)しようとしてXX has arguments in `__init__` and therefore must override `get_config`
.が出たときの対処法を紹介したいと思います。
背景と原因
KerasではDense層やConv層など事前定義されたレイヤーが多数存在し、これらを組み合わせることで基本的なモデルを設計します。
しかし、より発展的には、カスタムレイヤーを自分で実装しモデルに追加することになります。
例えば最新の論文で発表された仕組みを利用したいには場合、Kerasの事前定義レイヤーには存在せず、Githubから引用したり、自分で実装したりする必要があります。
(カスタムレイヤーの実装について興味のある場合は、こちらの公式Exampleをご確認ください。)
あるいは、初学者においては、kaggleのkernelなどで公開されているスクリプトをフォークした際に、知らず知らずのうちにカスタムレイヤーを含んだモデルを利用しているかもしれません。(私自身もそのようにして今回のエラーに直面しました。)
さて、XX(カスタムレイヤー名) has arguments in `__init__` and therefore must override `get_config`
というエラーは、このカスタムレイヤーを含んだモデルに対して正しく対処できていない際に、Kerasから「そんなレイヤー知らないよ」と怒られて生じるものなのです。
解決方法
カスタムレイヤーのクラス内でget_config()
をオーバーライドすることで解決できます。
より具体的には、カスタムレイヤーのクラスの__init__
の引数を辞書にして、親クラスのconfigに追加して返すようなget_config()
を定義します。
これが意味するところは、__init__
の引数はこのカスタムレイヤーの設計書のようなものですから、勝手に作ったカスタムレイヤーの仕組みをKerasに明示的に教えてあげていることに相当します。
ちなみに、このようにして保存されたモデルはロードする際にもカスタムレイヤーをcustom_objectsアーギュメントで明示的に示す必要があります。
方法は非常に簡単で、以下のように行います。
load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})
具体例
KaggleのこちらのPublic Kernelを例に説明いたします。
[GLRec] ResNet50 ArcFace (TF2.2)
このスクリプトのうち、実際にモデルの定義は以下で行われます。
backboneとなるモデルはResNet50でKerasに事前定義されています。(weightも今回のようにローカルに保存したものを使用するだけでなく、Kerasのパッケージで取得できます。)
また、pooling層やdropout層も事前定義されたものです。
このなかでmargin層だけは独自にインスタンス化していることがわかります。これがこのモデルのカスタム層です。
def create_model(input_shape,
n_classes,
dense_units=512,
dropout_rate=0.0,
scale=30,
margin=0.3):
backbone = tf.keras.applications.ResNet50(
include_top=False,
input_shape=input_shape,
weights=('../input/imagenet-weights/' +
'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
)
pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
dense = tf.keras.layers.Dense(dense_units, name='head/dense')
margin = ArcMarginProduct(
n_classes=n_classes,
s=scale,
m=margin,
name='head/arc_margin',
dtype='float32')
softmax = tf.keras.layers.Softmax(dtype='float32')
image = tf.keras.layers.Input(input_shape, name='input/image')
label = tf.keras.layers.Input((), name='input/label')
x = backbone(image)
x = pooling(x)
x = dropout(x)
x = dense(x)
x = margin([x, label])
x = softmax(x)
return tf.keras.Model(
inputs=[image, label], outputs=x)
margin層のクラスであるArcMarginProductを確認します。
すると、tf.keras.layers.Layerを継承したカスタムレイヤーであることがわかります。
(ちなみに、実装している技術はArcFaceといいます。)
このように独自定義されたカスタムレイヤー内で、get_config()
を正しくオーバーライドしていないとき、model.saveをすると冒頭のエラーに直面するのでした。
今回のKernelではget_config()
がクラス内で定義されていないので、そのままsaveしようとするとエラーがでます。
class ArcMarginProduct(tf.keras.layers.Layer):
'''
Implements large margin arc distance.
Reference:
https://arxiv.org/pdf/1801.07698.pdf
https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
blob/master/src/modeling/metric_learning.py
'''
def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
ls_eps=0.0, **kwargs):
super(ArcMarginProduct, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.ls_eps = ls_eps
self.easy_margin = easy_margin
self.cos_m = tf.math.cos(m)
self.sin_m = tf.math.sin(m)
self.th = tf.math.cos(math.pi - m)
self.mm = tf.math.sin(math.pi - m) * m
def build(self, input_shape):
super(ArcMarginProduct, self).build(input_shape[0])
self.W = self.add_weight(
name='W',
shape=(int(input_shape[0][-1]), self.n_classes),
initializer='glorot_uniform',
dtype='float32',
trainable=True,
regularizer=None)
def call(self, inputs):
X, y = inputs
y = tf.cast(y, dtype=tf.int32)
cosine = tf.matmul(
tf.math.l2_normalize(X, axis=1),
tf.math.l2_normalize(self.W, axis=0)
)
sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = tf.where(cosine > 0, phi, cosine)
else:
phi = tf.where(cosine > self.th, phi, cosine - self.mm)
one_hot = tf.cast(
tf.one_hot(y, depth=self.n_classes),
dtype=cosine.dtype
)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
そこで、以下のような変更を加える必要があります。
具体的には、get_config()
をオーバーライドし、__init__
の引数と親クラスのconfigを返しています。
class ArcMarginProduct(tf.keras.layers.Layer):
'''
Implements large margin arc distance.
Reference:
https://arxiv.org/pdf/1801.07698.pdf
https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
blob/master/src/modeling/metric_learning.py
'''
def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
ls_eps=0.0, **kwargs):
super(ArcMarginProduct, self).__init__(**kwargs)
self.n_classes = n_classes
self.s = s
self.m = m
self.ls_eps = ls_eps
self.easy_margin = easy_margin
self.cos_m = tf.math.cos(m)
self.sin_m = tf.math.sin(m)
self.th = tf.math.cos(math.pi - m)
self.mm = tf.math.sin(math.pi - m) * m
### Start 追加されたコード
def get_config(self):
config = {
"n_classes" : self.n_classes,
"s" : self.s,
"m" : self.m,
"easy_margin" : self.easy_margin,
"ls_eps" : self.ls_eps
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
### End
def build(self, input_shape):
super(ArcMarginProduct, self).build(input_shape[0])
self.W = self.add_weight(
name='W',
shape=(int(input_shape[0][-1]), self.n_classes),
initializer='glorot_uniform',
dtype='float32',
trainable=True,
regularizer=None)
def call(self, inputs):
X, y = inputs
y = tf.cast(y, dtype=tf.int32)
cosine = tf.matmul(
tf.math.l2_normalize(X, axis=1),
tf.math.l2_normalize(self.W, axis=0)
)
sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
phi = cosine * self.cos_m - sine * self.sin_m
if self.easy_margin:
phi = tf.where(cosine > 0, phi, cosine)
else:
phi = tf.where(cosine > self.th, phi, cosine - self.mm)
one_hot = tf.cast(
tf.one_hot(y, depth=self.n_classes),
dtype=cosine.dtype
)
if self.ls_eps > 0:
one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output *= self.s
return output
そしてモデルのロードは以下のように行う必要があります。
loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})
参考
Kerasでカスタムレイヤーを作成する方法
kerasでカスタムレイヤーのシリアライズを行う
NotImplementedError: Layers with arguments in __init__
must override get_config