LoginSignup
37
33

More than 3 years have passed since last update.

コード解剖!! <深層距離学習 ArcFace>

Last updated at Posted at 2021-04-26

はじめに

画像認識を用いて解決したいタスクとして、しばしばある画像と類似する画像の検索を行いたいというタスクが表れることがあります。そのような、画像同士の類似度を学習する仕組みの中で深層学習を利用した方法に深層距離学習というものがあります。

今回はその深層距離学習のモデルの1つの「ArcFace」というモデルに関してPyTorchによる実装コードを追いながら勉強した内容を紹介したいと思います。

今回解説するコード

from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import math


class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device='cuda')
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4
        output *= self.s
        # print(output)

        return output

今回はArcFaceのモデルの概要について紹介し、その後上記のArcFaceのPyTorchによる実装(https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/metrics.py)
でモデルの処理・数式が上記コードで実現できることを見ていきたいと思います。

ArcFaceの概要

ここではArcFaceの概要やそこに表れる数式に簡単に説明します。
ArcFaceの詳しい説明は、「モダンな深層距離学習 (deep metric learning) 手法: SphereFace, CosFace, ArcFace」や元論文の説明がわかりやすかったです。
ArcFaceはクラス分類問題において非常頻繁に用いられるSoftmax Cross Entropyという損失関数
$$
L_1 = -\frac{1}{N}\sum^N_{i=1}\log\frac{\exp({W^T_{y_i}x_i+b_i)}}{\sum^C_k\exp({W^T_{k}x_i+b_k})} \;\;\;\;\;\;\;\;\;\;\;\;\;\; (1)
$$
に対して2つの工夫を施すことで得られる手法になります。
なお、式中に出てくるパラメーターの定義は以下であるとします。

  • N: バッチサイズ
  • i: バッチ内の学習サンプル
  • $W_k$: 最後の全結合層の重み$W$のk番目の列
  • $x_i$: 全結合層への入力となるD次元ベクトル
  • $b_k$: バイアス
  • $y_i$: i番目のサンプルの正解クラスID

[1つ目の工夫] 重み・特徴ベクトルの正規化

ArcFaceにおける工夫の1つ目はSoftmax Cross Entropyの中の重みベクトル$W^T_{y_k}$と特徴ベクトル$x_i$をそれぞれ$L_2$正規化(ノルムが1にする正規化)を行うというものです。
この、正規化によって重みベクトルと特徴ベクトルはそれぞれ

  • $W_{y_k}$->半径1の超球面上に射影されたクラス$k$の代表ベクトル
  • $x_i$->半径1の超球面上に射影されたi番目のデータの特徴ベクトル

と解釈することができます。
この時、クラス$k$の代表ベクトルとi番目のデータの特徴ベクトルがなす角度$\theta_{k}$とすると、それらの内積は
$$
W^T_{k}x_i=|W^T_{k}||x_i|\cos\theta_{k}=\cos\theta_{k}
$$
となります。
さらに、バイアス項$b_k=0$とすることで$L_2$正規化を施したCross entropy lossは
$$
L_2 = -\frac{1}{N}\sum^N_{i=1}\log\frac{\exp(\cos\theta_{y_i})}{\sum^C_k\exp(\cos\theta_{k})} \;\;\;\;\;\;\;\;\;\;\;\;\;\; (2)
$$
と表せます。
この式の意味するところは、各クラスの代表ベクトル$W_k$と特徴ベクトル$x_i$のなす角$\theta_k$が最も小さいクラスに分類する損失を表します。この工夫によって代表ベクトル$W_k$や特徴ベクトル$x_i$距離の近さを比べたいベクトル同士の成す角度の近さと解釈できるようになります。

[2つ目の工夫]角度に対するマージン

次に行うのはAngular Margin Penaltyと呼ばれる工夫です。その工夫とは「(2)式において、正解クラス$y_i$に対応する$\cos\theta_{y_i}$のみ」に対して、$\cos(θ_{y_i}+m) (m>0)$ という変更(ペナルティ)を加えるというものです。

