0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

【論文解説】Attention Residuals: Kimi Team が10年間変わらなかった残差接続を書き換えた

0
Posted at

論文情報

  • タイトル: Attention Residuals
  • 著者: Kimi Team (Moonshot AI) — Guangyu Chen, Yu Zhang, Jianlin Su, Weixin Xu ほか計36名
  • arXiv: 2603.15031

📖 目次

  1. はじめに
  2. 背景:標準残差接続の仕組み
  3. 課題:PreNorm 希薄化問題
  4. 核心的な洞察:時間と深さの双対性
  5. 提案手法:Attention Residuals(AttnRes)
  6. インフラストラクチャ設計
  7. 実験結果
  8. アブレーション実験
  9. 既存手法との比較
  10. 実務への示唆
  11. まとめ

1. はじめに

2015年の ResNet 以来、残差接続(Residual Connection) はディープラーニングの根幹を成す技術であり、Transformer を含む現代の大規模言語モデル(LLM)すべてがこれを採用している。その仕組みは極めてシンプルだ:

h_l = h_{l-1} + f_{l-1}(h_{l-1})

しかし、このシンプルさにこそ落とし穴がある。すべての層の出力を重み1で均等に足し合わせるこの仕組みは、モデルが深くなるにつれて各層の寄与を希薄化させ、情報の選択的な取り出しを不可能にしてしまう。

Kimi Team(Moonshot AI)が発表した本論文は、この10年間本質的に変わっていなかった残差接続のパラダイムに正面から挑戦し、Attention Residuals(AttnRes) という新しいメカニズムを提案した。AttnRes は、固定的な均等加算を 学習可能な softmax アテンションに置き換え、各層が前の層の出力を入力に応じて動的に選択・集約できるようにする。


2. 背景:標準残差接続の仕組み

残差接続の2つの役割

残差接続には、通常語られる以上の2つの重要な役割がある。

役割①:勾配の高速道路(Gradient Highway)

逆伝播時、損失関数の勾配は以下のように計算される:

∂L/∂h_l = (∂L/∂h_L) · ∏_{j=l}^{L-1} (I + ∂f_j/∂h_j)

展開すると 恒等行列 I が常に含まれるため、勾配がどの層へも直接流れる「高速道路」が確保される。これにより勾配消失問題が大幅に緩和され、100層以上のネットワークでも安定した学習が可能になった。

役割②:層間の情報集約

再帰式を展開すると、第 l 層への入力は:

h_l = h_1 + Σ_{i=1}^{l-1} f_i(h_i)

つまり、入力埋め込み+すべての前層の出力の均等な足し合わせになっている。ここに「どの層の情報を重視するか」を制御するメカニズムは一切存在しない。

Highway Networks による一般化の試み

Highway Networks(2015年)は、学習可能なゲートを導入して解決を試みた:

h_l = (1 - g_l) ⊙ h_{l-1} + g_l ⊙ f_{l-1}(h_{l-1})

しかし、各層がアクセスできるのは直前の層の圧縮済み状態 h_{l-1} のみであり、過去の個別の層出力を選択的に取り出すことはできない。この根本的な制約は変わらない。


3. 課題:PreNorm 希薄化問題

現代の Transformer では PreNorm(LayerNorm を Attention / FFN の前に適用)が標準だが、これが希薄化問題を深刻化させている。

何が起きているか

各層の処理の流れ:
1. 累積された hidden state(全前層出力の合計)を正規化
2. 正規化された値を Attention / FFN で処理 → 標準的なスケールの出力
3. その出力を正規化前の累積値に足し合わせる
層の深さ 累積値の大きさ(概算) 新しい層の寄与の割合
第10層 ~10倍 ~1/10
第50層 ~50倍 ~1/50
第100層 ~100倍 ~1/100

結果として生じる3つの問題

  1. 情報の埋没: 初期層の有用な情報が後の層の出力に埋もれ、選択的に取り出せない
  2. 限定的なアクセス: 各層は圧縮済みの h_{l-1} しか参照できず、個々の過去層の出力にアクセスできない
  3. 出力の肥大化: 深い層が影響力を保つために、ますます大きな出力を学習する必要がある

