14
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

[メモ]論文読み"Matching Networks for One Shot Learning"

Last updated at Posted at 2017-12-31

Matching Networks for One Shot Learning

概要

  • 論文: Matching Networks for One Shot Learning
  • 内容
    • one short learningのためのネットワーク(MatchingNet)とその学習方法の提案

事前準備

One shot Learning

One short learningとは、Nクラスへの分類問題において、あるクラスに属する教師データが1個しかないような問題。

たとえば、人間の子供の場合、1枚のキリンの写真を見ればキリンを認識できるようになる。深層学習の場合、大量の教師データが必要になる。

深層学習以外の手法で1個の教師データから学ぶことができる方法もある。例えば、k-nearest neighbors法などがある。

問題設定

  • Nクラス分類問題を解く。
  • 各クラスの教師データはかなり少ない(1~5個程度)

各クラスに属するデータが少ないデータ・セット$S = {(x_1, x_2), ..., (x_k, y_k)}$ とする。
$x_i$ は入力データ(画像など)、$y_i$ は正解ラベル。

この$S$から分類器$c_S$を生成する写像 $M_{\theta}$ を精度良く学習できれば少ないデータから分類機を生成できる。

$$M_{\theta}: S \mapsto c_S$$
$$c_S: x \mapsto p \in R^N$$

$\theta$ はモデルパラメータ。$p$ は各クラスに属する確率を表す。

上の2式を組み合わせて
$$P_\theta(y|x, S) := \phi_y(M_{\theta}(S)(x))$$

$\phi_y$はベクトルからスカラーへのマップで$y$次元目の値を抜き出す。
$$\phi_j([v_1, ..., v_j, ..., v_N]) = v_j$$

提案手法

2つの要素から構成される。

  • モデル: MatchingNet。AttentionやMemoryを取り入れたネットワーク。
  • 学習方法: 少ない教師データから学習させる。

Model Architecture

モデルは小さなsupport $S$ から分類器 $c_S$ を作る。

提案するモデルは下記のモデルから着想を得ている。

  • seq2seq with Attention
  • memory Networks
  • pointer netrowks

モデルをシンプルに表現すると下記のように書ける。
$$y'=\sum_{i=1}^k a(x', x_i) y_i \ \ \ \ \ \ (1)$$

ここで

  • $x_i$, $y_i$ は support set $S$ に含まれるサンプルとラベル
  • $y_i$ はone-hotベクトル、$y'$ は各ラベルへの重み。(と思われる。)
  • $a$ はattention構造
  • $a$ を $X \times X$ 上のカーネルとみなせば、式(1)はkernel destance estimator(KDE)とみなせる。
  • もっとも $x'$ から遠い $b$ 個の $x_i$ に対して $a(x', x_i)$ が0ならば、(k-b) nearest neighborsとみなせる。

The Attention Kernel

式(1)の $a$ を決める必要がある。最もシンプルな構造はコサイン距離$c$を使って、その距離のsoftmaxをとる方法。

$a(h', x_i) = \exp(c(f(x'), g(x_i)) / \sum_{j=1}^k \exp(c(f(x'), g(x_j))$

ここで、$f$, $g$はニューラルネットによって近似する。
例えば、画像処理の場合にCNN、自然言語処理の場合に単語の埋め込み等を使う。

関数 g,f の詳細

$g(x_i)$ は他のデータにも依存して決める。なので、 $g(x_i, S)$ と表す。
まずneural network(CNNなど)を使って $g'(x)$ 入力データを変換する。
LSTMを使って $g$ を計算する。

$$g(x_i, S) = h^r_{i} + h^l_{i} + g'(x_i)$$

$$h^r_{i}, c^r_{i} = \text{LSTM}(g'(x_i), h^r_{i-1}, c^r_{i-1})$$

$$h^l_{i}, c^l_{i} = \text{LSTM}(g'(x_i), h^l_{i+1}, c^l_{i+1})$$

$f$ は LSTM with real-attention をつかって定義する。
(詳しくは、Order Matters: Sequence to sequence for sets, https://arxiv.org/abs/1511.06391)

$$f(x', S) = \text{attLSTM}(f'(x'), {g(x_i, S)}, K)$$

$$h_k', c_k = \text{LSTM}(f'(x'), [h_{k-1}, r_{k-1}], c_{k-1})$$

$$h_k = h'_k + f'(x')$$

$$r_{k-1} = \sum_{i=1}^{|S|} a(h_{k-1}, g(x_i)) g(x_i)$$

$$a(h_{k-1}, g(x_i)) = \text{softmax}(h_{k-1}^T g(x_i))$$

学習方法

$$\theta = \arg \max_{\theta} E_{L \sim T}
\left[
E_{S \sim L, B \sim L}
\left[
\sum_{(x, y) \in B} \log P_{\theta}(y|x, S)
\right]
\right]$$

  • Tはタスク。クラスの集合。
  • LはTから一様にサンプリングしたクラスの集合(5個程度)
  • Sはsupport用に正解ラベルがLに含まれるデータの集合(25個程度)
  • BはSを与えたときの学習用データ。
14
18
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
18

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?