3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

アテンション機構を理解したい

Last updated at Posted at 2025-01-20

PyTorchにTransformerのチュートリアル記事があると聞き調べてみた。

私はそもそもニューラルネットワーク周りの知識に疎いので、アテンションについてわかったことを書いていく。

自己注意機構(Self-Attention)の定義

query、key、valueを用いて、単語の埋め込みベクトルから単語間の関係性を調べるとのこと。
マルチヘッドアテンションは単に自己注意機構を並列処理できるようにしただけなので、
自己注意機構の定義を理解できれば他は大体わかります。

数式としては

$$ Attention(Q, K,V)= softmax(\frac{QK^T}{\sqrt{d_k}})V $$

query, key, valueの定義は以下の通り。

$$ Q=XW_Q+b_Q $$

$$ K=XW_K+b_K $$

$$ V=XW_V+b_V $$

$W_Q, W_K, W_V$ は重み行列。
$b_Q, b_K, b_V$ はバイアス。
$X$ は入力埋め込み(文中の単語(トークン)の埋め込み行列)
$d_k$はkeyの次元数です。
Queryの次元はkeyと内積をとるためkeyと同じく$d_k$。
Valueの次元は$d_v$と独立しています。

実際に用いる場合は特定の情報について隠す必要があったりして、

$$ Attention(Q, K,V)= softmax(\frac{QK^T}{\sqrt{d_k}} - Mask)V $$

のように表現されることもある。

Q, K, Vってなに?

調べたところ多く言われているのは、
$Q$:query(質問)は「どの情報に注意を向けるべきか?を表すベクトル」を寄せ集めた行列
$K$:key(特徴、検索キー)はqueryと照らし合わせてどのトークンがどれだけ関連性があるかを表す行列。
$V$:value(実際の情報)入力内容の各トークンが持つ「実際の情報」を表すベクトルを集めた行列
とのこと。
それはわかった。

ただ、定義を見て
$$W_Q, W_K, W_V$$
ってどこからくるねん!
と思った方、いますよね???
(私がそうでした)

これらは学習対象の行列です。
Transformerはこの3つの行列を学習し、最適化することで、その優れた性能を発揮しています。

つまりアテンションは、
$W_Q, W_K, W_V$ という3つの重み行列(あとバイアス)を学習し、単語間の類似度や関係性についてうまく表現できるような仕組みです。

実際は最後にまた線形層をはさんだり、他の部分の最適化もあるので厳密には違うようですが、
まあいいでしょう。

このことを踏まえて考えてみます。

Value

まず$V$は、入力から適切な重みを学習し、単に線形変換してるだけです。
特に工夫もないので、文章から得られた特徴をほとんどそのまま学習し、表現している行列と解釈できます。

QueryとKey

やはりシンプルに考えるなら、内積=類似度という解釈です。
多分この理解で問題ないとは思うのですが、もう少し式をいじってみます。

QとKをメンドクサめに考えてみる

クエリとキーの内積の(i,j)成分を考えてみる。(バイアス項は一旦無視)
内積をシグマで表す。
GPTにTexを投げたら修正してくれたのでそれを載せます。
$d_{input}$は$d_{model}$と同義です。

$まず、クエリ ( Q ) とキー ( K ) の定義は次のようになります:$

Q_{i,j} = \sum^{d_{\text{input}}}_{a=1} X_{i,a} W^Q_{a,j}, \quad
K_{i,j} = \sum^{d_{\text{input}}}_{b=1} X_{i,b} W^K_{b,j}

$ ここで、( Q ) と ( K^\top ) の内積(QK^\top)は次のように計算されます: $

(QK^\top)_{i,j} = \sum_{k=1}^{d_k} Q_{i,k} K_{j,k}

$ これを ( Q ) と ( K ) の定義を代入して展開すると: $

(QK^\top)_{i,j} = \sum_{k=1}^{d_k} 
\left( \sum_{a=1}^{d_{\text{input}}} X_{i,a} W^Q_{a,k} \right) 
\left( \sum_{b=1}^{d_{\text{input}}} X_{j,b} W^K_{b,k} \right)