💡 実証的な証拠: 標準的な Transformer の相当数の層をそのまま削除しても、性能がほとんど低下しないことが報告されている。モデルはそれらの層の寄与が希薄化しすぎて、すでに「無視」することを学んでいたのだ。


4. 核心的な洞察:時間と深さの双対性

本論文の最も重要な着想は、深さ方向の情報圧縮と、RNN の時系列方向の情報圧縮が構造的に同一であるという発見だ。

RNN と残差接続の対比

固定重み集約 Attention 集約
時系列方向 RNN(前ステップの状態を固定重みで引き継ぐ) Transformer(全時刻を softmax で選択参照)
深さ方向 残差接続(前層を重み1で足す) AttnRes(全前層を softmax で選択参照)

RNN は過去のすべてのトークンを1つの固定サイズの隠れ状態に圧縮していたため、長距離依存の学習が困難だった。Transformer はこれを シーケンス方向の Attention で解決した。

残差接続も全く同じ構造で、過去のすべての層出力を1つの累積状態に圧縮している。深さ方向でも同じ解決策(Attention)を適用するのが AttnRes の本質だ。

RNNの問題(時系列方向の圧縮)   → Transformerの解決(時系列方向のAttention)
残差接続の問題(深さ方向の圧縮) → AttnResの解決(深さ方向のAttention)

5. 提案手法:Attention Residuals(AttnRes)

5.1 Full AttnRes

基本原理

固定的な均等加算:

h_l = Σ_{i=0}^{l-1} v_i      (すべて重み1)

を、学習可能な Attention 重みによる加重和に置き換える:

h_l = Σ_{i=0}^{l-1} α_{i→l} · v_i

ここで α_{i→l} は softmax 正規化されたアテンション重みであり、Σ α_{i→l} = 1 を満たす。

数式の詳細

クエリ・キー・バリュー:

q_l = w_l                    (層ごとの学習可能な擬似クエリベクトル、w_l ∈ R^d)
k_i = v_i = h_1              (i=0 のとき:トークン埋め込み)
k_i = v_i = f_i(h_i)         (i≥1 のとき:第i層の出力)

アテンション重みの計算:

φ(q, k) = exp(q^T · RMSNorm(k))

α_{i→l} = φ(q_l, k_i) / Σ_{j=0}^{l-1} φ(q_l, k_j)

層入力の計算:

h_l = Σ_{i=0}^{l-1} α_{i→l} · v_i

設計上のポイント

要素 設計 理由
クエリ 層ごとの学習可能パラメータ w_l 入力非依存なので並列計算が可能
RMSNorm キーに適用 出力のスケールが大きい層がアテンションを支配するのを防止
追加パラメータ 層あたり1つの d 次元ベクトルのみ 48Bパラメータモデルでは誤差程度の追加
初期化 w_l をゼロ初期化 学習開始時は標準残差接続と同等に動作し安定性を確保

計算コスト

  • 演算量: O(L²d) — ただし深さ L は系列長に比べ十分に小さいため実用上は軽微
  • メモリ: O(Ld) — 通常学習では逆伝播用に保持済みの活性化と完全に重複するため追加メモリゼロ

5.2 Block AttnRes

大規模分散学習では、パイプライン並列化と活性化再計算が標準的に用いられるため、Full AttnRes の O(Ld) の通信・メモリコストが問題になる。Block AttnRes はこれを実用的に解決する。

アプローチ

L 層を N 個のブロックに分割(ブロックサイズ S = L / N)

ブロック内:標準残差接続で加算 → 1つのブロック表現 b_n に圧縮
ブロック間:N 個のブロック表現に対して Full AttnRes を適用

ブロック表現の計算

b_n = Σ_{j∈B_n} f_j(h_j)    (ブロック n 内の全層出力の和)
b_0 = h_1                    (トークン埋め込みは常に保持)

ブロック間アテンション

n ブロックの第 i 層への入力は:

  • ブロックの最初の層 (i=1): [b_0, b_1, ..., b_{n-1}] に対してアテンション
  • ブロックの2番目以降 (i≥2): 上記に加え、現在のブロックの途中の部分和 b_n^{i-1} も参照

メモリ・計算コストの比較

