0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

図解と最小例で直感的に理解する Self-attention、Cross-attention、Multi-head Attention

0
Posted at

attention intuition.png
図1:Self-attention における Query・Key・Value の生成、attention scores / weights の計算、Value の重み付き和による出力表現の生成フロー

高レベルの直感

Attention は次の問いに答える仕組みである。

  • 自然言語処理(NLP)タスクの場合: あるトークン(単語)を処理するとき
    • どの単語(自分自身を含む)に注意を向けるべきか?
    • そしてそれぞれの単語にどれくらい注意を向けるべきか?
  • 時系列処理タスクの場合: あるトークン(時間ステップ)を処理するとき
    • 現在の時間ステップにとって有用な情報を含む時間ステップ(現在の時間ステップ自身を含む)はどれか?
    • そして各時間ステップがこの現在の時間ステップにどの程度影響すべきか?
  • ここでいう「処理」とは、現在の要素(単語または時間ステップ)の表現を更新することを意味する。
  • 得られた表現は、その後の Transformer 層に入力され、さらに変換や予測が行われる。

Self-attention と Cross-attention

役割の違い

  • Self-attention は、トークンが 同一のシーケンス内 の他のトークンと相互作用し、文脈を考慮した自身の表現を構築できるようにする。
  • Cross-attention は、あるシーケンスのトークン別のシーケンス から情報を取得できるようにする。

これら 2 つの仕組みは、多くのシーケンスモデリングタスクで用いられる。

以下は、エンコーダ・デコーダ型モデルの場合の例である。

タスク Self-Attention Cross-Attention
機械翻訳 ソース文の文脈を表現する 翻訳生成の際にソース情報を取得する
テキスト要約 文書全体の文脈をモデル化する 要約生成のために文書情報を取得する

簡単な数値例

ここでは、機械翻訳における計算の流れを示す。

英語の文を日本語に翻訳したいとする。

I love machine learning

エンコーダ・デコーダ型の翻訳モデルでは、まずエンコーダが self-attention を用い、その後デコーダが cross-attention を用いてエンコーダの出力にアクセスする。

Self-Attention(エンコーダ)

Self-attention では、各トークンが 同一シーケンス内のすべてのトークン(自分自身を含む) を参照し、自身の表現を更新するためにどのトークンが最も有用かを判断する。

以下の計算の目的は、トークン love が文全体を利用してどのように自分自身の文脈化された表現を構築するかを示すことである。

Step 1: トークン埋め込み

トークン 埋め込み
I [1.0, 0.5]
love [0.8, 1.2]
machine [1.5, 0.7]
learning [0.6, 1.8]

注:実際には、埋め込みの値は学習によって得られる。ここでは計算の流れを分かりやすくするため、ダミーの値を用いている。

埋め込みは、以下では $X$ と表記する。

Step 2: すべてのトークンの Q, K, V を計算する

各トークンの現在の表現は、Query、Key、Value ベクトルへ射影される。

Transformer アーキテクチャでは、最初のエンコーダ層への入力はトークン埋め込みに位置情報を加えたものです。しかし、この例では単純化のため位置情報を無視します。

成分 計算 意味
Query $Q = X \cdot W_Q$ 現在のトークンが何を探しているか
Key $K = X \cdot W_K$ このトークンがどのような内容を持つかを表す記述子であり、他のトークンがそれとの関連性を判断するためのもの
Value $V = X \cdot W_V$ そのトークンに注意が向けられたときに提供される情報
  • トークン埋め込みの行列:$X$(値は上の Step 1 に示したもの)
  • 射影行列:
W_Q =
\begin{bmatrix}
1.0 & 0.2 \\
0.1 & 0.9
\end{bmatrix}
W_K =
\begin{bmatrix}
0.6 & 0.1 \\
-0.3 & 1.0
\end{bmatrix}
W_V =
\begin{bmatrix}
1.2 & 0.0 & 0.5 \\
-0.2 & 0.8 & 0.3
\end{bmatrix}

注:実際には、$X$ と射影行列の値はいずれもモデルの学習後に得られる。ここでは理解を容易にするため、計算過程を示す目的でダミーの値を用いている。

