LoginSignup
14
13

More than 5 years have passed since last update.

「破滅的忘却」を回避する Overcoming Catastrophic Forgetting by Incremental Moment Matching, NIPS'17. 読んだ

Posted at

はじめに

ニューラルネットワークにおいて、新たなタスクの学習を逐次的に行うと、以前学習したタスクに対する性能が急激に低下にcatastrophic forgetting(破滅的忘却)が発生してしまうことが課題となっており、それを解決するための手法を提案している。
Overcoming Catastrophic Forgetting by Incremental Moment Matching, NIPS'17.
https://arxiv.org/abs/1703.08475

以前にDeepMindのアプローチを読んで、これ系には興味があるので読んだ。背景や既存手法(EWC)の詳細は下記の記事に書いた。
ニューラルネットワークが持つ欠陥「破滅的忘却」を回避するアルゴリズムをDeepMindが開発した論文を読んだ

アプローチ

目的は、タスク1に関する学習データ$\mathcal{D}_1 = \{ X_1, y_1 \}$およびタスク2に関する学習データ$\mathcal{D}_2 = \{ X_2, y_2 \}$が与えられた際に、両方のタスクに対し有効なニューラルネットワークのパラメータ$\theta$の事後確率$p_{1:2} = p(\theta | \mathcal{D}_1, \mathcal{D}_2)$を求めることである。

これに対し、まずタスク1のパラメータの事後確率$p_1 = p(\theta | \mathcal{D}_1)$およびタスク2のパラメータの事後確率$p_2 = p(\theta | \mathcal{D}_2)$を求め、その後それらを統合して$p_{1:2} = p(\theta | \mathcal{D}_1, \mathcal{D}_2)$を求めることを提案する。

そのままでは解けないので、上記の事後確率を全てガウス分布で近似する(ラプラス近似):

p_{1:2} = p(\theta | \mathcal{D}_1, \mathcal{D}_2) \approx q_{1:2} = q(\theta|\mu_{1:2}, \Sigma_{1:2}), \\
p_1 = p(\theta | \mathcal{D}_1) \approx q_{1} = q(\theta|\mu_{1}, \Sigma_{1}), \\
p_2 = p(\theta | \mathcal{D}_2) \approx q_{2} = q(\theta|\mu_{2}, \Sigma_{2}).

これにより、上記の問題は、タスク1およびタスク2のニューラルネットワークのパラメータを学習し、その事後確率をガウス分布で近似した$q_1$および$q_2$を求め、最後にそれらを統合することで、両方のタスクに対し有効なニューラルネットワークパラメータ(の事後確率をガウス分布で近似した$q_{1:2}$)を求めるというステップで解くことができる。

すなわち、$q_{1:2}$が求まれば、両方のタスクに対し有効なニューラルネットワークのパラメータ$\theta_{1:2}$は$q_{1:2}$の平均値$\mu_{1:2}$として求めることができる。

2つの事後確率のパラメータ推定

まず、$q_1$および$q_2$のパラメータ$\mu_1, \Sigma_1$および$\mu_2, \Sigma_2$を求める必要がある。
タスク1およびタスク2をそれぞれ学習し、求められたニューラルネットワークのパラメータを$\theta_1^{*}$および$\theta_2^{*}$とすると、$\mu_1 = \theta_1, \mu_2 = \theta_2$となる(タスク1を学習した後のタスク2の学習に関しては、複数の方法での学習を提案しており、後述する)。

分散共分散行列行列$\Sigma_1, \Sigma_2$については、文献1と同様に、パラメータ$\theta$に関するフィッシャー情報行列$F$の逆行列として求めると、$\Sigma_1 = F_1^{-1}, \Sigma_2 = F_2^{-1}$となる。

2つの事後確率の統合方法

上記のように求めた$q_1$および$q_2$を統合して$p_{1:2}$を求める。
本論文では、mean-based incremental moment matching (mean-IMM) とmode-based incremental moment matching (mode-IMM) が提案されている。

mean-IMM

mean-IMMでは、$q_1$と$q_2$を混合比$\alpha$で混合した分布を1つのガウス分布$q_{1:2}$で近似する。このとき、$q_{1:2}$のパラメータは$q_{1:2}$と$(1-\alpha)q_1 + \alpha q_2$のKLダイバージェンスを最小化することで求められる:

\mu_{1:2}^*, \Sigma_{1:2}^* = \arg\min_{\mu_{1:2}, \Sigma_{1:2}} \mathrm{KL}(q_{1:2} || (1-\alpha)q_1 + \alpha q_2).

$\mu_{1:2}^*$はclosed-formで求めることができ、

\mu_{1:2}^* = (1-\alpha)\mu_1 + \alpha \mu_2

となる($\Sigma_{1:2}^*$もclosed-formで求めることができるが省略)。

mode-IMM

