3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

TensorFlow用ArcFaceの一実装

Last updated at Posted at 2021-01-04

はじめに

TensorFlow 2.x 向けにArcFaceをカスタムレイヤとカスタム損失関数の組み合わせとして実装した。

背景

深層距離学習の様々な手法の中でも,クラス分類問題の出力層に追加するだけで構成可能なArcFaceはシンプルで見通しの良い手法といえる。(参考・モダンな深層距離学習 (deep metric learning) 手法: SphereFace, CosFace, ArcFace)
通常のニューラルネットワークの結合層が出力値の計算において層への入力値と重みのみを必要とするのに対してArcFaceの計算では正解ラベルも必要とするため,TensorFlow 2.x (Keras)で実現するには一工夫必要となる。
ArcFaceをKerasで実現した先行例として「[Keras]MobileNetV2+ArcFaceを使ってペットボトルを分類してみた!」がある。同事例ではArcFaceを2入力のカスタムレイヤとして実装した上で,学習データセットのジェネレータから正解ラベルを正解ラベルをバイパスして入力している。
筆者も上記を参考に試してみたが,ArcFaceと通常のクラス分類を差し替えて試したいときにネットワーク構造が変わる点が煩雑に感じたため別の方法での実装を検討した。

実装方法

ArcFaceの大まかな計算手順を以下に示す。

  1. 入力Xと重みWをL2正規化
  2. 正解以外のラベルについて,cos(θ) = X・W (内積)を計算
  3. 正解ラベルについて,cos(θ+m)を計算
  4. Softmaxを計算

ここでcos(θ+m) = cos(θ)・cos(m)-sin(θ)・sin(m)となり(加法定理),さらにsin(θ)=√(1-cos(θ)^2)となることから実際には手順2.で正解ラベルも含めてcos(θ)を計算すれば良い。

2'. 全てのラベルについて,cos(θ) = X・W (内積)を計算

これで手順2'.までが正解ラベルを必要としなくなり,カスタムレイヤとして実装可能になる。
また,手順3以降は入力や重みを必要としないため損失関数内に実装することができる。
但し,Accuracyの計算時にもSoftmaxの計算がされるようにする必要がある。

実装例を以下に示す。

arcface.py
import TensorFlow as tf

# ArcFaceの前半部分
class ArcFaceLayer0(tf.keras.layers.Layer) :
    def __init__(self, num_outputs, kernel_regularizer = None, **kargs) :
        super(ArcFaceLayer0, self).__init__(**kargs)
        self.num_outputs = num_outputs
        self.kernel_regularizer = kernel_regularizer

    def build(self, input_shape) :
        weight_shape = (input_shape[-1] , self.num_outputs)
        self.kernel = self.add_weight(
            name='kernel',
            shape = weight_shape,
            initializer = tf.keras.initializers.TruncatedNormal(),
            regularizer = self.kernel_regularizer,
            trainable = True
            )
        super(ArcFaceLayer0, self).build(input_shape)

    def call(self, input) :
        n_input = tf.math.l2_normalize(input, axis = 1)               # inputのL2正規化
        n_kernel = tf.math.l2_normalize(self.kernel, axis = 0)        # 重みのL2正規化
        return tf.matmul(n_input, n_kernel)      # W.Txの内積

# 損失関数側に実装したArcFace
class ArcFaceLoss(tf.keras.losses.Loss) :
    # m:マージン
    # s:倍率
    # loss_func:本来の損失関数 tf.keras.losses.CategoricalCrossentropy(from_logits = True)など
    def __init__(self, loss_func, m = 0.5, s = 30, name = "arcface_loss", **kwargs) :
        self.loss_func = loss_func
        self.margin = m
        self.s = s
        self.enable = True
        super(ArcFaceLoss, self).__init__(name = name, **kwargs)

    def call(self, y_true, y_pred):
        # y_predは cos(θ)
        # 加法定理のためにsin(θ)を計算する
        sine = tf.keras.backend.sqrt(1.0 - tf.keras.backend.square(y_pred))
        phi = y_pred * self.cos_m - sine * self.sin_m       # cos(θ+m)の加法定理
        phi = tf.where(y_pred > 0, phi, y_pred)             # あさってを向いているときはそのまま

        # 正解クラス:cos(θ+m) 他のクラス:cosθ 
        logits = (y_true * phi) + ((1.0 - y_true) * y_pred)

        # 本来の損失関数を呼び出す
        return self.loss_func(y_true, logits * self.s)

# ArcFace用の評価関数
class ArcFaceAccuracy(tf.keras.metrics.Mean) :
    def __init__(self, metrics_func, s = 30, name = "arcface_accuracy", dtype = None) :
        self.metrics_func = metrics_func
        self.s = s
        super(ArcFaceAccuracy, self).__init__(name, dtype)

    def update_state(self, y_true, y_pred, sample_weight = None) :
        output = tf.nn.softmax(y_pred * self.s)
        matches = self.metrics_func(y_true, output)

        return super(ArcFaceAccuracy, self).update_state(matches, sample_weight = sample_weight) 

3
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?