例えば、トークン love の場合、計算過程は次のようになる。

love の埋め込みは次の通りである:

x_{\text{love}} = [0.8,\ 1.2]

この行ベクトルを、それぞれの射影行列に掛ける。

  • love の Query
Q_{\text{love}} = x_{\text{love}} ⋅ W_Q
= [0.8,\ 1.2] ⋅ 
\begin{bmatrix}
1.0 & 0.2 \\
0.1 & 0.9
\end{bmatrix} \\
= \left[
0.8 \cdot 1.0 + 1.2 \cdot 0.1,\
0.8 \cdot 0.2 + 1.2 \cdot 0.9
\right]
= [0.92,\ 1.24]
  • love の Key
K_{\text{love}} = x_{\text{love}} ⋅ W_K
= [0.8,\ 1.2] ⋅ 
\begin{bmatrix}
0.6 & 0.1 \\
-0.3 & 1.0
\end{bmatrix}
= [0.12,\ 1.28]
  • love の Value
V_{\text{love}} = x_{\text{love}} ⋅ W_V
= [0.8,\ 1.2] ⋅ 
\begin{bmatrix}
1.2 & 0.0 & 0.5 \\
-0.2 & 0.8 & 0.3
\end{bmatrix}
= [0.72,\ 0.96,\ 0.76]

すべてのトークンに対する射影後のベクトル:

トークン Q K V
I [1.05, 0.65] [0.45, 0.60] [1.10, 0.40, 0.65]
love [0.92, 1.24] [0.12, 1.28] [0.72, 0.96, 0.76]
machine [1.57, 0.93] [0.69, 0.85] [1.66, 0.56, 0.96]
learning [0.78, 1.74] [-0.18, 1.86] [0.36, 1.44, 0.84]

Step 3: トークン love がすべてのトークン(自分自身を含む)に向ける attention スコアを計算する

love の Query:

Q_{\text{love}} = [0.92, 1.24]

各 Key とのドット積:

\begin{aligned}
Q_{\text{love}} \cdot K_I &= 1.158 \\
Q_{\text{love}} \cdot K_{\text{love}} &= 1.6976 \\
Q_{\text{love}} \cdot K_{\text{machine}} &= 1.6888 \\
Q_{\text{love}} \cdot K_{\text{learning}} &= 2.1408
\end{aligned}

生の attention スコア:

[1.158, 1.6976, 1.6888, 2.1408]

$\sqrt{d_k} = \sqrt{2}$ でスケーリングする:

Q と K は 2 次元ベクトルなので、$d_k = 2$ である。

[0.8188, 1.2004, 1.1942, 1.5138]

Step 4: Softmax

\begin{aligned}
softmax([0.8188, 1.2004, 1.1942, 1.5138])
&\approx [0.1688, 0.2472, 0.2457, 0.3383]
\end{aligned}

補足:softmax は、重みの合計が 1 になるように正規化するためのもの。

したがって love

  • I には少しだけ注意を向け
  • 自分自身と machine にはほぼ同程度に注意を向け
  • learning に最も強く注意を向ける

この例では、learning が最も大きい attention 重みを受け取るため、love の更新された表現に最も強く寄与する。

Step 5: Value の重み付き和

Values:

\begin{aligned}
V_I &= [1.10, 0.40, 0.65] \\
V_{\text{love}} &= [0.72, 0.96, 0.76] \\
V_{\text{machine}} &= [1.66, 0.56, 0.96] \\
V_{\text{learning}} &= [0.36, 1.44, 0.84]
\end{aligned}

上で計算した attention 重み $[0.1688, 0.2472, 0.2457, 0.3383]$ を用いて重み付き和を計算する:

\begin{aligned}
&0.1688 \cdot [1.10, 0.40, 0.65] \\
&+ 0.2472 \cdot [0.72, 0.96, 0.76] \\
&+ 0.2457 \cdot [1.66, 0.56, 0.96] \\
&+ 0.3383 \cdot [0.36, 1.44, 0.84] \\
&\approx [0.8933, 0.9295, 0.8176]
\end{aligned}