$ さらにまとめると、最終的に次の形になります:$

(QK^\top)_{i,j} = \sum_{k=1}^{d_k} \sum_{a=1}^{d_{\text{input}}} \sum_{b=1}^{d_{\text{input}}} 
X_{i,a} W^Q_{a,k} X_{j,b} W^K_{b,k}

ここでbのシグマを計算すると、

(QK^\top)_{i,j} = \sum_{k=1}^{d_k} \sum_{a=1}^{d_{\text{input}}} 
X_{i,a} W^Q_{a,k} (X_{j,1} W^K_{1,k}+X_{j,2} W^K_{2,k}+...+X_{j,d_{input}} W^K_{d_{input},k})
\\=\sum_{k=1}^{d_k} \sum_{a=1}^{d_{\text{input}}}(X_{i,a} W^Q_{a,k}X_{j,1} W^K_{1,k}+X_{i,a} W^Q_{a,k}X_{j,2} W^K_{2,k}+...+X_{i,a} W^Q_{a,k}X_{j,d_{input}} W^K_{d_{input},k})
\\=\sum_{k=1}^{d_k} \sum_{a=1}^{d_{\text{input}}}
(X_{i,a}X_{j,1}W^Q_{a,k}W^K_{1,k}+
X_{i,a}X_{j,2}W^Q_{a,k}W^K_{2,k}+
...+X_{i,a}X_{j,d_{input}}W^Q_{a,k}W^K_{d_{input},k})

aのシグマも計算しちゃうぞ。

(QK^\top)_{i,j}=\sum_{k=1}^{d_k}[
(X_{i,1}X_{j,1}W^Q_{1,k}W^K_{1,k}+
X_{i,1}X_{j,2}W^Q_{1,k}W^K_{2,k}+
...+X_{i,1}X_{j,d_{input}}W^Q_{1,k}W^K_{d_{input},k})
\\+(X_{i,2}X_{j,1}W^Q_{2,k}W^K_{1,k}+
X_{i,2}X_{j,2}W^Q_{2,k}W^K_{2,k}+
...+X_{i,2}X_{j,d_{input}}W^Q_{2,k}W^K_{d_{input},k})
\\+(X_{i,3}X_{j,1}W^Q_{3,k}W^K_{1,k}+
X_{i,3}X_{j,2}W^Q_{3,k}W^K_{2,k}+
...+X_{i,3}X_{j,d_{input}}W^Q_{3,k}W^K_{d_{input},k})
...
\\+(X_{i,d_{input}}X_{j,1}W^Q_{d_{input},k}W^K_{1,k}+
X_{i,d_{input}}X_{j,2}W^Q_{d_{input},k}W^K_{2,k}+
...+X_{i,d_{input}}X_{j,d_{input}}W^Q_{d_{input},k}W^K_{d_{input},k})
]

もう少し見やすくする。

a=1のとき(今回は先にbのシグマを消したのでこのように見る)は

(X_{i,1}X_{j,1}W^Q_{1,k}W^K_{1,k}+
X_{i,1}X_{j,2}W^Q_{1,k}W^K_{2,k}+
...+X_{i,1}X_{j,d_{input}}W^Q_{1,k}W^K_{d_{input},k})

a=2のとき

(X_{i,2}X_{j,1}W^Q_{2,k}W^K_{1,k}+
X_{i,2}X_{j,2}W^Q_{2,k}W^K_{2,k}+
...+X_{i,2}X_{j,d_{input}}W^Q_{2,k}W^K_{d_{input},k})

a=3のとき

(X_{i,3}X_{j,1}W^Q_{3,k}W^K_{1,k}+
X_{i,3}X_{j,2}W^Q_{3,k}W^K_{2,k}+
...+X_{i,3}X_{j,d_{input}}W^Q_{3,k}W^K_{d_{input},k})

kがシグマの変数、i, jは定数。
あと扱ってるのは成分だからどれも実数であることに気を付ける。
うーん。
わからない。

と思ったらシンプルな変形がありました。

QK^T=(XW_Q)(XW_K)^T=XW_QW_K^TX_K^T

行列の公式、忘れてました。

