はじめに
前回は大規模言語モデルの概要と仕組みについて整理しましたが、今回より大規模言語モデルに関連する数学的な説明について整理しようと思います。今回はTransformerで重要な要素である注意機構です。
注意機構とは?
注意機構とは二つの系列間での各要素の関連度合いを計算することで、Transformerでは入力となる文章の注目すべき特徴を抽出することを目的に使用しています。
注意計算の式を行列形式で表すと以下の通りになります。
Attention(Q,K,V)=softmax({\frac{QK^T}{{\sqrt{d}}}})V
クエリ$Q$:検索対象のベクトル
キー$K$:他のトークンが持つ特徴をあらわすベクトル
バリュー$V$:実際の情報
\displaylines{
Q=(q_1, q_2, ..., q_n), q_i\in{R^{d_k}} \\
K=(k_1, k_2, ..., k_n), q_j\in{R^{d_k}} \\
V=(v_1, v_2, ..., v_m), q_j\in{R^{d_v}}
}
この式より、クエリとキーからどの情報を見るべきかを判断し、そこから必要な情報を抽出することを表現しています。
上記の式は以下の手順で導出できます。
注意計算の式導出
1.クエリとキーの類似度
あるクエリとキーの類似度${s_{ij}}$は、ベクトル同士の内積をとることで定義され、Transformerでは後続のソフトマックスでの偏りをなくすためスケーリングするため、埋め込み次元${d_k}$を使っています。
\displaylines{
s_{ij}=q_ik_j^T
}
Transformerにおける類似度
\displaylines{
s_{ij}=\frac{q_ik_j^T}{\sqrt{d_k}}
}
2.ソフトマックスによる確率分布への変換
類似度の確率分布を計算し、後続のバリューの重み付き平均の計算に利用します。
\displaylines{
a_{ij}=softmax(s_{ij})=\frac{\exp{s_{ij}}}{\sum_{k=1}^{n}\exp{s_{ik}}}
}
3.バリューの重み付き平均
最終的に特徴量抽出に必要なバリューを出力するため、Attentionの重みで平均します。
\displaylines{
o_{i}=\sum_{j=1}^{m}a_{ij}v_{j}
}
PyTorchでの実装例
生成AIに作らせましたが、実装例は以下のようになります。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SelfAttention(nn.Module):
def __init__(self, d_model):
super().__init__()
self.d_k = d_model
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
def forward(self, X):
Q = self.W_Q(X)
K = self.W_K(X)
V = self.W_V(X)
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
attention = F.softmax(scores, dim=-1)
return torch.matmul(attention, V)
# ---- 動作例 ----
X = torch.randn(4, 8) # (seq_len, d_model)
attn = SelfAttention(d_model=8)
output = attn(X)
print(output)
マスク付き自己注意
デコーダでは未来の情報にアクセスできないため、マスク行列を導入したマスク付き自己注意を用いられます。
1.マスク行列の導入
マスク行列とは未来の情報(=トークン)に対する注意をゼロにするためのものです。後続のソフトマックス関数を適用するため、未来の情報に対して、スコアに非常に小さな値を割り当てます。通常、負の無限大に近い値をとります。
\displaylines{
M_{ij}=\left\{
\begin{array}{ll}
0 & (i \geq j) \\
-\infty & (i \lt j)
\end{array}
\right.
}
2.マスクの適用
スコアにマスク行列を加算します。現在までの情報に対してはクエリとキーのスコア値のまま、未来の情報に対しては非常に小さい値となるため、ソフトマックス関数を適用すると未来の情報に対しては注意を向けることはできなくなり、過去及び現在の情報から次の情報を予測するようになります。
Attention(Q,K,V)=softmax({\frac{QK^T}{{\sqrt{d}}}}+M)V
マルチヘッド注意機構
ヘッドとは独立した注意計算を行う小さなモジュールのことで、複数のヘッドを並列に処理することで異なる視点から情報に注目するのがマルチヘッド注意機構になります。
1.各ヘッドの計算
上記の説明の通り、各ヘッドで注意計算をするので、例の注意計算の式になります。
\displaylines{
head_{i}=Attention(Q_{i},K_{i},V_{i})
}
2.結合及び線形変換
各ヘッドの結果を結合することで、多面的な視点で情報に注目することができます。そのため、Concat関数を利用します。
そして、結合後の結果に重み$W_{o}$を付けることで線形変換します。
\displaylines{
MultiHead(Q,K,V)=Concat(head_{1}, ..., head_{n})W_{o}
}
最後に
今回より、数学的な要素を含めながら大規模言語モデルの技術要素を説明しました。説明する中で不明な点は生成AIと対話しながらになりましたが、各処理の具体的な展開を理解できました。
参考文献
Ashish Vaswani et al. (2017) "Attention Is All You Need" NeurIPS 2017