記事の概要
Attention機構がすごい、というのはわかっているが、いまいちなんでこれを計算することで、文脈を考慮した単語埋め込みが手に入るのかがわからなかった。
「いきなり、Query、Key、Valueの定義とか言われてもなんでそんなもん用意するんだ?」がわからなかった。
この問いに対して、ふと思いついた解釈の仕方があったので、備忘としてメモ。
注意。論文そのものを読み解いたわけではなく、数式を眺めつつ、手を動かしつつ、悶々と考えていたときに、ふと思いついたものであるため、「こう考えたら"私は"理解しやすい」という内容になっている。
そのため正しいかどうかまったくわからないが、私と同じことを考えている人がいれば参考になるのではないかと思う。
- ある程度、線形代数がわかっていると良い。
- テーブルデータを行列とみなして、線形代数使うと、線型回帰ができる、とかそういうレベルの理解でOK。
Attentionのベースの発想
-
インプットテキストのトークン数を$( N )$ とする。
- 例:this is a pen. というテキストであれば、
- this
- is
- a
- pen
- .
- という5トークン。$N=5$。
- 例:this is a pen. というテキストであれば、
-
1トークンの埋め込みベクトルの次元を $( d )$ とする(★)。
-
データ行列 $( \boldsymbol{X} )$ は、次のように表せる。
$$
\boldsymbol{X} \in \mathbb{R}^{N \times d}
$$ -
知りたいのは、トークン $( i )$ と関係性の強いトークン $( j )$ を見つけること。
- 正確に言えば、トークン$( i )$ とそれ意外のトークン $( j )$の関連性の強さを定量化する。
-
そのためには、内積を取れば良いのでは?という発想。
- テーブルデータにおいて、特徴量同士の関連性を計算するために、分散共分散行列(相関行列)を計算するのと同じ発想。違いとしては、分散共分散行列を計算するときは、$\boldsymbol{X}^T・\boldsymbol{X} \quad(\in \boldsymbol{R}^{d×d})$で計算するが、今回は、$\boldsymbol{X}・\boldsymbol{X}^T \quad(\in \boldsymbol{R}^{N×N})$ で計算する点。
- $( N ・ N )$ の行列を手に入れる。
-
したがって、"クソ雑Attention" $\boldsymbol{A}$ は以下のように計算できる。
$$
\boldsymbol{A} = \boldsymbol{X} \cdot \boldsymbol{X}^T \in \mathbb{R}^{N \times N}
$$ -
これを使って、元のベクトルに重み付けをする。
$$
\boldsymbol{X_{\text{new}}} = \boldsymbol{A} \cdot \boldsymbol{X}
$$ -
こうすることで、文書中のトークン同士の関係性を考慮した新しい埋め込みベクトルが得られる。
"クソ雑Attention"の問題点
- しかし、この方法は 雑すぎる。
-
学習ができない(=汎用的な言語モデルを構築できない)。
- 学習対象とすべき「重み」パラメータが存在しないため。
- そこで、この「クソ雑Attention」を改良する必要がある。
- そこで生まれたのが $Query(Q), Key(K), Value(V)$ の概念。
- (※:このように考えて本当に生まれたのかどうか私はわからない。ただ、このような流れでQKVを考えついたのではないか?という想像をすると私は理解しやすいと思った。 )
Query(Q), Key(K), Value(V)の導入
-
以下のように定義する。
$$
\boldsymbol{Q} = \boldsymbol{X} \cdot \boldsymbol{W_Q}, \quad
\boldsymbol{K} = \boldsymbol{X} \cdot \boldsymbol{W_K}, \quad
\boldsymbol{V} = \boldsymbol{X} \cdot \boldsymbol{W_V}
$$ -
$( \boldsymbol{W_Q}, \boldsymbol{W_K}, \boldsymbol{W_V} )$ は学習対象の重みパラメータ。
- それぞれの行列の次元を書き下すと、
- $ \boldsymbol{W_Q} \in \boldsymbol{R}^{d×d_k} $
- $ \boldsymbol{W_K} \in \boldsymbol{R}^{d×d_k} $
- $ \boldsymbol{W_V} \in \boldsymbol{R}^{d×d_v} $
- $d$:上述の(★)の次元
- 注意点として、$\boldsymbol{W_V}$だけ、$d_k$ではなく$d_v$なところに注意。
- $ \boldsymbol{Q} \in \boldsymbol{R}^{N×d_k} $
- $ \boldsymbol{K} \in \boldsymbol{R}^{N×d_k} $
- $ \boldsymbol{V} \in \boldsymbol{R}^{N×d_v} $
- それぞれの行列の次元を書き下すと、
-
先ほどと同様に内積を取ると、トークン同士の関係性を求められる。
$$
\boldsymbol{Q} \cdot \boldsymbol{K}^T \in \mathbb{R}^{N \times N}
$$- ちなみに、これを展開すると、以下のようになる。
$$
\boldsymbol{Q} \cdot \boldsymbol{K}^T =
(\boldsymbol{X} \cdot \boldsymbol{W_Q}) \cdot (\boldsymbol{X} \cdot \boldsymbol{W_K})^T =
\boldsymbol{X} \cdot \boldsymbol{W_Q} \cdot \boldsymbol{W_K}^T \cdot \boldsymbol{X}^T
$$- 途中に $( \boldsymbol{W_Q} \cdot \boldsymbol{W_K}^T )$ という重みが挟まっているが、本質的には内積を計算していることが分かる。
正しいAttentionの計算
-
実際のAttentionの計算は、もう少し工夫されている。
-
QueryとKeyの内積をスケール調整し、softmaxを適用する。
$$
\text{softmax} \left( \frac{\boldsymbol{Q} \cdot \boldsymbol{K}^T}{\sqrt{d_k}} \right) \in \mathbb{R}^{N \times N}
$$
$\sqrt{d_k}$はスケーリングパラメータ。埋め込みの次元が大きいと、$ \boldsymbol{Q} ・ \boldsymbol{K}^T $の各要素の値は大きくなりやすい。それを抑えるために、QやKの次元の大きさ $N × d_k$の$d_k$の平方根で割り算している。
-
さらに、更新方法も次のように変わる。
$$
\boldsymbol{X_{\text{new}}} = \text{softmax} \left( \frac{\boldsymbol{Q} \cdot \boldsymbol{K}^T}{\sqrt{d_k}} \right) \cdot \boldsymbol{V}
$$ -
$( \boldsymbol{W_Q}, \boldsymbol{W_K}, \boldsymbol{W_V} )$ を学習して良い感じの値にすることで、文脈を捉えた言葉の理解が可能になる(ように見える)。
まとめ
-
誤解を恐れずに言えば、Attentionは以下の流れで計算される。
- トークン数 × 埋め込みベクトルの内積を計算し、トークン同士の関係性を求める。
- 得られた関係性行列$(( N \times N )$ の正方行列)を元のデータ行列と掛け合わせ、"文脈を考慮した"埋め込みを作る。
上記とは直接関係ないが、最近いろいろ勉強していて思うのが、新しい概念を勉強するときに、わかりやすいなと思う説明方法はパターンがある。
それは、
- ①くそ単純な方法を説明(言い換えると、直感的あるいは、素直なやり方。)
- ②その方法だと〇〇という問題が発生する。
- ③それを解決するために、△△というアイデアを取り込んだらうまく行った。
- ④②③の繰り返し。
という説明の仕方が一番理解しやすいし、その方法やモデルの「気持ち」がわかりやすいなと思った。
例えば、今回のAttentionで言えば、"クソ雑Attention"の考え方を最初に解説されていると、QKVという難解なテンソルの出現理由がわかりやすい。かつ、「QKVは結局のところ、データ行列の内積を取りたい」だけ、という気持ちを理解できるので、応用も利きやすいとおもう。