一通り見て、何か面白い解釈がないか探ってみましたが、

  • 学習可能な重み付き内積
  • $X$に対する二次形式(双線形形式)

など無難な解釈に落ち着きました。

もしかしたら? 他の双線形形式や二次形式の理論も機械学習に応用できるかも。

なぜsqrt(key次元)でスケーリングしてる?

論文にちっちゃく書いてました。

To illustrate why the dot products get large, assume that the components of q and k are independent random
variables with mean 0 and variance 1. Then their dot product, q · k =
Pdk
i=1 qiki, has mean 0 and variance dk.

なぜドット積が大きくなるかを説明するために、$q$ と $k$ の成分が平均 0、分散 1 の独立した確率変数であると仮定します。このとき、ドット積 $q \cdot k = \sum_{i=1}^{d_k} q_i k_i $ は、平均 0、分散 $ d_k $ を持ちます。

統計とかあんまり覚えてないよ……
分散に線形性ってなかったんとちゃうか? と思いGPTに聞いたところ
二つの確率変数が独立であった場合には線形性が成り立つそう。
つまり、確率変数が独立でないなら

$$ V[X+Y] \not= V[X] + V[Y]$$
なのですが、独立なら

$$ V[X+Y] = V[X] + V[Y] $$

が成り立ち、線形性より

V[q \cdot k] = V[ \sum^{d_k}_{i=1}{q_i \cdot k_i} ] =\sum^{d_k}_{i=1}V[ {q_i \cdot k_i}]=\sum^{d_k}_{i=1}{1} = d_k

が成り立つ。
つまり内積の結果が標準正規分布$N(0,1)$に従うようにするためのスケーリングって訳ですね。

こうすることで$Softmax$の結果に散らばりが出にくくなり、適切な推論と学習が行われやすいと解釈できます。

埋め込みベクトルってなに?

PositionalEncodingを一番わかりやすく解説してるところ。

1.まず文章をトークンに分解

"This is a pen" → [ This , is , a , pen ]

2.トークンを整数に変換

This → 1
is → 100
a → 321
pen → 3000

すなわち

[1, 100, 321, 3000]

に変換される。

ここまで1, 2はTokenizerで行っています。
Tokenizerは基本的にはあらかじめ開発されてあるものを使います。

3.文章の各単語をベクトルに変換する。

ここでは埋め込み層を用いて、離散的な語彙インデックスを連続値の埋め込みベクトルに変換します。

ここで純粋な疑問。

埋め込み層ってなに?

以下、GPTに質問してわかったことをまとめます。


埋め込み層の数式定義

1.入力

語彙インデックス$x$を入力とする。

$x$の形状および要素は

Shape(x) = (Batch Size, Sequencd Length)
x \in \{0,1,2, ..., |Vocab|-1\}

を満たす。
これはそのままエンコーダへの入力にあたる。

語彙インデックス$x$への変換はモデルにもよるが、
大体トークナイザが行う。

tokenized = tokenizer(
    text,
    max_length=128,
    truncation=True,
)

$|Vocab|$は語彙サイズ。トークナイザ―の語彙数を表す。

これらの定義から$x$は

$Batch Size$ 行
$Sequence Length$ 列(トークナイザで文章を変換した後のトークン列の長さ)の

行列であることがわかる。

さっきの例だとトークナイザで変換した後の文字列は

[1, 100, 321, 3000]

であるから、$Sequence Length$は4である。

$BatchSize$は、1回の処理で同時にモデルに入力されるデータの数である。

例えば複数の入力文章が格納されたリスト

text=[
    "This is a pen",
    "Cat is cute",
    "He is tall"
    ]

をトークナイズすると

[
    [1, 100, 321, 3000],
    [3, 100, 20],
    [5, 100, 80],
]

に変換される。この場合、

$BatchSize=3$

であり、モデルには3つの文章が一度に渡され、同時に処理される。

自分の場合、ここで少し引っかかりました。
なので少し詳しく書くと、


多くのモデルの$Sequence Length$は固定であり、
$BatchSize$も固定である。