通常、前述のmean-IMMで求めた$q_{1:2}$のmode(最頻値, $\theta = \mu_{1:2}^*$)と、$(1-\alpha)q_1 + \alpha q_2$のmodeは異なる。mode-IMMでは、$(1-\alpha)q_1 + \alpha q_2$と$q_{1:2}$のmodeが同一となるように$\mu_{1:2}^{*}$を求める。
簡単のためmode-IMMでは$\alpha = 1/2$とすると、$\mu_{1:2}^*$は下記のように求めることができる($\Sigma_{1:2}^{*}$は略):

\mu_{1:2}^* = (\Sigma_1^{-1} + \Sigma_2^{-1})^{-1}(\Sigma_1^{-1} \mu + \Sigma_2^{-1} \mu_2).

ここで、$\Sigma_{1:2}^{*}$が対角であるという仮定を入れると、$\mu_{1:2}^*$は各次元$d$毎に求めることができる:

\mu_{{1:2}^*, d} = \frac{\mu_{1,d} / \sigma_{1,d}^{2} +  \mu_{2,v} / \sigma_{2,v}^{2} }{1 / \sigma_{1,v}^{2} + 1 / \sigma_{2,d}^{2}}.

IMMのための転移学習

本論文では、タスク1およびタスク2で学習されたニューラルネットワーク(のパラメータ)を平均することを提案しているが、それらが独立に初期化されている場合には、それらのパラメータの間にはロスが大きくなるhigh cost barrierが存在する2ことが問題となる。
つまり、平均したニューラルネットワークがそのロスが大きくなる領域になってしまうと、性能が悪くなってしまうためである。
この問題を回避するために、幾つかの転移学習のテクニックが活用できることを示す。

Weight-transfer

最も重要なアイディアは、weight-transferである。これは単純に、タスク2の学習を行う際に、パラメータの初期値をタスク1を学習したパラメータで初期化するだけである。
文献2では、様々なニューラルネットワークにおいて、ある初期値から学習を開始した際に、その解となる(収束した)パラメータと初期値の間にはほとんどhigh cost barrierが存在しないことを経験的に示した。
このことから、weight-transferを行うと、2つのタスクを学習させたパラメータの間にhigh cost barrierが存在しないようにすることができる。

L2-transfer

L2-transferは、タスク2の学習時に、$\lambda ||\mu_1 - \mu_2||_2^2$をロスとして追加する。
Continuous learningの文脈では良くL2-transferが(恐らくベースラインとして)利用される。
既存手法では$\lambda$を大きくすることでなるべく$\mu_1$と$\mu_2$が遠くならないようにしているが、本論文では、$\mu_1$と$\mu_2$の間のロス関数がスムーズになることを目的としており、小さな$\lambda$でL2-transferを行うことが特徴である。

Drop-transfer

Drop-transferは本論文で新しく提案する転移学習の方法である。
これは、dropout ratio $p$に対し、タスク2の学習中の重み$\mu_{2i}$を、確率$p$で$\mu_{1i}$に、確率$1-p$で$\frac{1}{1-p} \mu_{2i} - \frac{p}{1-p} \mu_{1i}$とするものである。
このとき、パラメータの期待値は$\mu_{2i}$となる。
この設定だと、現在のパラメータ$\mu_{2i}$は直接利用されず、$\mu_{1i}$と外挿された$\frac{1}{1-p} \mu_{2i} - \frac{p}{1-p} \mu_{1i}$になる。$\mu_{1i}$から学習した差分をdropoutしていると考えれば良いのだろうか?

実験結果

興味深いMNISTでの実験結果のみ取り上げる。
ここでは、MNISTをベースに、Disjoint MNISTShuffled MNISTという2種類の設定で実験を行っている。

Disjoint MNISTでは、タスク1では0から4までの数字の分類を、タスク2では5から9までの数字の分類を行う。
ポイントは、既存の論文の設定では独立した5クラス分類を2回解くという設定だが、本論文では、10クラス分類を解く前提で、データが2回に分けて与えられているという設定である。

Shuffled MNISTは文献1でも採用されている設定で、各タスクで、MNISTのピクセルがタスクごとに決まった順序でシャッフルされ、それぞれ10クラス分類を解くというものである。

上記が実験結果となっている。ポイントは、EWC1Disjoint MNISTでダメダメな点である。ここまで悪くなるのはかなり不思議だが、あまり何故かという考察はない。
結果を見る限り、L2-transferでMean/Mode-IMMをしておけば良さそうな雰囲気で、わざわざDrop-transferとかは使わなくてもいいかも。
本論文の貢献は、Disjoint MNISTという問題設定を見つけたところが一番かもしれない。

あと、タスク1がImageNetで、タスク2がCaltech-UCSD Birds-200-2011でAlexNetをベースとした実験も行っており、既存手法より僅かに高精度な結果を達成している(雑)


  1. J. Kirkpatrick et al., "Overcoming Catastrophic Forgetting in Neural Networks," in PNAS, vol. 114, no. 13, pp. 3521-3526, 2017. 

  2. I. J. Goodfellow, O. Vinyals, and A. M. Saxe, "Qualitatively Characterizing Neural Network Optimization Problems," ICLR, 2017. 

14
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
13