Full AttnRes Block AttnRes
メモリ O(Ld) O(Nd)
計算量 O(L²) O(N²)
通信量 O(Ld) O(Nd)

N ≈ 8 で Full AttnRes の性能のほとんどを回復でき、実用上は標準残差接続のドロップイン置き換えとして機能する。

PyTorch 擬似コード

def block_attn_res(blocks, partial_block, proj, norm):
    """
    blocks: N個の完了済みブロック表現 [B, T, D]
    partial_block: 現在のブロック内の部分和 [B, T, D]
    """
    V = torch.stack(blocks + [partial_block])    # [N+1, B, T, D]
    K = norm(V)                                   # RMSNorm
    logits = torch.einsum('d, n b t d -> n b t',
                          proj.weight.squeeze(), K)
    h = torch.einsum('n b t, n b t d -> b t d',
                     logits.softmax(0), V)
    return h

6. インフラストラクチャ設計

AttnRes を大規模学習で実用化するための2つの重要な最適化。

6.1 学習時:クロスステージキャッシング

パイプライン並列化では、ステージ間でブロック表現を転送する必要がある。

ナイーブな方法:

  • 各ステージ遷移ごとに、蓄積された全ブロック表現を再送信
  • 通信コスト: O(C(C-1)/2 · N_p · d) (C = パイプラインチャンク数)

キャッシュベースの方法:

  • 受信側が過去に受け取ったブロック表現をローカルにキャッシュ
  • ステージ遷移時は差分(新しく完成したブロック)のみを転送
  • ピーク通信コストを O(C)O(P) に削減(V倍の改善)
例: 4 GPU × 2仮想ステージ

ナイーブ: [b0, b1, b2] を毎回フル送信
キャッシュ: +[b1, b2] の差分のみ送信 → 6回分の冗長転送を排除

6.2 推論時:Two-Phase 計算戦略

推論時のレイテンシを最小化するための2段階計算:

フェーズ 内容 特徴
Phase 1 ブロック間アテンション(すべてのクエリ w_l を一括処理) クエリが学習パラメータで入力非依存のため、ブロック内の全層分を並列計算可能
Phase 2 ブロック内のシーケンシャル処理(部分和の更新) Phase 1 の結果と online softmax で統合

結果:

  • 学習オーバーヘッド: < 4%
  • 推論レイテンシオーバーヘッド: < 2%

7. 実験結果

7.1 スケーリング則

5つのモデルサイズでの比較で、AttnRes はすべての計算予算で一貫してベースラインを上回った

🔑 最重要の結果: Block AttnRes は、1.25倍の計算量で学習したベースラインと同等の性能を達成。つまり、残差接続の変更だけで約20%の学習コスト削減に相当する。

7.2 ダウンストリームベンチマーク

Kimi Linear アーキテクチャ(48B 総パラメータ / 3B 活性化パラメータ)を 1.4T トークンで事前学習した結果:

カテゴリ ベンチマーク Baseline AttnRes 改善幅
一般 MMLU 73.5 74.6 +1.1
GPQA-Diamond 36.9 44.4 +7.5
BBH 76.3 78.0 +1.7
TriviaQA 69.9 71.8 +1.9
数学 & コード Math 53.5 57.1 +3.6
HumanEval 59.1 62.2 +3.1
MBPP 72.0 73.9 +1.9
中国語 CMMLU 82.0 82.9 +0.9
C-Eval 79.6 82.5 +2.9

📊 注目すべき傾向: 多段階推論タスク(GPQA-Diamond +7.5、Math +3.6)で最大の改善。これは「深い層が特定の浅い層の情報を選択的に参照する必要がある」タスクほど AttnRes の恩恵が大きいことを示す。

7.3 学習ダイナミクスの変化

指標 Baseline AttnRes
出力のノルム 深さに比例して単調増加 O(L) ブロック境界で周期的にリセットされ、有界に保たれる
勾配ノルム 初期層に集中(不均一) 全層にわたって均一に分布

→ AttnRes が PreNorm 希薄化を構造的に解消していることが、学習ダイナミクスからも確認された。


8. アブレーション実験

ブロック数 N の影響

