はじめに
ニューラルネットワークにおいて、新たなタスクの学習を逐次的に行うと、以前学習したタスクに対する性能が急激に低下にcatastrophic forgetting(破滅的忘却)が発生してしまうことが課題となっており、それを解決するための手法を提案している。
Overcoming Catastrophic Forgetting by Incremental Moment Matching, NIPS'17.
https://arxiv.org/abs/1703.08475
以前にDeepMindのアプローチを読んで、これ系には興味があるので読んだ。背景や既存手法(EWC)の詳細は下記の記事に書いた。
ニューラルネットワークが持つ欠陥「破滅的忘却」を回避するアルゴリズムをDeepMindが開発した論文を読んだ
アプローチ
目的は、タスク1に関する学習データ$\mathcal{D}_1 = \{ X_1, y_1 \}$およびタスク2に関する学習データ$\mathcal{D}_2 = \{ X_2, y_2 \}$が与えられた際に、両方のタスクに対し有効なニューラルネットワークのパラメータ$\theta$の事後確率$p_{1:2} = p(\theta | \mathcal{D}_1, \mathcal{D}_2)$を求めることである。
これに対し、まずタスク1のパラメータの事後確率$p_1 = p(\theta | \mathcal{D}_1)$およびタスク2のパラメータの事後確率$p_2 = p(\theta | \mathcal{D}_2)$を求め、その後それらを統合して$p_{1:2} = p(\theta | \mathcal{D}_1, \mathcal{D}_2)$を求めることを提案する。
そのままでは解けないので、上記の事後確率を全てガウス分布で近似する(ラプラス近似):
p_{1:2} = p(\theta | \mathcal{D}_1, \mathcal{D}_2) \approx q_{1:2} = q(\theta|\mu_{1:2}, \Sigma_{1:2}), \\
p_1 = p(\theta | \mathcal{D}_1) \approx q_{1} = q(\theta|\mu_{1}, \Sigma_{1}), \\
p_2 = p(\theta | \mathcal{D}_2) \approx q_{2} = q(\theta|\mu_{2}, \Sigma_{2}).
これにより、上記の問題は、タスク1およびタスク2のニューラルネットワークのパラメータを学習し、その事後確率をガウス分布で近似した$q_1$および$q_2$を求め、最後にそれらを統合することで、両方のタスクに対し有効なニューラルネットワークパラメータ(の事後確率をガウス分布で近似した$q_{1:2}$)を求めるというステップで解くことができる。
すなわち、$q_{1:2}$が求まれば、両方のタスクに対し有効なニューラルネットワークのパラメータ$\theta_{1:2}$は$q_{1:2}$の平均値$\mu_{1:2}$として求めることができる。
2つの事後確率のパラメータ推定
まず、$q_1$および$q_2$のパラメータ$\mu_1, \Sigma_1$および$\mu_2, \Sigma_2$を求める必要がある。
タスク1およびタスク2をそれぞれ学習し、求められたニューラルネットワークのパラメータを$\theta_1^{*}$および$\theta_2^{*}$とすると、$\mu_1 = \theta_1, \mu_2 = \theta_2$となる(タスク1を学習した後のタスク2の学習に関しては、複数の方法での学習を提案しており、後述する)。
分散共分散行列行列$\Sigma_1, \Sigma_2$については、文献1と同様に、パラメータ$\theta$に関するフィッシャー情報行列$F$の逆行列として求めると、$\Sigma_1 = F_1^{-1}, \Sigma_2 = F_2^{-1}$となる。
2つの事後確率の統合方法
上記のように求めた$q_1$および$q_2$を統合して$p_{1:2}$を求める。
本論文では、mean-based incremental moment matching (mean-IMM) とmode-based incremental moment matching (mode-IMM) が提案されている。
mean-IMM
mean-IMMでは、$q_1$と$q_2$を混合比$\alpha$で混合した分布を1つのガウス分布$q_{1:2}$で近似する。このとき、$q_{1:2}$のパラメータは$q_{1:2}$と$(1-\alpha)q_1 + \alpha q_2$のKLダイバージェンスを最小化することで求められる:
\mu_{1:2}^*, \Sigma_{1:2}^* = \arg\min_{\mu_{1:2}, \Sigma_{1:2}} \mathrm{KL}(q_{1:2} || (1-\alpha)q_1 + \alpha q_2).
$\mu_{1:2}^*$はclosed-formで求めることができ、
\mu_{1:2}^* = (1-\alpha)\mu_1 + \alpha \mu_2
となる($\Sigma_{1:2}^*$もclosed-formで求めることができるが省略)。
mode-IMM
通常、前述のmean-IMMで求めた$q_{1:2}$のmode(最頻値, $\theta = \mu_{1:2}^*$)と、$(1-\alpha)q_1 + \alpha q_2$のmodeは異なる。mode-IMMでは、$(1-\alpha)q_1 + \alpha q_2$と$q_{1:2}$のmodeが同一となるように$\mu_{1:2}^{*}$を求める。
簡単のためmode-IMMでは$\alpha = 1/2$とすると、$\mu_{1:2}^*$は下記のように求めることができる($\Sigma_{1:2}^{*}$は略):
\mu_{1:2}^* = (\Sigma_1^{-1} + \Sigma_2^{-1})^{-1}(\Sigma_1^{-1} \mu + \Sigma_2^{-1} \mu_2).
ここで、$\Sigma_{1:2}^{*}$が対角であるという仮定を入れると、$\mu_{1:2}^*$は各次元$d$毎に求めることができる:
\mu_{{1:2}^*, d} = \frac{\mu_{1,d} / \sigma_{1,d}^{2} + \mu_{2,v} / \sigma_{2,v}^{2} }{1 / \sigma_{1,v}^{2} + 1 / \sigma_{2,d}^{2}}.
IMMのための転移学習
本論文では、タスク1およびタスク2で学習されたニューラルネットワーク(のパラメータ)を平均することを提案しているが、それらが独立に初期化されている場合には、それらのパラメータの間にはロスが大きくなるhigh cost barrierが存在する2ことが問題となる。
つまり、平均したニューラルネットワークがそのロスが大きくなる領域になってしまうと、性能が悪くなってしまうためである。
この問題を回避するために、幾つかの転移学習のテクニックが活用できることを示す。
Weight-transfer
最も重要なアイディアは、weight-transferである。これは単純に、タスク2の学習を行う際に、パラメータの初期値をタスク1を学習したパラメータで初期化するだけである。
文献2では、様々なニューラルネットワークにおいて、ある初期値から学習を開始した際に、その解となる(収束した)パラメータと初期値の間にはほとんどhigh cost barrierが存在しないことを経験的に示した。
このことから、weight-transferを行うと、2つのタスクを学習させたパラメータの間にhigh cost barrierが存在しないようにすることができる。
L2-transfer
L2-transferは、タスク2の学習時に、$\lambda ||\mu_1 - \mu_2||_2^2$をロスとして追加する。
Continuous learningの文脈では良くL2-transferが(恐らくベースラインとして)利用される。
既存手法では$\lambda$を大きくすることでなるべく$\mu_1$と$\mu_2$が遠くならないようにしているが、本論文では、$\mu_1$と$\mu_2$の間のロス関数がスムーズになることを目的としており、小さな$\lambda$でL2-transferを行うことが特徴である。
Drop-transfer
Drop-transferは本論文で新しく提案する転移学習の方法である。
これは、dropout ratio $p$に対し、タスク2の学習中の重み$\mu_{2i}$を、確率$p$で$\mu_{1i}$に、確率$1-p$で$\frac{1}{1-p} \mu_{2i} - \frac{p}{1-p} \mu_{1i}$とするものである。
このとき、パラメータの期待値は$\mu_{2i}$となる。
この設定だと、現在のパラメータ$\mu_{2i}$は直接利用されず、$\mu_{1i}$と外挿された$\frac{1}{1-p} \mu_{2i} - \frac{p}{1-p} \mu_{1i}$になる。$\mu_{1i}$から学習した差分をdropoutしていると考えれば良いのだろうか?
実験結果
興味深いMNISTでの実験結果のみ取り上げる。
ここでは、MNISTをベースに、Disjoint MNISTとShuffled MNISTという2種類の設定で実験を行っている。
Disjoint MNISTでは、タスク1では0から4までの数字の分類を、タスク2では5から9までの数字の分類を行う。
ポイントは、既存の論文の設定では独立した5クラス分類を2回解くという設定だが、本論文では、10クラス分類を解く前提で、データが2回に分けて与えられているという設定である。
Shuffled MNISTは文献1でも採用されている設定で、各タスクで、MNISTのピクセルがタスクごとに決まった順序でシャッフルされ、それぞれ10クラス分類を解くというものである。
上記が実験結果となっている。ポイントは、EWC1がDisjoint MNISTでダメダメな点である。ここまで悪くなるのはかなり不思議だが、あまり何故かという考察はない。
結果を見る限り、L2-transferでMean/Mode-IMMをしておけば良さそうな雰囲気で、わざわざDrop-transferとかは使わなくてもいいかも。
本論文の貢献は、Disjoint MNISTという問題設定を見つけたところが一番かもしれない。
あと、タスク1がImageNetで、タスク2がCaltech-UCSD Birds-200-2011でAlexNetをベースとした実験も行っており、既存手法より僅かに高精度な結果を達成している(雑)