サマリ
- metric learningとは、関連する画像同士における特徴量表現の距離を意図的に近く集合させることが可能な学習モデルであり、サンプルが少ないときや、未知クラスがあるときに有効な手法である。
- ADFI(異常検知用画像データセット)をもとに普段よく用いられる画像分類モデル(ResNet34)と、metric learning(ArcFace)の画像データセットに対する精度を評価。
- metric learning(AUC=0.933)を用いたことで、通常の画像分類モデル(AUC=0.907)より精度向上が見られた。
背景・目的
- 東京のとあるデータ分析会社のデータサイエンティスト職として従事。
- 社内メンバーで日本酒コンペにチーム参加した際、metric learning(=距離学習)を用いた画像検索問題に出会う。
- 銅メダルをチームで取得できたが、metric learningについての理解が浅かったので知識の整理のためにまとめたい。(コンペの話については年明け以降に弊社テックブログにて掲載予定。)
- qiitaにてmetric learningについての過去記事を調査すると、異常検知との組み合わせた検証記事が多かったのでそれらを参考に、今回は異常検知×metric learningの検証を行った。
用語
metric learningとは?
- metric learningは、画像処理モデルの一つとして利用されている学習手法。
- 特徴空間内におけるデータ間の距離や類似性を適切に学習するため、分類モデルなどによって画像データからベクトルを抽出し、関連する画像同士は意図的にベクトルの距離が近く、関連しない画像同士は距離が遠くなるようデータ間の距離が計算されような学習を実行。
- 定番のクラス分類の学習であれば、サンプル間の距離を考慮せず全結合層で分離可能な(separable)特徴量になるよう学習してしまうが、metric learningだとサンプル間の距離が考慮されるので識別的な(discriminative)特徴量を得ることができる。
- 上記の経緯から、metric learningは各クラスのサンプルが少ないときや、未知クラスがあるときでも十分な性能を発揮できることが強みであり、顔認識や異常検知、画像検索タスクなどで用いられている。
- 今回は、metric learningの中でも実装が手軽なArcFaceを用いて異常検知用データにて精度を検証。
- ArcFaceではソフトマックス損失関数を拡張して同一クラス間の分散を小さくする工夫を追加することで、同一クラスの距離が近くなるように学習させることが可能。ArcFaceの損失関数についての詳細はこちらの記事の説明が分かりやすい。
異常検知とは?
- 異常値(Anomaly)とは、あるデータセットにおける予期せぬ変化、または予期されるパターンからの逸脱を指す。
- 異常検知とは、それらの異常値を検出し警告するために使用される技術。
手法
データセット
- 使用したデータセット:ADFI, “Real-World Dataset for Anomaly Detection”, AI Robotics LTD., 2022, https://adfi.jp
-
Hazelnut
,Coffee beans
,Rotary beacon light
の3種類の画像に「正常」「異常」どちらかのラベルが付与。
実験条件
こちらを参考に学習・評価を実施。
評価対象
- ①Conventional = 距離学習を利用しない通常のCNNの場合
- ②ArcFace
- ①②どちらもResNet34をベースにモデル構造を定義。
- ②ArcFaceに関しては最終層に距離学習を行うスクリプトを追加し、出力されるembeddingsは512次元とした。
学習手順
- 学習時に際しては「異常」画像を排除し、「正常」画像セットのみを利用。
-
train
データの「正常」データのみを利用し、Hazelnut
,Coffee beans
,Rotary beacon light
の3クラスで画像分類モデルを作成した。
評価手順
- 評価用画像は
test
データ全量を用いた。 - すべての評価用画像を、学習済みモデル(Conventional, ArcFace)に与えて512次元のembeddingsを得る。
- 評価用画像のembeddingsと学習用画像(「正常」のみ画像セット)のembeddingsにおける距離を計算。※距離計算はcosine類似度により算出。
- 評価用画像jごとに、すべての学習用画像(「正常」のみ画像セット)との距離を計算し、そのうち一番短い距離の値を、評価用画像jの正解までの距離
distance_j
とした。 - モデルの精度はAUCで評価、評価用画像の内、正常を
0
、異常を1
でラベリング。 - 評価用画像の正解までの距離指標(
distance_j
)をもとに、AUCを算出。
結果
- 精度評価においてArcFaceがConventionalモデルの精度を上回る結果に。
# | 学習条件 | AUC |
---|---|---|
1 | ArcFace | 0.933 |
2 | Conventional | 0.907 |
あとがき
- 今回の検証スクリプトについてはこちら。
- 今回、metric learningを用いたことで異常検知に関する精度の向上が見られたが、従来の画像分類モデルの方が精度が高い結果になっているケースも他の人の検証結果で見かけたので、従来モデルより精度劣化するユースケースやその理由についても今後理解できるようになりたい。
- metric learning手法は今回採用したArcFace以外にも複数あるが、実装が簡単な点で気軽に試せるのが良かった。今後のデータ分析コンペで上手く活用できるよう研鑚を積みたい。