ブロック数 N Validation Loss 備考
Full (N=L) 1.738 最高性能だがスケーラビリティに制限
2 1.740 ほぼ Full と同等
4 1.740 ほぼ Full と同等
8 1.740 推奨設定:性能とコストの最適バランス
16 1.745 性能が低下し始める
32 1.750 ベースラインに近づく
1 (= 標準残差) 1.766 ベースライン

N = 8 が性能とスケーラビリティの最適な妥協点。

入力依存の重要性

手法 Validation Loss
AttnRes(入力依存の重み) 1.738
入力非依存の固定学習重み 1.749
DenseFormer(固定重み) 1.767
ベースライン 1.766

「入力に応じて動的に重みを変える」ことが本質的に重要であり、単に過去の層にアクセスを与えるだけでは効果がない。


9. 既存手法との比較

手法 層間アクセス 重みの性質 スケーラビリティ 効果
標準残差接続 直前層のみ(圧縮状態) 固定(重み1) ベースライン
Highway Networks 直前層のみ 学習可能なゲート 限定的改善
DenseFormer 全前層 入力非依存の固定学習重み 改善なし
Scaled Residuals 直前層のみ スケーリング係数 限定的改善
Full AttnRes 全前層 入力依存 softmax △(大規模時) 最高性能
Block AttnRes ブロック単位 + 部分和 入力依存 softmax ほぼ Full と同等

10. 実務への示唆

🔧 LLM を利用する側(推論・ファインチューニング)

  • 既存ワークフローへの変更は不要
  • AttnRes を組み込んだモデルは、特に推論集約型タスクでそのまま性能が向上する

🏗️ LLM を学習する側

  • ドロップイン置き換え: 標準残差接続を Block AttnRes に置き換えるだけで、学習コスト約20%相当の性能改善
  • アーキテクチャ探索への影響: AttnRes は深く、狭いネットワークにより有利に働く(追加された深さを有効活用できるため)
  • 実装: GitHub リポジトリに PyTorch 参照実装が公開

💡 学んだアテンション重みのパターン

学習されたアテンション重みの可視化から、以下の興味深いパターンが判明:

  1. 局所性の保持: 各層は依然として直前の層に最大の重みを置く
  2. 選択的な遠距離参照: 特定の層が、はるか前の層に意味のある重みを割り当てる
  3. Attention 層と MLP 層の違い:
    • Attention 層の前 → 幅広い過去の層を参照
    • MLP 層の前 → 最近の層に集中

11. まとめ

本論文の貢献

貢献 内容
Attention Residuals 固定残差接続を深さ方向の softmax Attention に置き換える新手法とそのスケーラブルな変形(Block AttnRes)
理論的統一 標準残差接続=深さ方向の線形 Attention、AttnRes=深さ方向の softmax Attention という統一的な見方を提示
インフラストラクチャ クロスステージキャッシングと Two-Phase 計算により大規模学習・推論を実用化
包括的な検証 スケーリング則、アブレーション、48B パラメータモデルでの下流タスク評価

一言でまとめると

Transformer がシーケンス方向で RNN の固定重み集約を softmax Attention に置き換えて成功したように、AttnRes は深さ方向で同じ変革を起こした。追加コストはわずか(学習 < 4%、推論 < 2%)で、Block AttnRes により 1.25倍の計算量で学習したベースラインと同等の性能を実現する。

今後の展望・未解決の課題

  • 独立した第三者による再現検証がまだ行われていない
  • 1B、7B、70B パラメータスケールでの 1.25倍優位性の一般化可能性
  • 長コンテキストタスクでの性能評価
  • ブロックサイズ N の最適化に関するさらなる分析
  • 将来のインターコネクト技術の進歩により、Full AttnRes も大規模で実用化できる可能性がある

📚 参考文献

  • 本論文: Kimi Team. "Attention Residuals." arXiv:2603.15031, Mar 2026. Link
  • 実装: github.com/MoonshotAI/Attention-Residuals
  • He et al. "Deep Residual Learning for Image Recognition." CVPR 2016.(ResNet — 残差接続の原論文)
  • Vaswani et al. "Attention Is All You Need." NeurIPS 2017.(Transformer の原論文)
  • Srivastava et al. "Highway Networks." ICML 2015.(ゲート付き残差)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?