10
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Kerasでmodel.saveしようとしてmust override `get_config`エラーが出たときの対処

Posted at

はじめに

本記事ではKerasでmodel.save(あるいはmodel.to_json)しようとしてXX has arguments in `__init__` and therefore must override `get_config` .が出たときの対処法を紹介したいと思います。

背景と原因

KerasではDense層やConv層など事前定義されたレイヤーが多数存在し、これらを組み合わせることで基本的なモデルを設計します。
しかし、より発展的には、カスタムレイヤーを自分で実装しモデルに追加することになります。
例えば最新の論文で発表された仕組みを利用したいには場合、Kerasの事前定義レイヤーには存在せず、Githubから引用したり、自分で実装したりする必要があります。
(カスタムレイヤーの実装について興味のある場合は、こちらの公式Exampleをご確認ください。)
あるいは、初学者においては、kaggleのkernelなどで公開されているスクリプトをフォークした際に、知らず知らずのうちにカスタムレイヤーを含んだモデルを利用しているかもしれません。(私自身もそのようにして今回のエラーに直面しました。)

さて、XX(カスタムレイヤー名) has arguments in `__init__` and therefore must override `get_config` というエラーは、このカスタムレイヤーを含んだモデルに対して正しく対処できていない際に、Kerasから「そんなレイヤー知らないよ」と怒られて生じるものなのです。

解決方法

カスタムレイヤーのクラス内でget_config()をオーバーライドすることで解決できます。
より具体的には、カスタムレイヤーのクラスの__init__の引数を辞書にして、親クラスのconfigに追加して返すようなget_config()を定義します。
これが意味するところは、__init__の引数はこのカスタムレイヤーの設計書のようなものですから、勝手に作ったカスタムレイヤーの仕組みをKerasに明示的に教えてあげていることに相当します。

ちなみに、このようにして保存されたモデルはロードする際にもカスタムレイヤーをcustom_objectsアーギュメントで明示的に示す必要があります。
方法は非常に簡単で、以下のように行います。

load_model('my_model.h5', custom_objects={'NameOfCustomLayer': NameOfCustomLayer})

具体例

KaggleのこちらのPublic Kernelを例に説明いたします。
[GLRec] ResNet50 ArcFace (TF2.2)

このスクリプトのうち、実際にモデルの定義は以下で行われます。
backboneとなるモデルはResNet50でKerasに事前定義されています。(weightも今回のようにローカルに保存したものを使用するだけでなく、Kerasのパッケージで取得できます。)
また、pooling層やdropout層も事前定義されたものです。

このなかでmargin層だけは独自にインスタンス化していることがわかります。これがこのモデルのカスタム層です。

create_model.py

def create_model(input_shape,
                 n_classes,
                 dense_units=512,
                 dropout_rate=0.0,
                 scale=30,
                 margin=0.3):

    backbone = tf.keras.applications.ResNet50(
        include_top=False,
        input_shape=input_shape,
        weights=('../input/imagenet-weights/' +
                 'resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5')
    )

    pooling = tf.keras.layers.GlobalAveragePooling2D(name='head/pooling')
    dropout = tf.keras.layers.Dropout(dropout_rate, name='head/dropout')
    dense = tf.keras.layers.Dense(dense_units, name='head/dense')

    margin = ArcMarginProduct(
        n_classes=n_classes,
        s=scale,
        m=margin,
        name='head/arc_margin',
        dtype='float32')

    softmax = tf.keras.layers.Softmax(dtype='float32')

    image = tf.keras.layers.Input(input_shape, name='input/image')
    label = tf.keras.layers.Input((), name='input/label')

    x = backbone(image)
    x = pooling(x)
    x = dropout(x)
    x = dense(x)
    x = margin([x, label])
    x = softmax(x)
    return tf.keras.Model(
        inputs=[image, label], outputs=x)

margin層のクラスであるArcMarginProductを確認します。
すると、tf.keras.layers.Layerを継承したカスタムレイヤーであることがわかります。
(ちなみに、実装している技術はArcFaceといいます。)

このように独自定義されたカスタムレイヤー内で、get_config()を正しくオーバーライドしていないとき、model.saveをすると冒頭のエラーに直面するのでした。

今回のKernelではget_config()がクラス内で定義されていないので、そのままsaveしようとするとエラーがでます。

custom_layer.py
class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m

    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

そこで、以下のような変更を加える必要があります。

具体的には、get_config()をオーバーライドし、__init__の引数と親クラスのconfigを返しています。

new_custom_layer.py

class ArcMarginProduct(tf.keras.layers.Layer):
    '''
    Implements large margin arc distance.

    Reference:
        https://arxiv.org/pdf/1801.07698.pdf
        https://github.com/lyakaap/Landmark2019-1st-and-3rd-Place-Solution/
            blob/master/src/modeling/metric_learning.py
    '''
    def __init__(self, n_classes, s=30, m=0.50, easy_margin=False,
                 ls_eps=0.0, **kwargs):

        super(ArcMarginProduct, self).__init__(**kwargs)

        self.n_classes = n_classes
        self.s = s
        self.m = m
        self.ls_eps = ls_eps
        self.easy_margin = easy_margin
        self.cos_m = tf.math.cos(m)
        self.sin_m = tf.math.sin(m)
        self.th = tf.math.cos(math.pi - m)
        self.mm = tf.math.sin(math.pi - m) * m


### Start 追加されたコード
    def get_config(self):
        config = {
            "n_classes" : self.n_classes,
            "s" : self.s,
            "m" : self.m,
            "easy_margin" : self.easy_margin,
            "ls_eps" : self.ls_eps
        }
        base_config = super().get_config()
        return dict(list(base_config.items()) + list(config.items()))

###  End       
        
    def build(self, input_shape):
        super(ArcMarginProduct, self).build(input_shape[0])

        self.W = self.add_weight(
            name='W',
            shape=(int(input_shape[0][-1]), self.n_classes),
            initializer='glorot_uniform',
            dtype='float32',
            trainable=True,
            regularizer=None)

    def call(self, inputs):
        X, y = inputs
        y = tf.cast(y, dtype=tf.int32)
        cosine = tf.matmul(
            tf.math.l2_normalize(X, axis=1),
            tf.math.l2_normalize(self.W, axis=0)
        )
        sine = tf.math.sqrt(1.0 - tf.math.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = tf.where(cosine > 0, phi, cosine)
        else:
            phi = tf.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = tf.cast(
            tf.one_hot(y, depth=self.n_classes),
            dtype=cosine.dtype
        )
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.n_classes

        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

そしてモデルのロードは以下のように行う必要があります。

load_model.py

loaded_model =keras.models.load_model("path_to_model", custom_objects = {"ArcMarginProduct": ArcMarginProduct})

参考

Kerasでカスタムレイヤーを作成する方法
kerasでカスタムレイヤーのシリアライズを行う
NotImplementedError: Layers with arguments in __init__ must override get_config

10
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?