メモ
深層学習
論文読み

[メモ]論文読み"Matching Networks for One Shot Learning"

Matching Networks for One Shot Learning

概要

  • 論文: Matching Networks for One Shot Learning
  • 内容
    • one short learningのためのネットワーク(MatchingNet)とその学習方法の提案

事前準備

One shot Learning

One short learningとは、Nクラスへの分類問題において、あるクラスに属する教師データが1個しかないような問題。

たとえば、人間の子供の場合、1枚のキリンの写真を見ればキリンを認識できるようになる。深層学習の場合、大量の教師データが必要になる。

深層学習以外の手法で1個の教師データから学ぶことができる方法もある。例えば、k-nearest neighbors法などがある。

問題設定

  • Nクラス分類問題を解く。
  • 各クラスの教師データはかなり少ない(1~5個程度)

各クラスに属するデータが少ないデータ・セット$S = {(x_1, x_2), ..., (x_k, y_k)}$ とする。
$x_i$ は入力データ(画像など)、$y_i$ は正解ラベル。

この$S$から分類器$c_S$を生成する写像 $M_{\theta}$ を精度良く学習できれば少ないデータから分類機を生成できる。

$$M_{\theta}: S \mapsto c_S$$
$$c_S: x \mapsto p \in R^N$$

$\theta$ はモデルパラメータ。$p$ は各クラスに属する確率を表す。

上の2式を組み合わせて
$$P_\theta(y|x, S) := \phi_y(M_{\theta}(S)(x))$$

$\phi_y$はベクトルからスカラーへのマップで$y$次元目の値を抜き出す。
$$\phi_j([v_1, ..., v_j, ..., v_N]) = v_j$$

提案手法

2つの要素から構成される。
* モデル: MatchingNet。AttentionやMemoryを取り入れたネットワーク。
* 学習方法: 少ない教師データから学習させる。

Model Architecture

モデルは小さなsupport $S$ から分類器 $c_S$ を作る。

提案するモデルは下記のモデルから着想を得ている。

  • seq2seq with Attention
  • memory Networks
  • pointer netrowks

モデルをシンプルに表現すると下記のように書ける。
$$y'=\sum_{i=1}^k a(x', x_i) y_i \ \ \ \ \ \ (1)$$

ここで

  • $x_i$, $y_i$ は support set $S$ に含まれるサンプルとラベル
  • $y_i$ はone-hotベクトル、$y'$ は各ラベルへの重み。(と思われる。)
  • $a$ はattention構造
  • $a$ を $X \times X$ 上のカーネルとみなせば、式(1)はkernel destance estimator(KDE)とみなせる。
  • もっとも $x'$ から遠い $b$ 個の $x_i$ に対して $a(x', x_i)$ が0ならば、(k-b) nearest neighborsとみなせる。

The Attention Kernel

式(1)の $a$ を決める必要がある。最もシンプルな構造はコサイン距離$c$を使って、その距離のsoftmaxをとる方法。

$a(h', x_i) = \exp(c(f(x'), g(x_i)) / \sum_{j=1}^k \exp(c(f(x'), g(x_j))$

ここで、$f$, $g$はニューラルネットによって近似する。
例えば、画像処理の場合にCNN、自然言語処理の場合に単語の埋め込み等を使う。

関数 g,f の詳細

$g(x_i)$ は他のデータにも依存して決める。なので、 $g(x_i, S)$ と表す。
まずneural network(CNNなど)を使って $g'(x)$ 入力データを変換する。
LSTMを使って $g$ を計算する。

$$g(x_i, S) = h^r_{i} + h^l_{i} + g'(x_i)$$

$$h^r_{i}, c^r_{i} = \text{LSTM}(g'(x_i), h^r_{i-1}, c^r_{i-1})$$

$$h^l_{i}, c^l_{i} = \text{LSTM}(g'(x_i), h^l_{i+1}, c^l_{i+1})$$

$f$ は LSTM with real-attention をつかって定義する。
(詳しくは、Order Matters: Sequence to sequence for sets, https://arxiv.org/abs/1511.06391)

$$f(x', S) = \text{attLSTM}(f'(x'), {g(x_i, S)}, K)$$

$$h_k', c_k = \text{LSTM}(f'(x'), [h_{k-1}, r_{k-1}], c_{k-1})$$

$$h_k = h'_k + f'(x')$$

$$r_{k-1} = \sum_{i=1}^{|S|} a(h_{k-1}, g(x_i)) g(x_i)$$

$$a(h_{k-1}, g(x_i)) = \text{softmax}(h_{k-1}^T g(x_i))$$

学習方法

$$\theta = \arg \max_{\theta} E_{L \sim T}
\left[
E_{S \sim L, B \sim L}
\left[
\sum_{(x, y) \in B} \log P_{\theta}(y|x, S)
\right]
\right]$$

  • Tはタスク。クラスの集合。
  • LはTから一様にサンプリングしたクラスの集合(5個程度)
  • Sはsupport用に正解ラベルがLに含まれるデータの集合(25個程度)
  • BはSを与えたときの学習用データ。