0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

LLMの蒸留についての学習メモ

Posted at

蒸留のざっくりとしたイメージ

大規模な教師モデルが出力する確率分布を用いて、小規模な生徒モデルを効率的に学習させる。
教師モデルの「温度」を上げることで出力分布を平滑化し、生徒の最適化を安定化する。


気になった理論的背景や論文

  • Phuong & Lampert (2019) の一般化境界:
    • データ数 n が入力次元 d 以上であれば生徒モデルの転移リスクがゼロになる。
    • 限定的に線形分類器対象だが、サンプル効率の飛躍的向上を示す。
  • Hinton et al. (2015)
    • アンサンブル学習を生徒へ圧縮する蒸留を定式化。
    • 生徒がアンサンブル分布とのKL最小化を行うことで、確率分布の最適凸結合に近づく。

平滑化による疑問点

教師モデルの温度を上げると学習安定化は図れるが、分布の細部(クラス間マージンなど)が失われるのではないか?


解決法1:FitNetsによる中間層特徴の蒸留

  • 教師の中間層(Hint Layer)と生徒の対応層(Guided Layer)の特徴表現をL2損失で合わせる。
  • 学習は2段階:
    1. Hintフェーズ:Regressorを介して生徒特徴を教師次元へ射影し、Hint損失で整合。
    2. 全体微調整フェーズ:出力のKD損失+クロスエントロピーで生徒モデル全体を最適化。
  • Regressorは1×1畳み込みなどで次元不一致を解消し、パラメータと計算コストを抑える。

解決法2:SDDによる出力ロジットの細粒度分解

  • 出力ロジットマップを複数の局所領域(マルチスケール)で平均プーリングし、局所的ログットを抽出。
  • 各局所ログットを「一貫項」(グローバル予測クラスと同一)と「補完項」(異なるクラス)に分離。
  • 最終損失は以下の和:
    • グローバルKD損失
    • 各局所セルごとのKD損失(補完項は重みを増強)
  • 細粒度情報を活用することで、平滑化によるクラス間マージンの欠如を補い、微細な判別力を維持。

Diffusionモデル蒸留の概要

  • 多ステップ生成(数百~数千ステップ)が必要な拡散モデルを、少ステップ(例:4ステップ、1ステップ)で同等性能へ高速化。
  • 代表手法:
    • Progressive Distillation(Salimans & Ho)
    • Consistency Models(Song et al.)
    • SCott(Liu et al.)
  • 各中間ステップの生成結果を生徒へ模倣させ、サンプリングコストを大幅に削減。

自分なりの結論整理

  1. 一般化境界が示すように、蒸留は通常学習よりも高いサンプル効率を実現。
  2. 温度平滑化は学習安定化に寄与するが、細部情報損失のリスクがある。
  3. FitNetsSDD のような手法により、中間層や局所出力を蒸留対象に含めることで、大まかな概念と細部構造の両立が可能。

雑感

ChatGPTとの対話を通じた学習は深掘りの助けになる。次はモデル実装を通じて理論を体感したい。

この記事の作成過程

chatgptと話した履歴
https://chatgpt.com/share/6811684a-b79c-8009-ba67-d78bf8200d8c

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?