Self-attention の結果

love の埋め込み

[0.8, 1.2]

から、Attention の計算によって love の次の文脈表現が得られる:

[0.893, 0.930, 0.818]

このベクトルには 文全体の情報 が含まれており、以降の処理に利用できる。

Cross-Attention(デコーダ)

エンコーダがソース文を処理した後、デコーダは翻訳の生成を開始する。

実際の Transformer デコーダでは、cross-attention の前に、生成済みトークンに対する masked self-attention も行われる。ここでは cross-attention の仕組みに焦点を当てるため、その部分は省略する。

デコーダがすでに次のように生成しているとする:

私は

ここでは attention メカニズムに焦点を当てるため、すでに生成されたトークンがどのように生成されたかについては扱わない。

次のトークンを生成するために、デコーダは英語の文からの情報を必要とする。

これは cross-attention によって実現される。ここでは次のようになる:

  • Query $Q$デコーダ から来る
  • Key $K$Value $V$エンコーダ から来る

言い換えると、デコーダは現在の表現を用いて エンコーダの出力に対して照会(Query)を行い、次のトークンを予測するために必要なソース情報を重みに応じて取得する。

Step 1: エンコーダの表現を取得する

エンコーダが次のような表現を生成したと仮定する。

エンコーダのトークン エンコーダ出力の表現ベクトル
I [1.020, 0.910]
love [0.880, 1.020]
machine [1.140, 0.860]
learning [0.760, 1.150]

ここで示している表現ベクトルは、self-attention 演算の生の出力そのものではない。

実際には、デコーダの cross-attention が利用するベクトルは、エンコーダの self-attention 出力をそのまま使うわけではなく、その間にいくつかの追加処理が行われる。

本記事では説明を簡潔にするため、この部分の処理については扱わない。

Step 2: デコーダで表現と射影行列を用意する

  • トークン 私は に対する現在のデコーダ表現:$X_{\text{私は}} = [0.800, 1.500]$
  • cross-attention 用の射影行列:
W_Q =
\begin{bmatrix}
0.9 & 0.2 \\
0.3 & 0.8
\end{bmatrix}
W_K =
\begin{bmatrix}
0.5 & 0.1 \\
-0.4 & 1.1
\end{bmatrix}
W_V =
\begin{bmatrix}
1.1 & 0.0 & 0.4 \\
-0.2 & 0.9 & 0.2
\end{bmatrix}

注:前述の通り、実際にはこれらの値はモデルの学習過程で学習される。ここでは計算過程を追いやすくするため、説明用のダミー値を用いている。

Step 3: デコーダの Query を計算する

\begin{aligned}
Q_{\text{私は}}
&= x_{\text{私は}} \cdot W_Q \\

&=
[0.800,\ 1.500]
\begin{bmatrix}
0.9 & 0.2 \\
0.3 & 0.8
\end{bmatrix} \\

&=
\left[
0.800 \cdot 0.9 + 1.500 \cdot 0.3,\;
0.800 \cdot 0.2 + 1.500 \cdot 0.8
\right] \\

&=
[1.170,\ 1.360]
\end{aligned}

Step 4: エンコーダ表現から Key と Value を計算する

上で示したように、エンコーダの出力は次の通りである。

エンコーダのトークン エンコーダ出力の表現ベクトル
I [1.020, 0.910]
love [0.880, 1.020]
machine [1.140, 0.860]
learning [0.760, 1.150]
  • Key:
\begin{aligned}
K_I &= x_I \cdot W_K
= [1.020,0.910]
\begin{bmatrix}
0.5 & 0.1 \\
-0.4 & 1.1
\end{bmatrix}
= [0.146,\ 1.103] \\

K_{\text{love}} &= x_{\text{love}} \cdot W_K
= [0.880,1.020]
\begin{bmatrix}
0.5 & 0.1 \\
-0.4 & 1.1
\end{bmatrix}
= [0.032,\ 1.210] \\

K_{\text{machine}} &= x_{\text{machine}} \cdot W_K
= [1.140,0.860]
\begin{bmatrix}
0.5 & 0.1 \\
-0.4 & 1.1
\end{bmatrix}
= [0.226,\ 1.060] \\

