はじめに
新アーキテクチャ、Block Diffusion...
TransformerとDiffusion両方の強みを組み合わせながら、それらの限界を克服するものらしい...
概要
- 自己回帰型(Transformer)と拡散型(Diffusion)の両方の利点を兼ね備えた新モデル「Block Diffusion」を提案した
- 新たな学習アルゴリズム、勾配分散推定器、データ駆動型ノイズスケジュールを含む
結論
- 既存ディフュージョンモデルの主要な欠点を克服
- 自己回帰モデルとの品質ギャップ
- 任意長シーケンス生成不能
- KVキャッシング非対応
- 離散ディフュージョンモデルで新たな最高水準を確立
- 並列トークン生成と改善されたサンプル制御性を提供しつつ、標準LLMと競合可能な強力なディフュージョン言語モデル構築への有望なステップ
自己回帰型 vs ディフュージョン型言語モデル
特徴 | 自己回帰モデル | ディフュージョン型モデル |
---|---|---|
確率モデリング方法 | トークン間の条件付き確率を積の形で表現 | トークンを独立に扱い、並列生成が可能 |
生成プロセス | 逐次的に生成する必要があり、長いシーケンスでは遅い | 前向きノイズ過程を逆転させる過程でモデリング |
最適化方法 | - | 対数尤度の下界のみを最適化するため品質が低下 |
シーケンス長の制約 | - | 固定長シーケンスのみに制限される |
BD3-LMs: Block Discrete Denoising Diffusion Language Models
-
ブロックディフュージョン尤度:
- シーケンスをL'長のブロックに分割し、自己回帰的にモデル化
- 各ブロック内では離散ディフュージョンELBOを適用
- 重み付けされた交差エントロピー項の和として最終目的関数を定式化
-
効率的な学習・サンプリングアルゴリズム:
- 2回の順伝播のみで計算可能
- 最初の順伝播でKVキャッシュを計算
- 2回目の順伝播で全ブロックの予測を同時に計算
- ブロック単位で生成し、前のブロックに条件付け
自己回帰型とディフュージョン型モデル間の尤度ギャップの理解
-
単一トークン生成のケーススタディ:
- ブロック長L'=1の場合、理論上は自己回帰モデルと同等
- 実際には2ポイントのパープレキシティギャップが発生
-
ディフュージョン目的関数の高分散が原因:
- マスク率が低すぎる場合、再構成が容易で学習信号が弱い
- マスク率が高すぎる場合、最適解はデータ分布の周辺確率になり学習が単純化
低分散学習のためのデータ駆動型ノイズスケジュール
-
「クリップ」されたマスク率を使用:
- $t∈[tmin,tmax]$の範囲に制限
- ブロックサイズL'に応じて適応的に学習
- 5000勾配更新ごとに検証ステップでグリッド探索
-
最適化されたノイズスケジュール効果:
- 損失推定器の分散を低減
- 線形スケジュールや対数スケジュールより良いパープレキシティを達成
ノイズスケジュールとは、トークンをノイズ化(マスキング)する過程を制御するパラメータです。BD3-LMでは、この部分が革新的です。つまり、学習効率には最適なマスク率範囲が存在するという洞察です。
この手法の本質は、「学習にとって最も有益なマスク率範囲のみを使用する」という単純だが強力なアプローチです。
結果
-
尤度評価:
- ブロック長L'を調整することでディフュージョンと自己回帰の間を補間
- OpenWebTextで262B token訓練したモデルで最先端性能を達成
- L'=4でパープレキシティ≤20.73を記録
-
任意長シーケンス生成:
- 既存ディフュージョンモデルが1024トークンに制限される一方、任意長生成が可能
- 同等の生成ステップ数で比較した生成パープレキシティで最高性能
- L'=4で生成パープレキシティ25.7を達成
感想
分散最適化の重要性
論文の表面的な主張は新アーキテクチャの提案ですが、真の革新は学習過程の分散制御にあるのだと思います。トレーニング目的関数の分散を制御することが、最終的な性能向上の鍵で、モデル構造以上に、学習過程の安定性が重要だという学びでした。
用語
| カテゴリ | 用語 | 説明 |
|---------|------|------|
| アーキテクチャ | Block Diffusion | 自己回帰型とディフュージョン型言語モデルを融合したハイブリッドアーキテクチャ |
| アーキテクチャ | BD3-LMs | Block Discrete Denoising Diffusion Language Modelsの略称 |
| アーキテクチャ | D3PM | 離散デノイジングディフュージョン概率モデル(Austin et al.提案) |
| アーキテクチャ | MDLM | Maskable Diffusion Language Modelの略称、ディフュージョン型言語モデルの一種 |
| アーキテクチャ | SEDD | 離散ディフュージョン型言語モデルの一種(Lou et al.提案) |
| アーキテクチャ | SSD-LM | Semi-autoregressive Stochastic Diffusion Language Modelの略称 |
| 技術 | KVキャッシング | キーと値を事前計算してメモリに保存し、推論速度を向上させる技術 |
| 技術 | NFEs | Number of Function Evaluations、関数評価回数(生成効率の指標) |
| 技術 | クリップされたマスク率 | マスキング率の範囲を制限する技術(t∈[tmin,tmax]) |
| 技術 | ノイズスケジュール | ディフュージョンプロセスでノイズ追加のタイミングとレベルを制御するスケジュール |
| 技術 | パープレキシティ | 言語モデルの性能指標、値が低いほど良いモデル |
| 技術 | 生成パープレキシティ | 生成されたテキストの質を評価するためのパープレキシティ指標 |
| 手法 | ELBO | Evidence Lower Bound、対数尤度の下界 |
| 手法 | ディフュージョン | データに徐々にノイズを加え、それを取り除く学習をする生成モデル手法 |
| 手法 | ブロック単位のモデリング | シーケンスを固定長ブロックに分割して処理する手法 |
| 手法 | 自己回帰型 | 前のトークンに基づいて次のトークンを予測するシーケンシャルな手法 |
| 手法 | 離散ディフュージョン | カテゴリカルデータ(トークンなど)に適用するディフュージョンモデル |
| データセット | LM1B | 言語モデリング用の10億語コーパスデータセット |
| データセット | OpenWebText (OWT) | ウェブから収集されたテキストデータセット |