今回はRust製の深層学習フレームワーク「Burn」のソースコードを読み解きながら、マルチヘッドアテンションの仕組みと実装方法について学んでいきます。
環境
この記事では以下の環境を前提としています。
- Rust 1.75以上
- Burn 0.14.0以降
マルチヘッドアテンションとは
マルチヘッドアテンションは、2017年にGoogleの研究者らが発表した論文「Attention Is All You Need」で提案された注意機構です。トランスフォーマーアーキテクチャの中核をなす技術で、異なる表現部分空間からの情報に同時に注意を向けることを可能にします。
基本的な考え方
マルチヘッドアテンションの基本的な考え方は、単一の注意機構よりも、複数の「ヘッド」に分割して並列的に注意計算を行うことで、異なる視点から情報を抽出できるようにすることです。これにより、シーケンスの異なる位置や異なる抽象レベルの情報に同時に注意を向けることができます。
例えば、自然言語処理では、あるヘッドが文法的な関係に注目し、別のヘッドが意味的な関連性に注目するというように、異なる特性を学習することが可能になります。
数学的な表現
マルチヘッドアテンションは、クエリ(Q)、キー(K)、バリュー(V)という3つの入力を使用します。基本的な計算は以下になるとのことです。
$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O $
$ \text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $
下記で改めて理解していきましょう
1. マルチヘッドアテンションが主に使われる場面
まず、Transformerは、以下のようなタスクで非常に高い性能を発揮します。
-
自然言語処理 (NLP):
- 機械翻訳(例: 英語の文章を日本語に翻訳する)
- 文章要約(例: 長い記事を短い要約文にする)
- 質疑応答(例: 質問文に対して、与えられた文脈から回答を見つける)
- 文章生成(例: あるテーマについて文章を自動で作成する)
-
時系列データ処理:
- 音声認識
- 株価予測
-
コンピュータビジョン:
- 画像キャプション生成(画像の内容を説明する文章を生成する)
- 物体検出
これらのタスクに共通するのは、 シーケンス(順序を持ったデータの列) を扱うという点です。例えば、文章は単語のシーケンスですし、音声は音響特徴量のシーケンスです。
2. 入力データは「ベクトルのシーケンス」
マルチヘッドアテンションに直接入力されるのは、この「シーケンスデータ」が 数値化(ベクトル化) されたものです。
-
自然言語処理の場合:
- まず、入力文はトークン(多くの場合、単語やサブワード)に分割されます。
例: "猫が窓辺で日向ぼっこをしている" → ["猫", "が", "窓辺", "で", "日向ぼっこ", "を", "し", "て", "いる"] - 次に、各トークンは単語埋め込み (Word Embedding) という手法によって、固定長の数値ベクトルに変換されます。このベクトルは、その単語の意味や文脈的な特徴を捉えたものです。
例: "猫" →[0.1, -0.5, 0.3, ...]
(数十〜数百次元のベクトル) - 結果として、入力文は 「単語ベクトルのシーケンス」 として表現されます。これがマルチヘッドアテンションへの主要な入力の元となります。
例:[[0.1, -0.5, ...], [-0.2, 0.8, ...], [0.9, 0.1, ...], ...]
(シーケンス長 × ベクトル次元数の行列のような形)
- まず、入力文はトークン(多くの場合、単語やサブワード)に分割されます。
3. 入力シーケンスから Q (クエリ), K (キー), V (バリュー) を作る
マルチヘッドアテンションの数式に出てくる Q
, K
, V
は、この「ベクトルのシーケンス」から作られます。
最も基本的な自己アテンション (Self-Attention) の場合(入力シーケンス内の要素間の関係性に注目する場合)は以下のようになります。
-
入力として、上記で説明した「単語ベクトルのシーケンス」$X = [x_1, x_2, ..., x_n]$ (ここで $x_j$ は
j
番目の単語のベクトル) が与えられたとします。 -
この入力
X
から、 3種類の異なる線形変換(学習可能な重み行列を掛けること) によって、Q
,K
,V
をそれぞれ生成します。-
Q = X * W_Q_input
(入力シーケンスX
の各ベクトルに、クエリ用の重み行列W_Q_input
を掛ける) -
K = X * W_K_input
(入力シーケンスX
の各ベクトルに、キー用の重み行列W_K_input
を掛ける) -
V = X * W_V_input
(入力シーケンスX
の各ベクトルに、バリュー用の重み行列W_V_input
を掛ける)
※ ここでの
W_Q_input
,W_K_input
,W_V_input
は、マルチヘッドアテンションの各ヘッドが持つ $W_i^Q$, $W_i^K$, $W_i^V$ とは別の、Q, K, Vを準備するための初期変換行列です。 -
つまり、入力シーケンス内の各単語(ベクトル)が、それぞれクエリ、キー、バリューの役割を担うための専用のベクトルに変換されるということです。
-
Q
: シーケンス内の各単語が「他のどの単語の情報と関連付けたいか」を問い合わせるためのベクトル群。 -
K
: シーケンス内の各単語が「私はこんな情報を持っていますよ」と示すためのラベルのようなベクトル群。 -
V
: シーケンス内の各単語が実際に持つ「情報の中身」そのもののベクトル群。
これらの Q
, K
, V
が、先程の数式の最初の入力となります。
$ \text{MultiHead}(Q, K, V) $
通常、これらはバッチ処理されるため、形状は [バッチサイズ, シーケンス長, 特徴量次元]
のような3次元のテンソル(多次元配列)になります。
4. 次の数式の解説
この具体的な入力のイメージを持った上で、次の数式を見てみましょう。
$head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)$
-
Q
,K
,V
: 上記で説明した、入力シーケンス(単語ベクトルの集まり)から線形変換によって生成された「クエリの集合」「キーの集合」「バリューの集合」です。これらは通常、シーケンス長 × 特徴量次元数の行列の形をしています。 -
$W_i^Q$, $W_i^K$, $W_i^V$:
i
番目のアテンションヘッド専用の学習可能な「投影行列」です。- $QW_i^Q$: 全体のクエリ集合
Q
を、i
番目のヘッドの視点に合うようにさらに変換(投影)したもの。 - $KW_i^K$: 全体のキー集合
K
を、i
番目のヘッドの視点に合うようにさらに変換(投影)したもの。 - $VW_i^V$: 全体のバリュー集合
V
を、i
番目のヘッドの視点に合うようにさらに変換(投影)したもの。
これにより、各ヘッドは、同じ入力情報(元のQ, K, V)から、それぞれ異なる側面(部分空間)に注目して情報を処理します。例えば、あるヘッドは単語間の構文的な関係を見るための投影を行い、別のヘッドは意味的な類似性を見るための投影を行う、といった具合です。
- $QW_i^Q$: 全体のクエリ集合
-
Attention(...)
: この関数が、i
番目のヘッドの視点に投影された$QW_i^Q$、$KW_i^K$、$VW_i^V$ を使って、「どのバリュー ($VW_i^V$) にどれだけ注目するか」の重みを計算し、その重みに従ってバリューを合成します。- 具体的には、クエリ($QW_i^Q$ の各行)と全てのキー($KW_i^K$ の各行)との関連度を計算し、ソフトマックス関数で正規化して「アテンションウェイト(注意の重み)」を求めます (後述の「アテンションスコアの計算」を参照)。
- このアテンションウェイトを使って、バリュー($VW_i^V$ の各行)の加重和を計算します。これが $head_i$ の出力、つまり
i
番目のヘッドが抽出した文脈情報になります。
$ MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O $
- $Concat(head_1, ..., head_h)$:
h
個の各アテンションヘッドが、それぞれ異なる視点で抽出した情報($head_1$ から $head_h$ までのベクトル群)を単純に連結します。これにより、多様な視点からの情報が一つの大きなベクトルにまとめられます。 - $W^O$: 連結されて長くなったベクトルを、最終的な出力として適切な次元数に変換するための、学習可能な出力用の投影行列です。各ヘッドからの情報を統合し、次の層で扱いやすい形に整えます。
まとめると
- 入力として、例えば文章であれば「単語の意味を表すベクトルの並び」が与えられます。
- この「ベクトルの並び」から、初期の線形変換によって、全体の「クエリ(Q)の集合」「キー(K)の集合」「バリュー(V)の集合」が作られます。
- これらのQ, K, Vは、複数のアテンションヘッドに分配されます。
- 各ヘッドは、独自の「視点」(投影行列 $W_i^Q$, $W_i^K$, $W_i^V$)でQ, K, Vをさらに変換し、どの情報に注目すべきかを計算して、そのヘッドにとって重要な情報を抽出します($head_i$)。
- 全てのヘッドが抽出した情報を集めて($Concat$)、最後に一つの出力ベクトルにまとめ上げます($W^O$)。
ここで $W^O$ の上付き文字 O
は「Output(出力)」を意味します。
つまり、$W^O$ は、連結された各アテンションヘッドの出力 (Concat(head_1, ..., head_h)
) を、マルチヘッドアテンション層全体の最終的な出力として適切な形に変換するための重み行列(線形変換層)です。
各アテンションヘッド ($head_i$) は、入力シーケンスの異なる側面や表現部分空間からの情報を捉えます。これらの情報は Concat
によって一つにまとめられますが、そのままでは次の層で処理するには次元が大きすぎたり、情報が整理されていなかったりする可能性があります。
そこで、$W^O$ という出力用の線形層(重み行列を掛ける操作)を適用することで、
-
次元の調整: 連結されたベクトルの次元を、モデルが期待する次元(多くの場合、入力の次元
d_model
と同じ)に戻します。 - 情報の統合・集約: 各ヘッドから得られた多様な情報を適切に混ぜ合わせ、より洗練された表現に変換します。
という役割を果たしているとのことです。(「出力変換行列」もしくは「出力投影行列」)
Burnのマルチヘッドアテンションのコード解析
それでは、Burnのマルチヘッドアテンション実装を見ていきましょう。コードは GitHub上のリポジトリ で確認できます。
設定構造体(Configuration)
まず、マルチヘッドアテンションレイヤーを構築するための設定構造体から見ていきましょう。
#[derive(Config)]
pub struct MultiHeadAttentionConfig {
// 各線形レイヤーのサイズ
pub d_model: usize,
// ヘッドの数
pub n_heads: usize,
// ドロップアウト率(デフォルト: 0.1)
#[config(default = 0.1)]
pub dropout: f64,
// 浮動小数点の最小値(デフォルト: -1.0e4)
// これはアテンションスコアにマスクをかける際に使用されます
#[config(default = -1.0e4)]
pub min_float: f64,
// 通常のsoftmaxの代わりに「quiet softmax」を使用するかどうか
#[config(default = false)]
pub quiet_softmax: bool,
// ニューラルネットワークのパラメータを初期化するための関数タイプ
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0), fan_out_only:false}"
)]
pub initializer: Initializer,
}
この設定構造体は、マルチヘッドアテンションレイヤーの基本的なパラメータを定義しています。主なパラメータは以下の通りです:
-
d_model
: モデルの次元数(埋め込みサイズ) -
n_heads
: アテンションヘッドの数 -
dropout
: ドロップアウト率 -
min_float
: マスク適用時に使用する最小値 -
quiet_softmax
: 通常のsoftmaxではなくquiet softmaxを使用するかどうか -
initializer
: 重みの初期化方法
特に注目すべきはquiet_softmax
というパラメータです。これは通常のsoftmaxの代替として、「シーケンスに関連情報がない場合にアテンションヘッドが情報を出力しないようにする」ために使用できるオプションです。これにより、モデルの性能向上や圧縮効率の改善が期待できるとコメントされています。
モジュール定義
次に、マルチヘッドアテンションモジュール本体を見てみましょう。
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct MultiHeadAttention<B: Backend> {
// クエリ空間に入力特徴量を変換する線形層
pub query: nn::Linear<B>,
// キー空間に入力特徴量を変換する線形層
pub key: nn::Linear<B>,
// バリュー空間に入力特徴量を変換する線形層
pub value: nn::Linear<B>,
// 出力特徴量を元の空間に戻す線形層
pub output: nn::Linear<B>,
// ドロップアウト層
pub dropout: nn::Dropout,
// 活性化関数
pub activation: nn::Gelu,
// 各線形層のサイズ
pub d_model: usize,
// ヘッドの数
pub n_heads: usize,
// キーとクエリベクトルのサイズ
pub d_k: usize,
// 浮動小数点の最小値
pub min_float: f64,
// 「quiet softmax」を使用するかどうか
pub quiet_softmax: bool,
}
このモジュールは、マルチヘッドアテンションの実際の計算を行うためのコンポーネントを含んでいます。ジェネリックパラメータB: Backend
を使用して、バックエンドに依存しない実装となっています。これにより、CPUやGPUなど様々なハードウェアで同じコードを使用できます。
主なコンポーネントは以下の通りです:
- クエリ、キー、バリュー、出力用の線形変換層
- ドロップアウト層と活性化関数
- 各種パラメータ(モデルサイズ、ヘッド数など)
初期化処理
設定構造体には、モジュールを初期化するためのメソッドが定義されています。
impl MultiHeadAttentionConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> MultiHeadAttention<B> {
let linear = |config: &Self| {
nn::LinearConfig::new(config.d_model, config.d_model)
.with_initializer(self.initializer.clone())
.init(device)
};
MultiHeadAttention {
query: linear(self),
key: linear(self),
value: linear(self),
output: linear(self),
dropout: nn::DropoutConfig::new(self.dropout).init(),
activation: nn::Gelu::new(),
n_heads: self.n_heads,
d_k: self.d_model / self.n_heads,
min_float: self.min_float,
quiet_softmax: self.quiet_softmax,
d_model: self.d_model,
}
}
}
この初期化処理では、以下のことが行われています:
- クエリ、キー、バリュー、出力用の線形層を同じ設定で初期化
- ドロップアウト層と活性化関数の設定
- マルチヘッドアテンション特有のパラメータ(ヘッド数、キーサイズなど)の設定
d_k
はこの後も登場するので注目してください。d_k
はd_model / n_heads
として計算されています。これは、モデルの次元を複数のヘッドに分割する際の各ヘッドの次元サイズを表しています。
入力と出力の構造体
マルチヘッドアテンションの入力と出力を表現するための構造体も定義されています。
// 入力構造体
#[derive(Debug, Clone)]
pub struct MhaInput<B: Backend> {
// 形状 `[batch_size, seq_length_1, d_model]`
query: Tensor<B, 3>,
// 形状 `[batch_size, seq_length_2, d_model]`
key: Tensor<B, 3>,
// 形状 `[batch_size, seq_length_2, d_model]`
value: Tensor<B, 3>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
}
// 出力構造体
#[derive(Debug, Clone)]
pub struct MhaOutput<B: Backend> {
// アテンションの重み `[batch_size, n_heads, seq_length_1, seq_length_2]`
pub weights: Tensor<B, 4>,
// コンテキストテンソル `[batch_size, seq_length_1, d_model]`
pub context: Tensor<B, 3>,
}
入力構造体MhaInput
はクエリ、キー、バリューのテンソルに加えて、パディングマスクとアテンションマスクを含んでいます。これらのマスクはオプションで、計算中に特定の位置を無視したり、因果的なアテンション(未来の情報を見ないようにする)を実現するために使用できます。
出力構造体MhaOutput
はアテンションの重みとコンテキストテンソルを含んでいます。重みは各ヘッドの各位置がどの位置に注目しているかを表し、コンテキストテンソルは計算結果を表します。
フォワードパス(順伝播)
マルチヘッドアテンションの核心部分は、順伝播処理を行うforward
メソッドです。
impl<B: Backend> MultiHeadAttention<B> {
pub fn forward(&self, input: MhaInput<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = self.attention_linear(input.query, &self.query);
let key = self.attention_linear(input.key, &self.key);
let value = self.attention_linear(input.value, &self.value);
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.clone().matmul(value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = self.output.forward(context);
MhaOutput { weights, context }
}
}
このメソッドでは、次のステップが順に実行されています:
- クエリ、キー、バリューを線形変換して各ヘッド用に分割
- クエリとキーからアテンションスコアを計算
- マスクを適用してアテンションの重みを計算(softmax)
- 重みとバリューの行列積を計算してコンテキスト情報を得る
- ヘッドを結合して元の形状に戻し、最終的な出力線形層を適用
この処理が、先ほど数学的に示した式の実装となっています。
アテンションスコアの計算
アテンションスコアの計算処理は別のメソッドに分離されています。
fn attn_scores(&self, query: Tensor<B, 4>, key: Tensor<B, 4>) -> Tensor<B, 4> {
let attn_scores = query
.matmul(key.transpose())
.div_scalar((self.d_k as f32).sqrt());
self.dropout.forward(attn_scores)
}
ここでは、クエリとキーの行列積を計算し、スケーリング(d_k
の平方根で割る)を適用してからドロップアウトを適用しています。これが、論文で提案されたスケーリングされたドット積アテンションの実装です。
アテンション重みの計算
アテンション重みの計算処理もメソッドに分離されています。
fn attn_weights(
&self,
mut attn_scores: Tensor<B, 4>,
mask_pad: Option<Tensor<B, 2, Bool>>,
mask_attn: Option<Tensor<B, 3, Bool>>,
) -> Tensor<B, 4> {
if let Some(mask_pad) = mask_pad {
let [batch_size, seq_length] = mask_pad.dims();
attn_scores = attn_scores.mask_fill(
mask_pad.reshape([batch_size, 1, 1, seq_length]),
self.min_float,
);
}
if let Some(mask_attn) = mask_attn {
let [batch_size, seq_length_1, seq_length_2] = mask_attn.dims();
attn_scores = attn_scores.mask_fill(
mask_attn.reshape([batch_size, 1, seq_length_1, seq_length_2]),
self.min_float,
);
}
if self.quiet_softmax {
activation::quiet_softmax(attn_scores, 3)
} else {
activation::softmax(attn_scores, 3)
}
}
このメソッドでは、以下の処理が行われています:
- パディングマスクがある場合、マスク位置のスコアを非常に小さな値(
min_float
)に設定 - アテンションマスクがある場合も同様にマスク位置のスコアを小さくする
-
quiet_softmax
フラグに応じて、通常のsoftmaxかquiet softmaxを適用
特に興味深いのはquiet_softmax
の使用です。これは通常のsoftmaxの変種で、アテンションヘッドが関連情報がない場合に「何も見ない」ようにします。
Quiet Softmaxについて
1. 通常のSoftmaxの挙動とおさらい
まず、通常のsoftmax関数がアテンションメカニズムでどのように使われるか思い出してみましょう。
- アテンションスコア(クエリと各キーの関連性の強さを示す数値)が計算されます。
- これらのスコアに対してsoftmax関数を適用すると、合計が1になるような「確率分布」が得られます。これがアテンションウェイト(各キーに対する注意の重み)になります。
- 例えば、スコアが
[1.0, 2.0, 0.5]
だった場合、softmaxを適用すると[0.24, 0.67, 0.09]
のようなウェイトが得られます(値は概算)。これは、「2番目のキーに最も強く注目し、1番目、3番目の順に注目する」という意味になります。
重要なポイント: 通常のsoftmaxは、入力されたスコアがどのような値であっても、必ずどこかのキー(バリュー)に注意を割り当てます。ウェイトの合計が1になるように正規化するため、全てのウェイトが同時にゼロになることは基本的にありません(全ての入力スコアがマイナス無限大でない限り)。
2. 通常のSoftmaxの「困った点」
もし、あるクエリに対して、シーケンス内のどのキーも本当は関連性が低い(つまり、どのバリューも現在の文脈理解に役立たない)場合、どうなるでしょうか?
- 通常のsoftmaxは、それでも無理やりどこかのキーに(比較的高い)ウェイトを割り当ててしまいます。
- その結果、アテンションヘッドは関連性の低い、あるいはノイズのような情報を取り込んでしまう可能性があります。
- これは、モデルが「特に見るべきものがないのに、無理やり何かを見ようとしてしまう」状態と言えます。
3. Quiet Softmaxが目指すこと:「何もない」ときは「何もしない」
Quiet Softmaxは、この「無理やり何かを見てしまう」問題を解決しようとします。
そのアイデアは、「もし関連情報がないのなら、どのキーにも実質的に注意を向けない(つまり『何も見ない』)」という選択肢をアテンションヘッドに与えることです。
4. 「何も見ない」とは具体的にどういうことか?
「何も見ない」というのは、比喩的な表現ですが、具体的には以下の状態を指します。
- 全てのアテンションウェイトが非常に小さくなる: Quiet Softmaxは、全ての入力スコアが低い(関連性がないことを示す)場合に、全てのアテンションウェイトが限りなく0に近い値になるように設計されています。
- 結果として、出力がほぼゼロベクトルになる: アテンションの最終的な出力(コンテキストベクトル)は、アテンションウェイトとバリューベクトルの加重和で計算されます。もし全てのアテンションウェイトがほぼ0ならば、どのバリューベクトルもほとんど考慮されず、結果として得られるコンテキストベクトルはほぼゼロベクトル(情報を持たないベクトル)になります。
アテンションヘッドがこのような状態になることを「静かになる (quiet)」と表現し、そのような挙動を可能にするのがQuiet Softmaxです。
5. Quiet Softmaxはどのようにそれを実現するのか? (Evan Miller氏の提案より)
Quiet Softmaxの一般的な実現方法の一つとして、Evan Miller氏が記事「Attention Is Off By One」で提案した方法があります。これは、softmaxの計算に少し手を加えるものです。
通常のSoftmax:
$ \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_j \exp(x_j)} $
Quiet Softmaxの一例(分母に1を加える):
$ \text{quiet_softmax}(x_i) = \frac{\exp(x_i)}{1 + \sum_j \exp(x_j)} $
-
分母に
1
を加えるのがポイントです。 - もし全てのスコア $x_j$ が非常に小さい(例えば大きな負の値で、関連性が低いことを示す)場合、$\exp(x_j)$ は非常に小さな正の値になります。すると、$\sum_j \exp(x_j)$ もほぼ0に近くなります。
- このとき、Quiet Softmaxの分母は $1 + (\text{ほぼ}0) \approx 1$ となります。
- 分子の $\exp(x_i)$ も非常に小さいので、結果として $\text{quiet_softmax}(x_i)$ の値は非常に小さくなります。
- 重要なのは、この場合、$\sum_i \text{quiet_softmax}(x_i)$ の合計は1よりもずっと小さくなる可能性があるということです。この「1に満たない部分」が、「どこにも注意を向けなかった分」と解釈できます。
6. 「何も見ない」ことのメリット
- ノイズの抑制: 無関係な情報に無理に注意を向けることを防ぎ、モデルがよりクリーンな情報を扱うのに役立ちます。
- スパース性の促進: 不要なアテンションが抑制されることで、本当に重要な情報にモデルが集中しやすくなります。
- 解釈性の向上: アテンションウェイトが高い値を示す場合、それはより「意図的に」その情報に注目している可能性が高まります。
- モデル圧縮: 一部のアテンションヘッドが「何もしない」状態になることは、モデルを圧縮する際の手がかりになる可能性があります(例えば、常に「何もしない」ヘッドは削除できるかもしれない)。
まとめ
Quiet Softmaxは、通常のsoftmaxが常にどこかに注意を割り当ててしまうのに対し、「見るべきものがなければ何もしない(どのキーにも実質的な注意を向けず、結果として情報を持たない出力をする)」という選択肢をアテンションヘッドに与えるための仕組みです。これにより、モデルがより頑健になり、不要な情報に惑わされにくくなることが期待されます。
線形変換と形状変更
クエリ、キー、バリューに線形変換を適用し、ヘッド単位に分割する処理も別メソッドに分離されています。
fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
let [batch_size, seq_length, _d_model] = x.dims();
linear
.forward(x)
.reshape([batch_size, seq_length, self.n_heads, self.d_k])
.swap_dims(1, 2)
}
この関数では、次の処理が行われています:
- 線形変換を適用
- 結果を再形成して、ヘッド次元を明示的に加える(
[batch_size, seq_length, n_heads, d_k]
) - 次元を入れ替えて、ヘッド次元をシーケンス長の前に移動(
[batch_size, n_heads, seq_length, d_k]
)
この形状変更により、各ヘッドが独立して計算を行えるようになります。
キャッシュを使用した最適化
Burnの実装では、推論時のパフォーマンスを向上させるためのキャッシュ機構も提供されています。これは、特に自己回帰的な生成(次のトークンを予測する)で効率的です。
pub fn forward_cache(&self, input: MhaInput<B>, cache: &mut MhaCache<B>) -> MhaOutput<B> {
let [batch_size, seq_length_1, d_model] = input.query.dims();
let query = cache
.query
.forward(input.query, |t| self.attention_linear(t, &self.query));
let key = cache
.key
.forward(input.key, |t| self.attention_linear(t, &self.key));
let value = cache
.value
.forward(input.value, |t| self.attention_linear(t, &self.value));
// 以下は通常のforwardと同様
let attn_scores = self.attn_scores(query, key);
let weights = self.attn_weights(attn_scores, input.mask_pad, input.mask_attn);
let context = weights.clone().matmul(value);
let context = context
.swap_dims(1, 2)
.reshape([batch_size, seq_length_1, d_model]);
let context = cache.output.forward(context, |t| self.output.forward(t));
MhaOutput { weights, context }
}
キャッシュを使用することで、すでに計算された値を再利用でき、特に長いシーケンスの生成時に計算コストを大幅に削減できます。
テスト
Burnのコードには、マルチヘッドアテンションの動作を検証するためのテストコードも含まれています。これらのテストは、実装が正しく動作することを確認するだけでなく、使用方法の良い例にもなっています。
#[test]
fn test_self_attention_shapes() {
let [batch_size, seq_length, d_model, n_heads] = [7, 13, 32, 4];
let device = Default::default();
let mha = MultiHeadAttentionConfig::new(d_model, n_heads).init::<TestBackend>(&device);
let input = MhaInput::self_attn(Tensor::random(
[batch_size, seq_length, d_model],
Distribution::Default,
&device,
));
let output = mha.forward(input);
assert_eq!(
output.context.shape(),
Shape::new([batch_size, seq_length, d_model]),
"Context should have the correct shape",
);
assert_eq!(
output.weights.shape(),
Shape::new([batch_size, n_heads, seq_length, seq_length]),
"Weights should have the correct shape",
);
}
このテストは、セルフアテンション(同じテンソルをクエリ、キー、バリューとして使用)の出力形状が期待通りであることを確認しています。
まとめ
Burnのソースコードを通じて、マルチヘッドアテンションの実装について見てきました。主なポイントは以下の通りです:
- マルチヘッドアテンションは、入力を複数のヘッドに分割し、それぞれが独立してアテンション計算を行う機構です
- 実装では、クエリ、キー、バリューの線形変換、スケーリングされたドット積、softmax計算が基本的な処理となります
- マスク機構により、パディングを無視したり、因果的なアテンションを実現したりできます
- キャッシュを用いた最適化により、自己回帰的な生成時のパフォーマンスを向上できます
Burnの実装を見ることで、トランスフォーマーモデルの基礎となるマルチヘッドアテンションを理解するための良い教材となりますね。
参考文献・リソース
- Attention Is All You Need - マルチヘッドアテンションとトランスフォーマーを提案した論文
- Burn公式リポジトリ - Burnのソースコード
- Burn Book - Burnの公式ドキュメント
- Quiet Softmax - Quiet Softmaxについての参考記事