K_{\text{learning}} &= x_{\text{learning}} \cdot W_K
= [0.760,1.150]
\begin{bmatrix}
0.5 & 0.1 \\
-0.4 & 1.1
\end{bmatrix}
= [-0.080,\ 1.341]
\end{aligned}
  • Value:
\begin{aligned}
V_I &= x_I \cdot W_V
= [1.020,0.910]
\begin{bmatrix}
1.1 & 0.0 & 0.4 \\
-0.2 & 0.9 & 0.2
\end{bmatrix}
= [0.940,\ 0.819,\ 0.590] \\

V_{\text{love}} &= x_{\text{love}} \cdot W_V
= [0.880,1.020]
\begin{bmatrix}
1.1 & 0.0 & 0.4 \\
-0.2 & 0.9 & 0.2
\end{bmatrix}
= [0.764,\ 0.918,\ 0.556] \\

V_{\text{machine}} &= x_{\text{machine}} \cdot W_V
= [1.140,0.860]
\begin{bmatrix}
1.1 & 0.0 & 0.4 \\
-0.2 & 0.9 & 0.2
\end{bmatrix}
= [1.082,\ 0.774,\ 0.628] \\

V_{\text{learning}} &= x_{\text{learning}} \cdot W_V
= [0.760,1.150]
\begin{bmatrix}
1.1 & 0.0 & 0.4 \\
-0.2 & 0.9 & 0.2
\end{bmatrix}
= [0.606,\ 1.035,\ 0.534]
\end{aligned}

Step 5: Attention スコア

\begin{aligned}
Q_{\text{私は}} \cdot K_I &= 1.6709 \\
Q_{\text{私は}} \cdot K_{\text{love}} &= 1.6830 \\
Q_{\text{私は}} \cdot K_{\text{machine}} &= 1.7060 \\
Q_{\text{私は}} \cdot K_{\text{learning}} &= 1.7302
\end{aligned}

生の attention スコア:

[1.6709,\ 1.6830,\ 1.7060,\ 1.7302]

$\sqrt{2}$ でスケーリングする:

[1.1815,\ 1.1901,\ 1.2063,\ 1.2234]

Step 6: Softmax

softmax([1.1815,1.1901,1.2063,1.2234])
\approx
[0.2453,\ 0.2474,\ 0.2515,\ 0.2558]

Step 7: 重み付き和

attention 重み

[0.2453,\ 0.2474,\ 0.2515,\ 0.2558]

を各トークンの Value

\begin{aligned}
V_I &= [0.940,\ 0.819,\ 0.590] \\
V_{\text{love}} &= [0.764,\ 0.918,\ 0.556] \\
V_{\text{machine}} &= [1.082,\ 0.774,\ 0.628] \\
V_{\text{learning}} &= [0.606,\ 1.035,\ 0.534]
\end{aligned}

に適用して、重み付き和を計算する。

\begin{aligned}
&0.2453 \cdot [0.940,0.819,0.590] \\
&+ 0.2474 \cdot [0.764,0.918,0.556] \\
&+ 0.2515 \cdot [1.082,0.774,0.628] \\
&+ 0.2558 \cdot [0.606,1.035,0.534] \\
&\approx [0.8467, 0.8874, 0.5768]
\end{aligned}

Cross-attention の結果

重み付き和

[0.847,\ 0.887,\ 0.577]

は、現在のデコーダトークン 私は に対する cross-attention の出力ベクトル である。

このベクトルは、デコーダトークン 私は に対して、ソース文 I love machine learning から重みに応じて混合された情報を表している。

注:$[0.847,\ 0.887,\ 0.577]$ は 次の日本語の単語そのものではない

  • これはトークン 私は に対する 中間ベクトル であり、英語文の関連部分を要約している。
  • デコーダはこのベクトルを後続の層で利用し、次のトークンの予測に役立てる。
  • 簡潔にするため、本記事では次のトークンを予測する仕組みについては扱わない。

Multi-head attention

