最近、Few-shot learningの欲が高まっている@oike17です。今さらSiamese Neural Networks for One-shot Image Recognition という論文を読んだので忘れないようにまとめておきたいと思います。
概要
この論文ではOne-shot Learningというタスクに対して、Deep metric learningの一手法であるSiamese Networkを適用することで当時のSOTA (State-of-the-Art)を達成した論文になります。現在の多くのFew-shot Learningの手法はこの手法に少なからず影響を受けていると思います。
One-shot Learning
では、まず対象とするタスクについて話していきます。One-shot Learningはクラスに対してひとつのサンプルしか与えられていない状況を対象としています。ただ、本当にサンプルひとつだけが渡されてもどうしようもないので、同じドメインの他のデータも同時に大量に与えられているという前提があります。この問題に対処するためには、対象とするドメイン固有の特徴量を得ることが必要となります。
N-way k-shot classification task
これはFew-shot Learningの評価を行うためによく用いられるタスクになります。以下のようなものです。
- 対象とするNクラスのそれぞれについて与えられるデータはk個のみ
- Nクラスについては学習せず、それ以外のデータを使って学習
- テストはNクラスを用いて行う
したがって、完全にランダムに予測した場合には精度が$1/N$となるようなタスクになります。
Deep Learningによる主なアプローチ
現在のOne-shot Learningの主なアプローチについても軽く触れておきます。Deep Learningによるアプローチとしては、主に以下のような手法が提案されています。
- Metric Learning
- この記事で取り扱うやつ
- Meta Learning
- 「少数の例をみただけで、すぐに認識できるようになる」というタスクを学習させる
- エピソード学習と呼ばれる
- Meta-learning with memory-augmented
neural networks
- Metric Learning × Meta Learning
がありますが、今回の話はMetric Learningのみを用いた手法になります。
Metric Learning
概要
Metric Learningとはデータ間の**計量(類似度や距離など)を学習する手法です。計量を学習するということは、「特徴空間を学習する」,「埋め込み方を学習する」**と捉えることもできます。僕的には埋め込み方を学習するという考え方が一番しっくりきています。直感的なイメージとしては、意味の近いデータは近く、意味の遠いデータは遠くなるように特徴空間を学習してる感じです。
意味的な距離を考慮した特徴量空間をうまく学習することができれば、未知クラスのデータに対してもロバストに対応できることも知られており、One-shot Learningというタスクに活用されることは自然な流れだと言えます。
Metric Learningのモデル
まず、Metric Learning Modelの入出力を明示的に示しておくと
- 入力:データのペア
- 出力:それらデータ間の類似度(距離)
となります。図で表すと以下のようになります。
この「類似度を求めるタスク」つまり「一致・不一致を判定するタスク」をverification taskといいます。ここで注意すべきことは、Metric Learninigで得られるものはあくまでSimilarity functionであり、識別モデルを学習するものではないということです。ただ、今対象としているタスクはclassification taskですので、verification taskのためのモデルをclassification taskへと落とし込むことが必要になります。
Metric Learning for Classification (One-shot)
テスト画像$x$と各クラスを代表する画像の集合$\{x_c\}_{c=1}^C$が与えられているとしましょう。このとき学習後のネットワークを利用して$x$と$\{x_c\}_{c=1}^C$の各画像の類似度を求めることができます。これを元に最大の類似度に対応するクラスを予測として出力してあげれば良さそうです。つまり、出力$C^*$は次のように定式化できます。
$$
C^*=argmax_c p^{(c)}
$$
この考え方は非常に直感的ですので理解しやすいと思います。
Deep metric learning
名前から大方の想像はつくと思いますが、Metric Learningに対してDeep learningを適用するのがDeep metric learningです。距離の定義部分にDeep learningを用いることで、特徴空間への非線形な変換を学習することが可能になり、より高い表現能力を獲得することができます。
本論文のアプローチ
- Siamese NetworkというDeep metric learningの一手法を用いて画像の類似度を出力するモデルを学習します
- 得られたモデルの再学習は不要であり、学習後のモデルをそのまま使ってOne-shot Learningを実現します
Siamese Network [Bromley, 1993]
Siamese Neworkが一番最初に発表されたのは1993年であり、非常に古典的な手法であると言えます。この手法は署名の検証のためにBromleytoとLeCunによって提案されました。(Signature verification using a siamese time delay neural network)
ネットワークの構造は下図の感じで、直感的な理解としては入力をそれぞれ特徴空間に埋め込んであげて、特徴空間内で距離(類似度)を求めてあげるといった感じでしょうか。
ペアとして与えられたどちらの入力も同一のサブネットワークを通ることで特徴ベクトルへ変換されます。出力はこの特徴ベクトル間の距離になります。ベクトル間の距離の定義(L1ノルム、L2ノルム、コサイン類似度など)は様々ありますが、この論文ではベクトルの各要素のL1ノルムの総和としています。
Siamese Networkの表現能力
個人的にSiamese Networkの表現能力に興味があったので、MNISTを用いて学習させてみました。以下が特徴ベクトルを可視化させてみた結果になります。
確かに同じラベルのデータは近く、異なるラベルのデータは遠くなるように学習が行えていることが確認できます。
提案モデル
この論文の提案モデルは下図のような感じです。
画像を対象として扱っているので、特徴ベクトルへの変換部分はCNNで行っており、特徴ベクトル間のL1ノルムの総和をシグモイド関数に投げて、類似度を計算しています。
Loss function
損失関数は、よく用いられるクロスエントロピーのやつです。
$$
L(x_1^{(i)},x_2^{(i)} )=y(x_1^{(i)},x_2^{(i)})\log p(x_1^{(i)}, x_2^{(i)})+(1-y(x_1^{(i)},x_2^{(i)})\log (1-p(x_1^{(i)},x_2^{(i)}))+\lambda^T|w|^2
$$
実験
データセット
- The Omniglot dataset
- よくある手書き文字のやつ
- MNIST
- よくある手書き文字(数字)のやつ
The Omniglot dataset
- データの60%をトレーニングデータとして使用
- 20-way One-shot classification task
色が変わっているものが提案手法です。HBPL (Hierarchial Bayesian Program Learning)以外の手法に対してはより高い精度が達成できています。
唯一提案手法の精度が劣っているHBPLに関してですが、HBPLは文字認識のみに特化しており、めちゃくちゃアドホックなモデルです。一方で提案手法はどんなドメインのデータセットにも対応でき、汎化性能の点で非常に優れています。
MNIST transpose
- Omniglotを用いて学習させたモデルを再学習なしにMNISTに適用
- 10-way One-shot classification task
Omniglotで学習したモデルはMNISTに対しても一定の効果をあげていることが確認できます。
所感
- 非常に単純なモデルでありながら、高い精度を達成できててすごい(小並感)
- クラス数が変わっても再学習なしで対応できる点は非常に強力
- Metric Learning面白い!!