蒸留のざっくりとしたイメージ
大規模な教師モデルが出力する確率分布を用いて、小規模な生徒モデルを効率的に学習させる。
教師モデルの「温度」を上げることで出力分布を平滑化し、生徒の最適化を安定化する。
気になった理論的背景や論文
-
Phuong & Lampert (2019) の一般化境界:
- データ数 n が入力次元 d 以上であれば生徒モデルの転移リスクがゼロになる。
- 限定的に線形分類器対象だが、サンプル効率の飛躍的向上を示す。
-
Hinton et al. (2015):
- アンサンブル学習を生徒へ圧縮する蒸留を定式化。
- 生徒がアンサンブル分布とのKL最小化を行うことで、確率分布の最適凸結合に近づく。
平滑化による疑問点
教師モデルの温度を上げると学習安定化は図れるが、分布の細部(クラス間マージンなど)が失われるのではないか?
解決法1:FitNetsによる中間層特徴の蒸留
- 教師の中間層(Hint Layer)と生徒の対応層(Guided Layer)の特徴表現をL2損失で合わせる。
- 学習は2段階:
- Hintフェーズ:Regressorを介して生徒特徴を教師次元へ射影し、Hint損失で整合。
- 全体微調整フェーズ:出力のKD損失+クロスエントロピーで生徒モデル全体を最適化。
- Regressorは1×1畳み込みなどで次元不一致を解消し、パラメータと計算コストを抑える。
解決法2:SDDによる出力ロジットの細粒度分解
- 出力ロジットマップを複数の局所領域(マルチスケール)で平均プーリングし、局所的ログットを抽出。
- 各局所ログットを「一貫項」(グローバル予測クラスと同一)と「補完項」(異なるクラス)に分離。
- 最終損失は以下の和:
- グローバルKD損失
- 各局所セルごとのKD損失(補完項は重みを増強)
- 細粒度情報を活用することで、平滑化によるクラス間マージンの欠如を補い、微細な判別力を維持。
Diffusionモデル蒸留の概要
- 多ステップ生成(数百~数千ステップ)が必要な拡散モデルを、少ステップ(例:4ステップ、1ステップ)で同等性能へ高速化。
- 代表手法:
- Progressive Distillation(Salimans & Ho)
- Consistency Models(Song et al.)
- SCott(Liu et al.)
- 各中間ステップの生成結果を生徒へ模倣させ、サンプリングコストを大幅に削減。
自分なりの結論整理
- 一般化境界が示すように、蒸留は通常学習よりも高いサンプル効率を実現。
- 温度平滑化は学習安定化に寄与するが、細部情報損失のリスクがある。
- FitNets と SDD のような手法により、中間層や局所出力を蒸留対象に含めることで、大まかな概念と細部構造の両立が可能。
雑感
ChatGPTとの対話を通じた学習は深掘りの助けになる。次はモデル実装を通じて理論を体感したい。
この記事の作成過程
chatgptと話した履歴
https://chatgpt.com/share/6811684a-b79c-8009-ba67-d78bf8200d8c