はじめに
メタラーニングの論文の1つ、通称MAML(マムル)
[1] C. Finn, et. al. "Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks"
を読んでみた。
arXiv:
https://arxiv.org/abs/1703.03400
公式の(?)コード:
https://github.com/cbfinn/maml
ちなみに nikkei robotics 2019年1月号 の岡野原氏の記事でもこの論文が紹介されていた。
概要
- メタラーニングのモデル
- 複数のタスクで学習させる事で、それぞれのタスクに特化したパラメータをすぐさま得られる
モデル
問題設定
メタラーニングのモデル: $f$
observation : $\bf{x}$
モデルの出力: $\bf{a}$
初めの obsevation の分布: $q(\bf{x}\rm_1)$
traisition 分布: $q(\bf{x}\rm_{t+1}|\bf{x}\rm_t, \bf{a}\rm_t)$
エピソードの長さ: $H$
これらを用いて各タスク:
\mathcal{T} = \{ \mathcal{L}(\bf{x}\rm_1, \bf{a}\rm_1, \cdots , \bf{x}\rm_H, \bf{a}\rm_H ), q(\bf{x}\rm_1), q(\bf{x}\rm_{t+1}|\bf{x}\rm_t, \bf{a}\rm_t) \}
教師あり学習の場合は $H=1$ 。
タスク全体の分布: $p(\mathcal{T})$
個々のタスクの分布: $p(\mathcal{T} _i)$
ロス: $\mathcal{L}(\bf{x}\rm_1, \bf{a}\rm_1, \cdots, \bf{x}\rm_H,
\bf{a}\rm_H) \to \mathbb{R}$
タスク $\mathcal{T} _i$ のロス: $\mathcal{L} _{\mathcal{T} _i}$
モデルのパラメータ: $\theta$
モデルからの出力: $f_{\theta}$
手法
パラメータの初期値 $\theta$ とする。
$i$ 番目のタスク $\mathcal{T} _i$ のロス $\mathcal{L} _{\mathcal{T} _i}$ から通常の勾配降下法でパラメータの更新値 $\theta ' _i$ を求める。
\theta ' _i = \theta - \alpha \nabla _{\theta} \mathcal{L} _{\mathcal{T} _i} (f_{\theta})
このタスクごとに求まった $\theta ' _i$ のロス $\mathcal{L} _{\mathcal{T} _{i} }(\theta ' _i)$ の全タスクにわたる合計値が最小化するような $\theta$ を求める。
\min _{\theta} \sum _{\mathcal{T} _i \sim p(\mathcal{T})} \mathcal{L} _{\mathcal{T} _i} (f_{\theta ' _i}) = \min _{\theta} \sum _{\mathcal{T} _i \sim p(\mathcal{T})} \mathcal{L} _{\mathcal{T} _i} (f_{\theta - \alpha \nabla _{\theta} \mathcal{L} _{\mathcal{T} _i} (f_{\theta})})
これに関しても勾配降下法で勾配を求める。そしてパラメータの初期値を更新する。
\theta \gets \theta - \beta \nabla _{\theta} \sum _{\mathcal{T} _i \sim p(\mathcal{T})} \mathcal{L} _{\mathcal{T} _i} (f _{\theta ' _i})
このとき $\theta$ に関して微分しているので、
\begin{eqnarray}
\theta &\gets& \theta -\beta \nabla _{\theta} \sum _{\mathcal{T} _i \sim p(\mathcal{T})} \mathcal{L} _{\mathcal{T} _i} (f _{\theta ' _i}) \\
&=& \theta -\beta \nabla _{\theta} \sum _{\mathcal{T} _i \sim p(\mathcal{T})} \mathcal{L} _{\mathcal{T} _i} (f_{\theta - \alpha \nabla _{\theta} \mathcal{L} _{\mathcal{T} _i} (f_{\theta})} ) \\
\end{eqnarray}
となり、微分の微分が登場するので、Hessian を求める必要がある。
書きかけ