Matching Networks for One Shot Learning
概要
- 論文: Matching Networks for One Shot Learning
- (https://arxiv.org/pdf/1606.04080.pdf)
- 2016-0613
- Google DeepMindの人たちの論文
- 内容
- one short learningのためのネットワーク(MatchingNet)とその学習方法の提案
事前準備
One shot Learning
One short learningとは、Nクラスへの分類問題において、あるクラスに属する教師データが1個しかないような問題。
たとえば、人間の子供の場合、1枚のキリンの写真を見ればキリンを認識できるようになる。深層学習の場合、大量の教師データが必要になる。
深層学習以外の手法で1個の教師データから学ぶことができる方法もある。例えば、k-nearest neighbors法などがある。
問題設定
- Nクラス分類問題を解く。
- 各クラスの教師データはかなり少ない(1~5個程度)
各クラスに属するデータが少ないデータ・セット$S = {(x_1, x_2), ..., (x_k, y_k)}$ とする。
$x_i$ は入力データ(画像など)、$y_i$ は正解ラベル。
この$S$から分類器$c_S$を生成する写像 $M_{\theta}$ を精度良く学習できれば少ないデータから分類機を生成できる。
$$M_{\theta}: S \mapsto c_S$$
$$c_S: x \mapsto p \in R^N$$
$\theta$ はモデルパラメータ。$p$ は各クラスに属する確率を表す。
上の2式を組み合わせて
$$P_\theta(y|x, S) := \phi_y(M_{\theta}(S)(x))$$
$\phi_y$はベクトルからスカラーへのマップで$y$次元目の値を抜き出す。
$$\phi_j([v_1, ..., v_j, ..., v_N]) = v_j$$
提案手法
2つの要素から構成される。
- モデル: MatchingNet。AttentionやMemoryを取り入れたネットワーク。
- 学習方法: 少ない教師データから学習させる。
Model Architecture
モデルは小さなsupport $S$ から分類器 $c_S$ を作る。
提案するモデルは下記のモデルから着想を得ている。
- seq2seq with Attention
- memory Networks
- pointer netrowks
モデルをシンプルに表現すると下記のように書ける。
$$y'=\sum_{i=1}^k a(x', x_i) y_i \ \ \ \ \ \ (1)$$
ここで
- $x_i$, $y_i$ は support set $S$ に含まれるサンプルとラベル
- $y_i$ はone-hotベクトル、$y'$ は各ラベルへの重み。(と思われる。)
- $a$ はattention構造
- $a$ を $X \times X$ 上のカーネルとみなせば、式(1)はkernel destance estimator(KDE)とみなせる。
- もっとも $x'$ から遠い $b$ 個の $x_i$ に対して $a(x', x_i)$ が0ならば、(k-b) nearest neighborsとみなせる。
The Attention Kernel
式(1)の $a$ を決める必要がある。最もシンプルな構造はコサイン距離$c$を使って、その距離のsoftmaxをとる方法。
$a(h', x_i) = \exp(c(f(x'), g(x_i)) / \sum_{j=1}^k \exp(c(f(x'), g(x_j))$
ここで、$f$, $g$はニューラルネットによって近似する。
例えば、画像処理の場合にCNN、自然言語処理の場合に単語の埋め込み等を使う。
関数 g,f の詳細
$g(x_i)$ は他のデータにも依存して決める。なので、 $g(x_i, S)$ と表す。
まずneural network(CNNなど)を使って $g'(x)$ 入力データを変換する。
LSTMを使って $g$ を計算する。
$$g(x_i, S) = h^r_{i} + h^l_{i} + g'(x_i)$$
$$h^r_{i}, c^r_{i} = \text{LSTM}(g'(x_i), h^r_{i-1}, c^r_{i-1})$$
$$h^l_{i}, c^l_{i} = \text{LSTM}(g'(x_i), h^l_{i+1}, c^l_{i+1})$$
$f$ は LSTM with real-attention をつかって定義する。
(詳しくは、Order Matters: Sequence to sequence for sets, https://arxiv.org/abs/1511.06391)
$$f(x', S) = \text{attLSTM}(f'(x'), {g(x_i, S)}, K)$$
$$h_k', c_k = \text{LSTM}(f'(x'), [h_{k-1}, r_{k-1}], c_{k-1})$$
$$h_k = h'_k + f'(x')$$
$$r_{k-1} = \sum_{i=1}^{|S|} a(h_{k-1}, g(x_i)) g(x_i)$$
$$a(h_{k-1}, g(x_i)) = \text{softmax}(h_{k-1}^T g(x_i))$$
学習方法
$$\theta = \arg \max_{\theta} E_{L \sim T}
\left[
E_{S \sim L, B \sim L}
\left[
\sum_{(x, y) \in B} \log P_{\theta}(y|x, S)
\right]
\right]$$
- Tはタスク。クラスの集合。
- LはTから一様にサンプリングしたクラスの集合(5個程度)
- Sはsupport用に正解ラベルがLに含まれるデータの集合(25個程度)
- BはSを与えたときの学習用データ。