こうようにすることで正解クラスの代表ベクトル$W_{y_i}$とのなす角$θy_i$が他のどのクラスの代表ベクトルとの角度よりも、$θ_{y_i}$がm以上のマージンを持って小さくなるように学習されます。

このペナルティにより、モデルは特徴ベクトルxのクラス内分散を小さくし、クラス間分散を大きくして、このマージンを稼ぐように学習が進みます。

最後にこのcosθの値はs (s>1)倍されてsoftmaxへ送られます。これは$\cos\theta_k$の値が小さすぎるとsoftmaxが機能しなくなるために調節を行っている処理になります。

以上のことをまとめると、ArcFaceのロスを書き下すと下記になります。
$$
L_3 = -\frac{1}{N}\sum^N_{i=1}\log\frac{\exp(s\cos(\theta_{y_i}+m)}{\exp(s\cos(\theta_{y_i}+m)+\sum^C_{k≠y_i}\exp(s\cos(\theta_k))} \;\;\;\;\;\;\;\;\;\;\;\;\; (3)
$$

上記のような処理を加えたモデルを通常のクラス分類として学習させることで、特徴ベクトルxのクラス内分散を小さくし、クラス間分散を大きくするような学習が実現されます。実際に利用する場合には、正規化された特徴ベクトル$x_i$を抽出し、それらのコサイン類似度によってサンプル間の類似度を算出することができます。

コード解説

① コードの前半部分(コンストラクターの定義まで)

   from __future__ import print_function
   from __future__ import division
   import torch
   import torch.nn as nn
   import torch.nn.functional as F
   from torch.nn import Parameter
   import math  

   class ArcMarginProduct(nn.Module):
       r"""Implement of large margin arc distance: :
           Args:
               in_features: size of each input sample
               out_features: size of each output sample
               s: norm of input feature
               m: margin
               cos(theta + m)
           """
       def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):
           super(ArcMarginProduct, self).__init__()
           self.in_features = in_features
           self.out_features = out_features
           self.s = s
           self.m = m
           self.weight = Parameter(torch.FloatTensor(out_features, in_features))
           nn.init.xavier_uniform_(self.weight)

           self.easy_margin = easy_margin
           self.cos_m = math.cos(m)
           self.sin_m = math.sin(m)
           self.th = math.cos(math.pi - m)
           self.mm = math.sin(math.pi - m) * m

ここでは、コンストラクターでこれから使う変数の宣言・初期化などをしています。
また、コンストラクターの要求する引数は

  • in_features: 特徴ベクトルの次元数
  • out_features: 分類クラスするクラス数
  • s: スケール因子
  • m: マージン
  • easy_margin: 後述

を意味します。

② cosの計算

       def forward(self, input, label):
           # --------------------------- cos(theta) & phi(theta) ---------------------------
           cosine = F.linear(F.normalize(input), F.normalize(self.weight))

このdef forwardからがArcFaceのアルゴリズムが実装されている箇所になります。

まず、関数の定義部分を見るとArcFaceの前段の層からの出力(input)と正解ラベル(label)が引数であることがわかります。

そのあとF.linear(F.normalize(input), F.normalize(self.weight))という量を計算しています。

この処理は、

  1. F.normalize()という処理が$L_2$正規化の処理で、

    • F.normalize(input)で入力データの$L_2$正規化
    • F.normalize(self.weight)で各クラスの代表ベクトルの$L_2$正規化
      ​を実施します。
  2. その後、1.で計算したそれぞれをF.linear()という$W^T x_i$を計算する関数に引数として渡しています。

この時2.の結果として[$\cos\theta_1$,・・・,$\cos\theta_{y_i}$,$\cos\theta_C$]という入力データと各クラスの代表ベクトルとのcosを要素として持つ配列が計算されます。

③ マージンの適用

   sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
   phi = cosine * self.cos_m - sine * self.sin_m
  1. 次の処理の準備のためにsinを計算します。

  2. 三角関数の加法定理
    $$
    cos(A+B)=cosAcosB-sinAsinB
    $$
    を使いcosineに対して各要素の角度からmラジアンを足した配列を$\phi$作ります。
    $$
    \phi = [\cos(\theta_1 + m),・・・,\cos(\theta_{y_i} + m),・・・,\cos(\theta_C + m)]
    $$

④ cosの周期性を考慮した処理。

   if self.easy_margin:
      phi = torch.where(cosine > 0, phi, cosine)
   else:
      phi = torch.where(cosine > self.th, phi, cosine - self.mm)

$\cos\theta$は$\theta=\pi$を境に減少から増加に転じます。ここでマージンmの目的は入力データと答えのクラスの代表ベクトルを遠ざけることが目的であったことを思い返すと、$\cos(\theta + m)<\cos\theta$である必要があります。そのため、$\theta>\pi-m$の範囲でマージンをとった場合上記の問題が発生する可能性があります。

  1. そのため、$\theta>\pi-m$の場合はマージンの取り方を変更し角度ではなくcosそのものにマージンを適用します(cosarcのマージンのとりかた)。この処理がelseの中身です。

  2. 一方上記のように複雑に考えずにcosine < 0になったらマージンの適用をやめるという方法で上記の問題に対処するという方針をとるのがself.easy_margin=Trueの場合です。

⑤ 正解ラベルをone hotベクトル化する。

      one_hot = torch.zeros(cosine.size(), device='cuda')
      one_hot.scatter_(1, label.view(-1, 1).long(), 1)

⑥ 配列cosineに関して正解ラベルの要素の部分を配列$\phi$のものに置き換える。

      output = (one_hot * phi) + ((1.0 - one_hot) * cosine)  # you can use torch.where if your torch.__version__ is 0.4

⑦ 最後sでスケールする。

      output *= self.s

利用方法

最後に既存のPyTorchで実装されたCNNのモデルがある場合、それに対してArcFaceを適用する方法を紹介したいと思います。

① ソースコードのダウンロード
以下からソースコードをダウンロードします。
https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/metrics.py

② ソースファイルのインポート

   import metrics

③ ArcFaceのインスタンス生成
例: input_data=10、分類クラス数10、 s=30.0, m=0.05, easy_margin=Trueでインスタンス生成する場合

   metric = metrics.ArcMarginProduct(10, 10,  s=30.0, m=0.05, easy_margin=True)

④ Optimizerのインスタンス生成部分の変更
ArcFaceの重みパラメーターをOptimizerで更新するためにOptimizerのインスタンス生成部分にArcFaceの部分を記述します。
例:既存のモデルがmodelという名前で最適化関数がAdamの場合

   optimizer = optim.Adam([{'params': model.parameters()}, {'params': metric.parameters()}], lr=0.00001)

⑤ 学習部分に関して
損失関数の計算~誤差逆伝搬によるパラメーターの更新の部分は以下のように実装できます。
例:損失関数の名前がcriterionである場合

    #・・・割愛(epoch数に関するfor文やdataloaderからdataやlabelを取得する処理等)・・・
    # 前回の勾配情報をリセット
       optimizer.zero_grad()

       # 予測
       features = model(inputs)
       outputs = metric(features, labels)
       # 予測結果と教師ラベルを比べて損失を計算
       loss = criterion(outputs, labels)
       # 損失に基づいてネットワークのパラメーターを更新
       loss.backward()
       optimizer.step()

注意する点としては損失関数の入力となるoutputsを予測するための処理が

  1. まず、既存の分類モデルであるmodelに入力データを代入して特徴ベクトルfeatures得る。
  2. 次に、出力ベクトルと正解ラベルをArcFaceに代入しoutputsを得る。 という2ステップになっている点です。

以上のようにして既存の分類モデルにArcFaceを組み込むことができます。

おわりに

今回はArcFaceという深層距離学習の理論・実装・使い方を確認しました。
今回記事を書くにあたり元論文を読んだりもしたのですが難解な数式や背景知識などがあまり現れずかなり読みやすく理解しやすいという印象をうけました。また、今回紹介した実装コードも100行に満たない簡潔で理解し易いコードと感じました。

このように、ArcFaceは導入にあたって必要なラーニングコストが少ないアルゴリズムであると感じました。是非みなさんも機会があれば活用してみてはいかがでしょうか?

37
33
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
37
33