Transformerの心臓部!QKV Attention機構入門
はじめに
この記事では,近年のAI技術,特に自然言語処理においてブレークスルーとなったTransformerモデルの心臓部であるAttention(アテンション)機構について,その基本となるQKV(Query, Key, Value)モデルの原理と,Softmax関数による重み算出のプロセスを,図やPyTorchのコードを交えながら徹底的に解説します.
この記事の対象読者
- TransformerやAttention機構に興味があるけど,数式だけだと難しく感じる方
- 「QKVって結局何?」と思っている方
- 今後PyTorchでAttentionを実装してみたいと考えている方
人間が文章を読むときに重要な部分へ自然と注目するように,Attention機構も入力情報の中から重要な部分を動的に選び出す仕組みです.この記事を読めば,その「注目」の仕組みが理論と実践の両面から理解できるようになります.
1. QKV注意機構とは? 〜図書館の比喩で理解する〜
このセクションでは,難しい数式やコードは一旦置いておいて,直感的に理解するために比喩表現で説明したいと思います.
今,あなたは図書館に来ています.「AIのことについて知りたいな」と思った時,どのように司書さんに質問するでしょうか? もちろん「AIに関する専門書はありますか?」と質問しますよね.この質問がQ (Query) です.
そして,図書館の本にはよく種類ごとにラベリングされています.それがK (Key) になります.司書さんはあなたのQuery(質問)と各本のKey(ラベル)を照合して,関連性の高い本を探してくれます.
最後に,そのラベルに紐づく本(情報)そのものがV (Value) になります.
まとめると,Attention機構とはQuery(知りたいこと)を使って,たくさんの情報(Key-Valueペアの集まり)の中から,関連性の高いKeyを見つけ出し,そのKeyに対応するValueを重点的に取り出す仕組みなのです.
では,次のステップとして,この関連度をコンピュータはどのように計算しているのかを見ていきましょう.
2. Attention計算の3ステップ
Attentionの計算は,大きく分けて以下の3つのステップで構成されています.
- スコア計算: Queryと各Keyの関連度を計算する.
- 重み算出: スコアをSoftmax関数で正規化し,「注目度」の重みに変換する.
- 情報集約: 重みを使って各Valueの加重和をとり,最終的な出力を得る.
一つずつ丁寧に見ていきましょう.
Step 1: アテンションスコアの計算 〜関連度を数値化する〜
まず,Queryに対して各Keyがどれだけ関連しているかを数値化します.これは,Queryベクトルと各Keyベクトルのドット積を計算することで実現されます.
なぜドット積?
ベクトル同士のドット積は,2つのベクトルの「向きがどれだけ似ているか」を表します.向きが似ているほど値は大きくなり,直交していれば0,逆方向なら負の値になります.この性質を利用して,QueryとKeyの関連性を測っているのです.
計算されたスコアは,論文 "Attention Is All You Need" に従って,Keyベクトルの次元数 $d_k$ の平方根 $\sqrt{d_k}$ で割ってスケール調整します.
$$
\text{score}(Q, K) = \frac{Q K^T}{\sqrt{d_k}}
$$
なぜスケーリング?
$d_k$ が大きい場合,ドット積の値が大きくなりすぎることがあります.そのまま次のステップのSoftmax関数に入力すると,勾配が非常に小さくなり学習が不安定になる「勾配消失問題」が起きやすくなります.これを防ぐためのおまじない,あるいは「スコアの暴走を防ぐ調整」だと考えてください.
Step 2: 重みの算出 〜Softmaxで注目度を決める〜
次に,計算したアテンションスコアをSoftmax関数に入力します.
$$
\text{Attention Weight} = \text{softmax}(\frac{Q K^T}{\sqrt{d_k}})
$$
Softmax関数は,入力された数値リストを,合計すると1になるような確率分布に変換する関数です.
Softmax関数の役割
- 正規化: 全ての出力の合計が1になるように正規化します.これにより,各Keyへの注目度を「確率」のように解釈できます.
- 強調: 値が大きいスコアはより大きな重みを,値が小さいスコアはより小さな(0に近い)重みを持つように,大小関係を強調します.
このSoftmax関数から出力された値が,どのKey(ひいてはValue)にどれだけ注目すべきかを示すアテンションウェイト (Attention Weight) となります.
Step 3: 情報の集約 〜Valueの加重和をとる〜
最後に,ステップ2で計算したアテンションウェイトを使って,各Valueベクトルの加重和 (Weighted Sum) を計算します.
$$
\text{Attention}(Q, K, V) = (\text{Attention Weight}) \cdot V
$$
これは,各Valueベクトルに,対応するアテンションウェイトを掛け合わせ,それらを全て足し合わせる処理です.
直感的なイメージ
「注目度が高い(アテンションウェイトが大きい)Valueの情報は,より強く最終結果に反映され,注目度が低いValueの情報はあまり反映されない」ということです.「重要な情報は大きな声で,そうでない情報は小さな声で聞き集める」ようなイメージですね.
この結果得られたベクトルが,Queryにとって必要な情報が重点的に抽出された「文脈を考慮した表現」となるのです.
3. PyTorchで再現してみよう
それでは,ここまで見てきた3つのステップを,PyTorchの簡単なコードで再現してみましょう.
import torch
import torch.nn.functional as F
import math
# --- ダミーデータの準備 ---
# シーケンス長=4, 特徴量次元=6 の入力データを想定
# 本来はQ, K, Vは入力シーケンスを線形変換して得られる
seq_len = 4
d_k = 6
# Query, Key, Value をランダムなテンソルで生成
# (バッチサイズ=1, シーケンス長, 次元)
Q = torch.randn(1, seq_len, d_k)
K = torch.randn(1, seq_len, d_k)
V = torch.randn(1, seq_len, d_k)
print("--- 入力テンソルのShape ---")
print("Q.shape:", Q.shape)
print("K.shape:", K.shape)
print("V.shape:", V.shape)
print("-" * 25)
# --- Step 1: アテンションスコアの計算 ---
# Q と Kの転置 を行列積
# K.transpose(-2, -1) は最後の2つの次元を転置 -> (1, 6, 4)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
# scores.shape: (1, 4, 4) 各Queryが他の全てのKeyと計算したスコア
print("\n--- Step 1: スコア計算後 ---")
print("Scores.shape:", scores.shape)
print("Scores:\n", scores)
print("-" * 25)
# --- Step 2: 重みの算出 (Softmax) ---
# 行ごと(最後の次元)にSoftmaxを適用
attention_weights = F.softmax(scores, dim=-1)
# attention_weights.shape: (1, 4, 4)
print("\n--- Step 2: Softmax適用後 ---")
print("Attention Weights.shape:", attention_weights.shape)
# 各行の合計が1になっていることを確認
print("Sum of weights per query:", attention_weights.sum(dim=-1))
print("Attention Weights:\n", attention_weights)
print("-" * 25)
# --- Step 3: 情報の集約 (加重和) ---
# アテンションウェイトとVを行列積
output = torch.matmul(attention_weights, V)
# output.shape: (1, 4, 6) 入力と同じ形状に戻る
print("\n--- Step 3: 加重和計算後 ---")
print("Output.shape:", output.shape)
print("Output:\n", output)
print("-" * 25)
実行結果は以下の通りです.
--- 入力テンソルのShape ---
Q.shape: torch.Size([1, 4, 6])
K.shape: torch.Size([1, 4, 6])
V.shape: torch.Size([1, 4, 6])
-------------------------
--- Step 1: スコア計算後 ---
Scores.shape: torch.Size([1, 4, 4])
Scores:
tensor([[[ 3.3316e-01, 7.1520e-02, 4.6809e-01, 1.1596e-01],
[ 2.1434e-01, -1.0959e+00, 1.0994e+00, -1.9563e+00],
[-7.9123e-02, 8.7879e-04, 5.2081e-01, -4.4296e-02],
[-3.6985e-01, 1.4227e+00, -9.5855e-01, 1.8621e-01]]])
-------------------------
--- Step 2: Softmax適用後 ---
Attention Weights.shape: torch.Size([1, 4, 4])
Sum of weights per query: tensor([[1.0000, 1.0000, 1.0000, 1.0000]])
Attention Weights:
tensor([[[0.2689, 0.2070, 0.3077, 0.2164],
[0.2627, 0.0709, 0.6365, 0.0300],
[0.2024, 0.2193, 0.3688, 0.2096],
[0.1075, 0.6454, 0.0597, 0.1874]]])
-------------------------
--- Step 3: 加重和計算後 ---
Output.shape: torch.Size([1, 4, 6])
Output:
tensor([[[-0.0875, -0.5641, -0.0059, -0.4153, -1.0604, 0.4528],
[-0.3880, -1.3888, -0.2552, -0.4095, -0.5432, -0.0216],
[-0.1129, -0.7199, -0.0742, -0.4281, -0.9563, 0.2142],
[ 0.0483, -0.1760, -0.0708, -0.2574, -1.3784, 0.4469]]])
-------------------------
この実行結果から,コードが各ステップをどのように処理しているかを見ていきましょう.
まず,shapeが(1, 4, 6)であるQ, K, Vテンソルから,Step 1でアテンションスコアを計算すると,shapeが(1, 4, 4)のScores行列が得られます.これは、4つの単語それぞれが,他の全単語との関連度を持っていることを示しています.
次に,Step 2でSoftmax関数を適用すると,各行の合計が1.0になるAttention Weightsに変換されました.これにより,単なるスコアが確率的な「注目度」の重みになったことが確認できますね.
最後にStep 3で,この重みとVで加重和を取ることで,最終的なOutputが得られました.shapeも入力と同じ(1, 4, 6)に戻っています.このOutputの各ベクトルには、他の単語の情報が「注目度」に応じて適切にブレンドされているのです.
このように,実際の計算も3つのステップでシンプルに行われていることがわかりますね.PyTorchにはtorch.nn.MultiheadAttentionという便利なクラスも用意されていますが,その内部ではこのような計算が行われています.
まとめ
今回は,Transformerの根幹をなすQKV Attention機構の原理について,比喩や図,PyTorchのコードを交えながら解説しました.
• QKVモデルは,Query(質問), Key(索引), Value(情報) の仕組みで関連情報を引き出す.
• Attentionの計算は,①スコア計算 → ②重み算出(Softmax) → ③加重和 の3ステップで行われる.
• これにより,入力シーケンス内のどの部分に注目すべきかを動的に判断し,文脈に応じた情報抽出が可能になる.
この記事で,Attention機構の基本的な考え方を掴んでいただけていれば幸いです.
次は,このAttentionを複数並列で行うMulti-Head Attentionや,実際のTransformerモデルの全体像についても学んでいくと,さらに理解が深まるはずです.
参考文献
• Ashish Vaswani, et al. (2017). Attention Is All You Need.
• 岡崎直観,荒瀬由紀,鈴木 潤,鶴岡慶雅,宮尾祐介,共著. (2023). 自然言語処理の基礎
最後までお読みいただきありがとうございました!