概要
Google Research Brain Teamから出た"Meta Pseudo Labels"という短いタイトルの論文を読んでみました。公開されている論文はまだ Preprint なので今後書き直される可能性はありますが、大まかには理解できたので紹介してみます。
分類モデルの訓練は、一般的にはクロスエントロピー損失を最小化することを通じて行われます。これは、モデルが出力する分布をターゲット分布に近づけるための操作です。多くの場合、ターゲット分布は one-hot 表現が採用されます。しかし、one-hot 表現による訓練には、過学習が生じやすいという欠点があります。この問題に対処する方法として、例えば Label Smoothing のような方法が提案されています。また、半教師あり学習のシナリオでは、事前学習済みのモデル(教師モデル)の出力分布をモデル(生徒モデル)のターゲット分布とするといったことが行われます。このように、分類モデルのターゲット分布は、必ずしも one-hot 表現であるとは限らず、訓練のプロセスやシナリオに応じて適宜選択する必要があります。
本論文で提案している Meta Pseudo Labels(以下、MPL)とは、ひとことで言えばターゲット分布そのものをメタ学習によって獲得しようという手法です。
記事では、以下の順に説明します。
- 前提知識として、様々なターゲット分布の形を整理します。
- ターゲットのメタ学習である Meta Pseudo Labels を導入します
- MPL のために必要な訓練の手順を説明します。
- MPL がどのような効果をもたらすのかを説明します
書誌情報
- Pham, Hieu, et al. "Meta Pseudo Labels." arXiv preprint arXiv:2003.10580 (2020).
- https://arxiv.org/abs/2003.10580
様々なターゲット分布
分類モデルのターゲット分布$q_{*}(\mathbf{Y} | \mathbf{x})$は、訓練のシナリオによって、違うものが採用されます。まずはその例を眺めてみます。
訓練シナリオによる違い
完全な教師あり学習
完全にラベルが提供されている場合、ターゲット分布は通常 onehot 表現で与えられます。
q_{*}(\mathbf{Y} | \mathbf{x}) \triangleq \text{one-hot}( \mathbf{y} )
知識蒸留
モデルの軽量化などで使用される知識蒸留においては、大きい容量の訓練済みのモデル(教師モデル)の知識(Dark Knowledge)を小さい容量のモデル(生徒モデル)へと伝達することを考えます。
このシナリオでは、生徒モデルのターゲット分布は教師モデルの出力によって与えられます。
q_{*}(\mathbf{Y} | \mathbf{x}) \triangleq q_{\text{large}}(\mathbf{Y} | \mathbf{x})
半教師あり学習
半教師あり学習は、少数のラベルありデータと多数のラベルなしデータが存在しているときに、多数のラベルなしデータも有効活用したいときに採用するシナリオです。
基本的には、一部のラベルありデータから訓練されたモデル$q_{\xi}$を用いてラベルなしデータに対する疑似ラベルを作成し、それを用いてモデルを訓練していきます。
Hard/Soft 2 通りのやり方があります。
Hard
$q_{\xi}$の出力のうち最も大きいクラスを用いて onehot 表現で与えます。
q_{*}(\mathbf{Y} | \mathbf{x}) \triangleq \text{one-hot}(\text{argmax}_{\mathbf{y}} q_{\xi} (\mathbf{y} | \mathbf{x}))
Soft
$q_{\xi}$の出力によってターゲット分布を与えます。
q_{*}(\mathbf{Y} | \mathbf{x}) \triangleq q_{\xi} (\mathbf{Y} | \mathbf{x})
近年のテクニック
以上のようなシナリオに応じたターゲット分布ではうまく行かないときに、以下のようなヒューリスティックな手法によってモデルの精度を改善できるということがわかってきています。
Label Smoothing
ターゲット分布が onehot 表現のときにしばしば生じるのが過学習の問題です。この問題を緩めるために、正解のクラス以外にも数値を割り振った分布をターゲット分布とすることで過学習を抑えることができるということが知られています。
q_{*}(c | \mathbf{x}) \triangleq q_{\text{smooth}}(c | \mathbf{x})=\left\{\begin{array}{ll} 1-d+1 / C & {if} & c=\mathbf{y} \\ d / C & { if } & c \neq \mathbf{y} \end{array}\right.
Temperature Tuning
確率分布をシャープ/なだらかにする温度パラメータ$\tau$を導入し、知識蒸留や半教師あり学習(Soft)で使用されるターゲット分布を変形させることができます。$\tau \to 0$のとき、onehot へと近づきます。
q_{*}(c | \mathbf{x}) \triangleq \frac{\exp \left(l_{c}(\mathbf{x}) / \tau\right)}{\sum_{i=1}^{C} \exp \left(l_{i}(\mathbf{x}) / \tau\right)}
既存の手法の問題点
どのようにしてターゲット分布を選べばよいのかという観点からみてみると、これまでの方法には、共通して以下の 2 つの制約が暗に設けられていたということに気づきます。
- 事前に目標分布$q_∗$を選択します。訓練後は固定しておくか、訓練中にアドホックな方法で調整/更新します。
- $q_*$のシャープさ・なだらかさが、データサンプルに依存しない。
ターゲット分布のメタ学習
上で挙げた制約を制約を克服することが、本手法の目的となります。先程の制約を反転させると、以下のような性質を$q_*$に持たせることが必要になります。
- 目標分布$q_∗$は訓練中に、訓練プロセスに応じて動的に変化する。
- 目標分布$q_∗$のシャープさ・なだらかさは、データサンプルに応じて変化する。
つまり、$q_∗$は訓練のプロセスの中で変化する、つまりそれ自体が訓練されるものであるということが求められます。しかも、その分布の作られ方は、データサンプルに応じて変化することが求められます。例えば、ある訓練サンプルに対してすでに十分な自信を持ってモデルが予測できる場合は、過学習を避けるためになだらかなターゲット分布を提供するといった具合です。
以上のようなターゲット分布$q_*(\mathbf{\mathbf{x}})$を$q_{\Psi}(\mathbf{x})$によってパラメータライズし、勾配降下法によって訓練する、ということを考えます。ターゲット分布$q_{\Psi}(\mathbf{x})$はモデル$p_{\Theta}$の訓練のための疑似ラベルを提供するので、知識蒸留の用語を用いて$q_{\Psi}$を教師モデル、$p_{\Theta}$を生徒モデルと呼ぶことにします。
「生徒モデルへの適切な疑似ラベルを提供するための、教師モデルのメタ学習」という意味で、本手法を Meta Pseudo Labels と名付けています。
MPLの手順
MPL のアイデアを実現するためには、教師モデルの訓練と生徒モデルの訓練をうまく組み合わせた手順の確立が必要です。
基本的な流れ
MPL は更新ステップを、以下の 2 つのフェーズに分けて進めていきます。
- 生徒が教師の疑似ラベルから学ぶ
- 教師が生徒の Validation 損失から学ぶ
以下、順番に詳細を確認してきます。
Phase1: 生徒が教師の疑似ラベルから学ぶ
教師モデルの出力したターゲット分布、つまり疑似ラベルを使って、生徒モデルが学ぶというフェーズです。このフェーズは、知識蒸留で見られる訓練の過程と同じものです。
生徒モデルのパラメータ$\Theta$は、以下のような更新則が適用されます。
\Theta^{(t+1)} \triangleq \Theta^{(t)}-\left.\eta \nabla_{\Theta} \mathcal{L}_{\mathrm{CE}}\left(q_{\Psi}(\mathbf{x}), p_{\Theta}(\mathbf{x})\right)\right|_{\Theta(t)}
Phase2: 教師が生徒のValidation損失から学ぶ
次に、教師モデルを何らかの基準によって更新する必要があります。ここで、Validation データを使用します。
Validation データは生徒モデルに与えられ、以下のような損失が計算されます。これは、$\Theta^{(t+1)}$の関数として表現されます。
\mathcal{L}_{\mathrm{CE}}\left(\mathbf{y}_{\mathrm{val}}, p_{\Theta^{(t+1)}}\left(\mathbf{x}_{\mathrm{val}}\right)\right) \triangleq \mathcal{R}\left(\Theta^{(t+1)}\right)
そして、Phase1 の式に示したように、$\Theta^{(t+1)}$には教師モデル$q_{\Psi}$が含まれています。つまり、$\mathcal{R}\left(\Theta^{(t+1)}\right)$は$\Psi$の関数でもあります。
\mathcal{R}\left(\Theta^{(t+1)}\right) = \left.\left.\mathcal{R}\left(\Theta^{(t)}-\eta \nabla_{\Theta} \mathcal{L}_{\mathrm{CE}}(q_{\Psi}(\mathbf{x})), p_{\Theta}(\mathbf{x})\right)\right|_{\Theta^{(t)}}\right)
この損失を使用して教師モデルを更新します。これによって、Validation 損失を小さくできる方向に生徒モデルが訓練されるように、教師モデルは更新されます。
下図は、この訓練方法の意味を表す概念図です。Validation 損失によって教師モデルを訓練するのは、生徒モデルが Training 損失のみに従って過学習に陥ることを防ぎ、生徒モデルの Validation 損失もある程度小さくできるように導くという効果を与えます。
訓練プロセスの安定化
以上のような教師モデルと生徒モデルの訓練プロセスは、一見するとうまくいきそうに見えます。
しかし実際には、教師モデルの訓練が十分に進まないうちに生徒モデルが過学習状態になり、うまく行かないことがあったと述べられています。
そこで、教師モデルの更新のために、生徒モデルの Validation 損失に加えて通常の分類損失も追加します。つまり、教師モデルの損失関数としてラベル付きデータに対する分類損失を計算し、これも教師モデルの訓練に使用します。
以上の流れを 1 つの図にまとめると、以下のようになります。
ReducedMPLによるメモリ削減
ここまでで示した訓練プロセスでは、教師モデルと生徒モデルが同時にメモリに乗っていることが求められます。これは、軽量なモデルであれば可能ですが、ある程度の規模のモデルを使った訓練の際に問題になります。
そこで、以上のプロセスを簡略化した ReducedMPL が提案されています。
まず、非常に大きいモデル$q_{\text{large}}(\mathbf{x})$を訓練しておき、データセット(ラベル付き・なしどちらでも可)に対して適用することで事前処理されたデータセットを手に入れます。
次に、Small Teacher として、単純な複数層のネットワークを用意します。この小さなネットワークは、$q_{\text{large}}$が提供するターゲット分布を調整して生徒モデルに提供するという役割を担います。Small Teacher は基本的には恒等変換に近い関数の周りをウロウロしながら、生徒モデルの訓練に適したターゲット分布を提供できるように更新されます。
このやり方だと、メモリに同時に乗せる必要があるのは Small Teacher と生徒モデルのみになるため、通常の MPL で生じていたメモリ問題が解決されます。
ReducedMPL を図示すると、以下のようになります。図では省略されていますが、Student(t+1)には、Validation 損失を計算するためのデータが与えられます。これは、$q_{\text{large}}(\mathbf{x})$によって提供されるものではなく、本当のラベル付きデータです。
なお、通常の MPL で行っていた教師モデルの分類損失の適用は、Small Teacher に対して行っているのかは不明です。Small Teacher が非常に小さいモデルであり、あまり不安定になることがないということで追加の分類損失は必要ないのかもしれませんが、単に図版のミスという可能性もあります。このあたりは正式版の論文を確認してみたいところです。
MPLの効果
以下はトイデータ(TwoMoons)で、ラベル付きデータとして星印の部分のみが与えられている、という状態を表しています。MPL を用いていると、自然な分類モデルが訓練できているということが確認できます。
MPL における教師モデルは、生徒モデルの Training 勾配が Validation 勾配に近づくように促すような効果を与えています。このトイデータの訓練の過程で、生徒モデルの Training 勾配と Validation 勾配のコサイン類似度をプロットしたのが以下の図です。
生徒モデルの Training 勾配と Validation 勾配が近いということは、どちらの損失から見ても望ましいパラメータの更新方向が一致しているということであり、Training データに対して生徒モデルが過学習する心配がないということを意味します。Validation データを直接参照して訓練しているわけではなく、教師モデルによるターゲット分布の与え方のコントロールによってこれを実現しています。
まとめ
以上、疑似ラベルをメタ学習する手法 Meta Pseudo Labels について紹介してきました。
分類モデルの訓練によく使われる、Label Smoothing や疑似ラベルといった概念が、自分の中でバラバラに存在していたのですが、この論文を読むことでスッキリと整理された気がします。
まだ Preprint の段階なので、正式版ができたら再読してみたいと思います。