目的
GPTやVision Transformer等のTransformerを活用した機械学習モデルにはAttention機構が用いられており、重要な役割を持っている。そのため、機械学習の分野の理解を深める上で、必要不可欠な技術である。そこで本記事では、図や数式、ソースコードなどを用いてAttentionの仕組みを説明し、その理解を深めることを目的とする。
概要
Attantion
・自然言語処理分野で発展した要素技術であり、Attentionを活用したTransformerは自然言語処理分野において、大きな発展に貢献した。現在では、自然言語処理分野だけでなく、画像処理分野など様々な分野に活用されている。
・入力された情報同士がどの程度似ているか(類似度)計算し、ベクトルの重みを大きくすることで、注意を向けさせる。
仕組み
Attentionの仕組みは以下の通りである。
1.バッチ毎の情報を埋め込みベクトルに変換する
2.埋め込みベクトルから類似度を求める
3.類似度からAttentionの加重和を求める
それぞれについて詳しく説明していく
1.バッチ毎の情報を埋め込みベクトルに変換する
埋め込みベクトルとは、言語や画像情報を数値データに置き換える手法である。これにより、単語や画像バッチの意味をベクトル空間に配置することができる。これにより、位置関係から要素同士の関係性を数値的に表すことが可能になる。
埋め込みベクトルは以下のように計算を行う。
(1)トークン(単語など)にIDを与える。
例:ID=s
(2)ID=sの要素を1とするOne Hot Vecterにする
例:vt = (0,・・・,1,0,・・・,0) (K次元)
(3)全結合NNを活用して埋め込みベクトルを算出する
例:Xt = (0.2,0.1,0.3,・・・・) (M次元)
以下の図はID=2の要素が1の時の埋め込みベクトルである。
上記のような計算式により、情報を重みベクトルに変換することができる。
また、K>Mであるため、K次元であったベクトルがM次元に次元削減を行うことができる。
2.埋め込みベクトルから類似度を求める
上記より、全結合NNを1つ活用して埋め込みベクトルを算出することができました。さらに、これを3つの全結合NNを用い、埋め込んだ各ベクトルをquery(クエリ)、key(キー)、value(バリュー)と呼びます。query、key、valueはそれぞれ同じベクトルを入力として用いていますが、3つの全結合NNの重みがそれぞれ異なるため、それぞれの値は異なっています。また、この全結合NNは学習を行う部分であり、最適化する必要があります。
これでquery、key、valueを求めることができたので、Attentionの内積について説明します。Attentionの内積にはqueryとkeyの行列積を計算することで、求めることができます。今回は4つのベクトルの内積を行列積を使うことで一度に求めています。これにより、それぞれの類似度を求めることができます。
内積の計算は以下の図のようになっています。ベクトルをオレンジの太線で表しています。
これをsoftmax関数にかけることで行の合計が1になります。
これによって行毎に確率で表すことができ、これが加重和の係数となります。この加重和の係数をAttention Weightを呼び、重みが大きいほど類似度が高く、重みが小さいほど類似度が低くなります。
3.類似度からAttentionの加重和を求める
最後に求めた類似度とvalueを加重和することで、より良いバッチ(表現力)を得ることができます。このようにバッチ毎にAttentionを求めるのに、全てのベクトルを利用します。そのため、Attentionはデータ全体を考慮して特徴量を学習出来ると言われています。
ここで、Attentionの式を見てみましょう。
$$
Attention(Q, K, V ) = softmax(\frac{QK^T}{\sqrt{d_k}})V
$$
式をみるとqueryとkeyの埋め込みベクトルの内積を計算し、$\sqrt{d_k}$で割っています。これは次元数が多くなると、内積の値が大きくなりすぎてしまうためです。queryとkeyの各ベクトルが平均0、標準偏差1の正規分布に従う場合、内積の期待値0、分散が$d_k$となリます。分散が大きいとソフトマックス関数を通した時に大きな値に対する勾配が急激に大きくなってしまいます。そのため、勾配消失や勾配爆発の問題が起こりやすくなります。この影響を緩和するために内積を次元数$d_k$の平方根で割っています。その後、ソフトマックス関数に入れ、valueとの加重和を求めることでAttentionを計算していることがわかると思います。
実装
今回はPytorchを利用してシンプルなセルフ・アテンションの実装をします。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleSelfAttention(nn.Module):
def __init__(self, d_model):
super(SimpleSelfAttention, self).__init__()
self.d_model = d_model
self.query_dense = nn.Linear(d_model, d_model)
self.key_dense = nn.Linear(d_model, d_model)
self.value_dense = nn.Linear(d_model, d_model)
def forward(self, x):
# クエリ、キー、バリューを計算
query = self.query_dense(x)
key = self.key_dense(x)
value = self.value_dense(x)
# スケール付きドットプロダクト注意を計算
scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, value)
return output
# 入力データの形状を定義
seq_len = 10 # シーケンスの長さ
d_model = 64 # 特徴次元
# ランダムな入力データを生成
inputs = torch.rand(1, seq_len, d_model)
# 自己注意層のインスタンスを作成し、入力データに適用
self_attention_layer = SimpleSelfAttention(d_model)
output = self_attention_layer(inputs)
print("入力データの形状:", inputs.shape)
print("出力データの形状:", output.shape)
このコードでは、__init__メソッドで、query、key、valueの線形層を定義し、埋め込みベクトルの計算を行えるようにしています。forwardメソッドでは、torch.matmulで行列積の計算を行い、類似度の計算を行っています。第一引数にquery、第二引数に$\frac{K^T}{\sqrt{d_k}}$を入れています。そして、softmax関数にscore($\frac{QK^T}{\sqrt{d_k}}$)を入れてAttention Weight($softmax(\frac{QK^T}{\sqrt{d_k}})$)を計算しています。
最後にAttention Weightとvalueの加重和を計算し、Attention(softmax($\frac{QK^T}{\sqrt{d_k}})V$)出力しています。
参考文献
・Attention Is All You Need
・【深層学習】図で理解するAttention機構
・ざっくり理解する分散表現, Attention, Self Attention, Transformer
・【詳説】Attention機構の起源から学ぶTransformer
・数式を追う!Transformerにおけるattention
・【論文】"Attention is all you need"の解説
・Transformerを理解するため!今からでもAttention入門 ~ イメージ・仕組み・コードの3面で理解する ~
・作って理解する Transformer / Attention
・Vision Transformer入門