一度の処理につき理論上は$SequenceLength \times BatchSize$だけ処理できるが、
普通に処理させると、$SequenceLength$を超過した分の文章は切り捨てられる。

$SequenceLength=3, BatchSize=2$だとして、
"This is a pen"($[1, 100, 321, 3000]$)を入力したら、入力$x$は

[
    [1, 100, 321], 
    [3000, 0, 0]
]

となるかと思いきや実際は

[
    [1, 100, 321],
    [0, 0, 0]
]

と切り捨てられてしまう。そのため、トークナイザ側の設定で

# トークナイザで分割
tokenized = tokenizer(
    long_text,
    max_length=10,
    truncation=True,
    stride=5,
    return_overflowing_tokens=True
)

という具合にして長い文章を明示的に分割して渡す必要がある。

この際、モデルは各バッチを並列に処理するため分割された他の部分は参照されない。
そのため、$SequenceLength$をオーバーした部分を次のバッチに入れると文脈を理解できなかったりする。

こういうことを減らすためにstride=5として重複する部分をもつように長い文章を分割したりする。

2.埋め込み行列

埋め込み層は埋め込み行列$W$を内部に保持している。定義は

W \in \mathbb{R}^{|Vocab|\times d_{model}}

であり、

$|Vocab|:語彙のサイズ$

$d_{model}:埋め込みベクトルの次元数$

である。

すなわち、埋め込み層の重み行列$W$は

$|Vocab|$行、$d_{model}$(ユーザ側が設定)列の行列だとわかる。

この行列はモデルの学習を通して単語の意味や関係性を学習します。

この$d_{model}$はGPTによるとこんな感じで設定されているらしい。

小型モデル(例: DistilBERT): $𝑑_{model}=128 \sim 256$
標準モデル(例: BERT-base): $d_{model}=768$
大型モデル(例: BERT-large, GPT-3): $d_{model}=1024 \sim 4096$

3.埋め込みベクトルへの変換

語彙インデックス(整数型のトークン列)$x$を埋め込み行列$W$に対応付けて埋め込みベクトルを取得する。

y=W[x]

$y:埋め込みベクトル 形状:(BatchSize,SequenceLength, d_{model})$

$W[x]:入力インデックスxに対応する行(Wの行列形式)を選択$

埋め込みベクトルは、離散的なインデックスを$d_{model}$次元の連続値ベクトルに割り当てている。

どういうことかというと、

$語彙インデックスx[i][j]$に対応する$W[x[i][j]]$

を埋め込み行列$W$から取り出しています。

これはつまり、

$W$の$x[i][j]$行目を取り出す操作です。

例えば

W=\begin{pmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
7 & 8 & 9 \\
\end{pmatrix}

とします。This is a penを
仮に$BatchSize=2$, $SequenceLength=3$として$x$を作成、変換するとします。

# [This ,is ,a]
# [pen , "", ""]
[
    [1, 100, 321], 
    [3000, 0, 0]
]

Thisを変換する場合:
$x[0][0]$の位置がThisにあたります。
これを埋め込みベクトル$y$に変換してみると、

y=W[x[0][0]]=W[1]=[4, 5, 6] \ \ (Wの1行目を取り出した)

よって、Thisの埋め込みベクトル$y_{this}$は、

y_{this}=[4, 5, 6]

とわかります。

埋め込み行列なので、行に単語のベクトルが埋め込まれているというイメージです。
isなら100行目、penなら3000行目を埋め込みベクトルとするわけです。
例の場合、3行しか書いてないので足りませんが……

これをトークン列の全てに対して行い、埋め込みベクトルに変換。
変換したベクトル全てには次のPositional Encodingが行われます。

4.Positional Encoding

埋め込みベクトルだけでは位置情報がたりない

さっきの埋め込みベクトルだけでは、単語の位置情報が失われてしまいます。
これはさっき挙げたページで分かりやすく解説してるので省きます。
さっくり言うと、位置に応じて固定のベクトルを加算することで埋め込みベクトルに位置情報を付加しています。

おわり

Attention機構の基礎的な部分をメモしました。
おおよそ何をしているのかが理解できたので、自分の研究でも何かしら応用したいですね。

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?