Multi-head attention は、複数の self-attention または cross-attention の計算を並列に実行し、その結果を結合することで、トークン間のさまざまな関係を捉えられるようにする仕組みである。

簡単な数値例

上と同じ例を用い、self-attention の計算を multi-head self-attention に拡張する。

cross-attention においても multi-head attention の考え方は同じである。
同じ入力の隠れ状態を用いて複数の attention 計算を行うが、各ヘッドはそれぞれ異なる学習された射影行列を使用し、その結果を最後に結合する。

自然言語では、あるトークンの意味が 複数種類の情報に同時に依存する 場合がある。

I love machine learning という文では、love という単語は次のようなものと関係する可能性がある。

  • learning(意味的な関係)
  • I(文法上の主語)
  • machine learning(話題となる文脈)

単一の attention 計算では、これらの情報を 一つの観点にまとめて扱う ことになる。

multi-head attention では、モデルが文を 複数の観点から同時に見る ことができ、異なる種類の関係を捉えやすくなる。

Step 1: 各ヘッドが self-attention を実行する

各ヘッドは、self-attention セクションで説明したものとまったく同じ attention 計算の手順を実行するが、異なる射影行列を用いる。

もし 3 つのヘッドを使用する場合、次のようになる:

(W_Q^{(1)},W_K^{(1)},W_V^{(1)})
(W_Q^{(2)},W_K^{(2)},W_V^{(2)})
(W_Q^{(3)},W_K^{(3)},W_V^{(3)})

したがって、各ヘッドは それぞれ独自の文脈表現 を生成する。

トークン love が次のような出力を生成すると仮定する。

ヘッド 1:

z_{\text{love}}^{(1)}=[0.893,0.930,0.818]

ヘッド 2:

z_{\text{love}}^{(2)}=[1.020,0.710,0.640]

ヘッド 3:

z_{\text{love}}^{(3)}=[0.640,1.210,0.580]

各ヘッドは 文中の異なる関係 を捉えやすくなる。

Step 2: ヘッドの結果を連結する

出力は 連結(concatenate) され、1つのベクトルになる。

\begin{aligned}
z_{\text{concat}}
&= [0.893,0.930,0.818] \\
&\quad \Vert [1.020,0.710,0.640] \\
&\quad \Vert [0.640,1.210,0.580]
\end{aligned}

結果:

\begin{aligned}
z_{\text{concat}} = [&0.893,\ 0.930,\ 0.818,\ 1.020,\ 0.710,\\
&0.640,\ 0.640,\ 1.210,\ 0.580]
\end{aligned}

3つのヘッドの出力を結合したため、次元が増える。

Step 3: 最終射影

連結されたベクトルは、その後行列 $W_O$ を用いて射影される。

この行列の値も実際には学習によって得られるが、ここでは計算を分かりやすく示すためにダミーの値を用いる。

例:

W_O =
\begin{bmatrix}
0.40 & 0.00 & 0.00 \\
0.00 & 0.40 & 0.00 \\
0.00 & 0.00 & 0.40 \\
0.35 & 0.00 & 0.00 \\
0.00 & 0.35 & 0.00 \\
0.00 & 0.00 & 0.35 \\
0.25 & 0.00 & 0.00 \\
0.00 & 0.25 & 0.00 \\
0.00 & 0.00 & 0.25
\end{bmatrix}

次のように計算する。

\begin{aligned}
z_{\text{final}}
&= z_{\text{concat}} \cdot W_O \\
&= [0.893,\,0.930,\,0.818,\,1.020,\,0.710,\\
&\quad 0.640,\,0.640,\,1.210,\,0.580] \cdot W_O
\end{aligned}

結果:

z_{\text{final}} = [0.8742,\ 0.9230,\ 0.6962]

Multi-head attention の結果

最終的なベクトル

[0.8742,\ 0.9230,\ 0.6962]

は、トークン love に対する multi-head attention の出力 である。

このベクトルは、複数の attention ヘッド が得た情報を結合したものであり、それぞれのヘッドは文中の異なる関係に焦点を当てやすい。

この表現は、その後の処理に利用される。

参考文献

  • Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention Is All You Need. arXiv:1706.03762.
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?