距離学習(metric learning)とは、機械学習によって似たデータが近くにくるようなマッピングを学習する手法です。この記事では、この技術を使って、キャラクター顔画像の分類を試してみたいと思います。
使用したコードは以下に公開してあります。
距離学習とは?
距離学習の有名な応用例としては顔画像認識があります。同一の人物の顔画像を「似た画像」、別の人物の画像「似ていない画像」として学習することで、仮に学習データに存在しない顔であっても「同じ顔」として判別することができるようになります。
通常の画像分類モデルとの対比を考えてみましょう。画像分類モデル、例えば、猫画像と犬画像を分類するモデルであれば、出力は猫を示す1か、犬を示す2などでしょう(正確に言えば、それをOne Hot形式の2次元ベクトルで表現したものが出力になるでしょう)。
一方、距離学習モデルは画像を入力すると、数百次元のベクトルを出力します。このベクトルは、埋め込みベクトルなどと呼ばれます。埋め込みベクトル上では、別種の画像(犬の画像と猫の画像)は遠ざかり、同種の画像(犬画像同士、猫画像同士)が近くなります。
距離学習モデルの強みのひとつは、未知の集まりにも適用できる可能性があることです。
顔認識の場合で考えてみましょう。学習データセットに含まれるのが、Aさん、Bさんの顔の二種類だけだった場合、このモデルを使って、学習データセットに含まれていないCさんの顔を分類することはできません。そもそも出力として定義されているのが、Aさん、Bさんのみなので、それ以外の集まりに適用することは原理的に困難です。
一方、距離学習モデルの場合、適切に学習ができていれば、同一人物の顔が近いベクトルにマッピングされるという特性は、未知の顔に対しても満たされることが期待できます。
単純に言えば、距離学習なら未知の人物の顔でも分類できるというわけです。
今回やりたいこと: キャラクター顔画像の分類
今回は実写の顔画像ではなく、距離学習を使って、キャラクター顔画像を学習させてみたいと思います。応用として考えているのは、ゲームなどの素材画像の分類です。
距離学習であれば、未知のキャラクターの画像であっても、サンプルの画像が一枚あれば、サンプルと同じキャラクターの画像を見つけることができる可能性があります。うまく使えば大量の画像から特定のキャラクターの画像を探したり、自動で分類を行なうことができそうです。
ちなみに本当は右側のイラストは同じキャラの別のイラストにしたかったのですが、権利的に問題ない画像を見つけられなかったので、同じ画像になってしまっています。実際にやりたいのは、まったく同一の画像を検索することではなく、同じキャラクターの表情などが異なる画像を見つけることなので注意してください。
今回使用したデータセットと手法
今回はAnimeFace Character Datasetというオープンデータセットを使用させていただきました。こちらは、英語圏の掲示板に投稿されたイラストから顔部分のみを抜き出したデータセットのようです。キャラクター別に分類された顔画像が、1キャラクターあたり10〜50程度あり、キャラクター数で言えば100〜200程度含まれているようです
深層学習による距離学習と言えば、以前は、contrastive lossやtriplet lossという二つ組や三つ組を使って学習する手法が一般的でした。しかし近年は、ArcFaceやCosFaceなど画像分類モデルを拡張したSoftmax系のアプローチが主流になりつつあります。この種のアプローチの利点は、画像分類モデルにモジュールをひとつ追加するだけで簡単に使用できることです。
以下、画像分類モデルと比較しつつ、Softmax系アプローチについて紹介します。
通常の画像分類モデルの場合、モデルは、正解ラベルに対応する出力がもっとも大きくなるように学習します。モデルの構成としては、入力された画像がまず複数のConvolution層で処理され、最後にLinear層で処理されるという構成が一般的でしょう。役割から言うと、Convolution層は画像から特徴を抽出し、Linear層は特徴ベクトルを各ラベルに分類します。
ArcFaceなどの距離学習モデルの場合、モデルの構成や学習の仕方は画像分類モデルとほぼ同じです。ただし距離学習モデルの場合、正解ラベルに対応する出力をへらします。言い方を変えれば、わざとハンディキャップをつけた状態で学習するような感じです。この圧力によって、モデルは、同一ラベル内の分散を低く、ラベル間の分散が高くなるように学習します。
一方推論時(つまり実際に使用する際)は、Linear層以降の出力をショートカットし、Convolution層の出力をそのまま使用します。学習時にハンディキャップをつけることで圧力をかけ、特徴ベクトルが中心に集まるようにしているという感じです。
結果
先ほど説明したAnimeFace Character Datasetのデータを使って学習を試してみました。画像分類モデルのバックボーンにはresnet34を使用し、arcfaceを使って学習させています。また実際に判定を行なう場合はscikit-learnのNearestNeighborsを使っています。
バリデーションデータを使って性能を見ると、以下のような結果になりました。NearestNeighborsのしきい値の設定である程度変化するので数値は参考程度にとどめておいてください。
メトリクス | |
---|---|
F1 | 0.6244 |
precision | 0.8661 |
recall | 0.5070 |
結果を見ると、recall(つまり、同一キャラの画像を漏れなく拾えている率)の数値が低くなっていますが、precision(つまり、モデルが同一キャラだと判定した場合に実際に同一キャラである率)はある程度の正確さが出ています。簡単に言えば、このモデルではカバーできない画像もそれなりにあるが、わかりやすい画像はそれなりに拾えているという感じです。
画像を出していいのわからなかったので掲載しませんが、実際の判定結果を見ると、recallが低い理由はおそらく次のような理由ではないでしょうか。元々のデータセットはユーザー投稿画像で構成されているため、キャラクターの判別がかなり難しいものもあるようです。実際人間が見ても判定が難しい画像が多かったです。
学習データに存在しないキャラクターでも使えるかどうか試してみました。こちらも画像は出せませんが、上記のデータセットに含まれていないKLabのゲームのキャラクターの顔画像を使ってテストしたところ、上記の結果に近い性能をあげることができました。
性能などはまだ改善の余地があると思いますが、ひとまず既成のデータセットを使っただけでも一定の性能が出せることがわかりました。