Transformer や Attention の解説を読んでいると、$QK^T$ や $AV$ のような行列積が何度も出てきます。
行列積そのものは線形代数で習う基本的な演算ですが、AI の式の中に入った瞬間に、「この行列積は何をしているのか」が見えにくくなることがあります。
この記事では、行列を「点の座標ベクトルを縦に積んだもの」、つまり ベクトル列 として読んでみます。そして、行列積を次の2種類に分けて考えます。
- $AB$ 型:$B$ の点群を結合する
- $AB^T$ 型:$B$ の点群と照合する
この見方をすると、Attention に出てくる $QK^T$ は「照合」、$AV$ は「結合」として読みやすくなります。
もともとは、大学のデータサイエンス授業で Transformer と Attention を説明するために整理した補足資料です。授業用に作ったものですが、同じように Attention の式を読むための足場を探している方の参考になればと思い、公開します。
なお、私は大学教員ですが、専門は生物学です。この記事は、行列積について厳密な数学的証明を与えることを目的としたものではありません。AI の式を読むときに、行列積に対する認識を少し切り替えると見通しがよくなる、という趣旨の整理です。
本文は授業用補足資料をほぼそのまま載せているため、「この授業では」という表現を残しています。また、$AB$ 型・$AB^T$ 型という呼び方は、授業用に名付けた便宜的な呼び方で、一般名称ではありません。
補足資料:アテンションの数学
1. 行列とベクトル列
はじめに、$m$ 行 $\ell$ 列 ($m \times \ell$) の二次元行列 $A$ を考える。
A =
\begin{bmatrix}
a_{11} & a_{12} & \cdots & a_{1\ell} \\
a_{21} & a_{22} & \cdots & a_{2\ell} \\
\vdots & \vdots & \ddots & \vdots \\
a_{m1} & a_{m2} & \cdots & a_{m\ell}
\end{bmatrix}
この $A$ を、$\ell$ 次元空間における、$m$ 個の点の座標を表すベクトルを束ね、縦に積んだものとみなすことを考える。この授業では、このような二次元行列のことを、ベクトル列と呼ぶ。
たとえば、
A =
\begin{bmatrix}
\boldsymbol{a}_1 \\
\boldsymbol{a}_2 \\
\vdots \\
\boldsymbol{a}_m
\end{bmatrix}
\in \mathbb{R}^{m \times \ell}
と書くとき、$A$ は $m$ 個のベクトル
\boldsymbol{a}_1, \boldsymbol{a}_2, \dots, \boldsymbol{a}_m
を縦に積んだものとして表せる。
このとき、各行ベクトル $\boldsymbol{a}_i$ は、ある空間における第 $i$ 点の座標を表す $\ell$ 次元ベクトルである。
\boldsymbol{a}_i =
\begin{bmatrix}
a_{i1} & a_{i2} & \cdots & a_{i\ell}
\end{bmatrix}
ここで、このように定義したベクトル列 $A, B, C$ を考える。
ベクトル列 $A, C$ の形状は、
A \in \mathbb{R}^{m \times \ell}, \quad C \in \mathbb{R}^{m \times n}
とする。一方、$B$ については、演算ごとに、$A, C$ に基づき、定まるものとする。
2. 行列積と写像
次に、2種類の行列積
C = AB
および、
C = AB^T
について考える。この授業では便宜的に、前者を $AB$ 型、後者を $AB^T$ 型と呼ぶ。このとき、いずれの場合も、右辺の左側のベクトル列 $A$ の第 $i$ 点は、左辺のベクトル列 $C$ の第 $i$ 点に対応する(後述)。
すなわち、いずれの場合も、
A \longrightarrow C
という、$\ell$ 次元空間の $m$ 個の点群から $n$ 次元空間の $m$ 個の点群への写像とみなせる。このとき、$A$ の第 $i$ 行の点が、$C$ の第 $i$ 行の点に対応する。
つまり、第 $i$ 行の点に注目すると、
\boldsymbol{a}_i \longmapsto \boldsymbol{c}_i
である。
3. AB 型の行列積:結合
まず、
C = AB
について考える。
ここで、
A=
\begin{bmatrix}
\boldsymbol{a}_1 \\
\boldsymbol{a}_2 \\
\vdots \\
\boldsymbol{a}_m
\end{bmatrix} \in \mathbb{R}^{m \times \ell},
\qquad
B=
\begin{bmatrix}
\boldsymbol{b}_1 \\
\boldsymbol{b}_2 \\
\vdots \\
\boldsymbol{b}_{\ell}
\end{bmatrix} \in \mathbb{R}^{\ell \times n},
\qquad
C=
\begin{bmatrix}
\boldsymbol{c}_1 \\
\boldsymbol{c}_2 \\
\vdots \\
\boldsymbol{c}_m
\end{bmatrix} \in \mathbb{R}^{m \times n}
$A$ の各行ベクトルを成分表示まで書き下すと、
A =
\begin{bmatrix}
\begin{bmatrix}
a_{11} & a_{12} & \cdots & a_{1\ell}
\end{bmatrix}
\\
\begin{bmatrix}
a_{21} & a_{22} & \cdots & a_{2\ell}
\end{bmatrix}
\\
\vdots
\\
\begin{bmatrix}
a_{m1} & a_{m2} & \cdots & a_{m\ell}
\end{bmatrix}
\end{bmatrix}
である。
したがって、$C=AB$ を計算すると、次のようになる。
AB=
\begin{bmatrix}
\begin{bmatrix}
a_{11} & a_{12} & \cdots & a_{1\ell}
\end{bmatrix}
\\
\begin{bmatrix}
a_{21} & a_{22} & \cdots & a_{2\ell}
\end{bmatrix}
\\
\vdots
\\
\begin{bmatrix}
a_{m1} & a_{m2} & \cdots & a_{m\ell}
\end{bmatrix}
\end{bmatrix}
\begin{bmatrix}
\boldsymbol{b}_1 \\
\boldsymbol{b}_2 \\
\vdots \\
\boldsymbol{b}_{\ell}
\end{bmatrix}
=
\begin{bmatrix}
a_{11}\boldsymbol{b}_1 + a_{12}\boldsymbol{b}_2 + \cdots + a_{1\ell}\boldsymbol{b}_{\ell} \\
a_{21}\boldsymbol{b}_1 + a_{22}\boldsymbol{b}_2 + \cdots + a_{2\ell}\boldsymbol{b}_{\ell} \\
\vdots \\
a_{m1}\boldsymbol{b}_1 + a_{m2}\boldsymbol{b}_2 + \cdots + a_{m\ell}\boldsymbol{b}_{\ell}
\end{bmatrix}
=
\begin{bmatrix}
\boldsymbol{c}_1 \\
\boldsymbol{c}_2 \\
\vdots \\
\boldsymbol{c}_m
\end{bmatrix}
=
C
$C$ の第 $i$ 行の点 $\boldsymbol{c}_i$ を抜き出すと、
\boldsymbol{c}_i
=
a_{i1}\boldsymbol{b}_1
+
a_{i2}\boldsymbol{b}_2
+
\cdots
+
a_{i\ell}\boldsymbol{b}_{\ell}
=
\sum_{j=1}^{\ell} a_{ij}\boldsymbol{b}_j
=
\boldsymbol{a}_i B
である。
したがって、$AB$ 型行列積では、$A$ の対応する点 $\boldsymbol{a}_i$ の各成分
a_{i1}, a_{i2}, \dots, a_{i\ell}
を重みとして、$B$ を構成する全ての点
\boldsymbol{b}_1, \boldsymbol{b}_2, \dots, \boldsymbol{b}_{\ell}
の情報を結合することで、$C$ の対応する点 $\boldsymbol{c}_i$ を作る。
つまり、$AB$ 型行列積は、$A$ の各点を、その成分に基づいて $B$ の全ての点の情報を結合し、$C$ の対応する点へ写し、その結果を縦に積んで、ベクトル列 $C$ を作る操作として読める。
4. AB^T 型の行列積:照合
次に、
C = AB^T
について考える。
ここで、
A=
\begin{bmatrix}
\boldsymbol{a}_1 \\
\boldsymbol{a}_2 \\
\vdots \\
\boldsymbol{a}_m
\end{bmatrix} \in \mathbb{R}^{m \times \ell},
\qquad
B=
\begin{bmatrix}
\boldsymbol{b}_1 \\
\boldsymbol{b}_2 \\
\vdots \\
\boldsymbol{b}_n
\end{bmatrix} \in \mathbb{R}^{n \times \ell},
\qquad
C=
\begin{bmatrix}
\boldsymbol{c}_1 \\
\boldsymbol{c}_2 \\
\vdots \\
\boldsymbol{c}_m
\end{bmatrix} \in \mathbb{R}^{m \times n}
である。
$B$ を転置すると、$\ell$ 次元ベクトル $\boldsymbol{b}_1, \boldsymbol{b}_2, \dots, \boldsymbol{b}_n$ が横に並んだ $B^T$ を得る。
B^T
=
\begin{bmatrix}
\boldsymbol{b}_1^T & \boldsymbol{b}_2^T & \cdots & \boldsymbol{b}_n^T
\end{bmatrix}
\in \mathbb{R}^{\ell \times n}
したがって、$C=AB^T$ を計算すると、次のようになる。
AB^T
=
\begin{bmatrix}
\boldsymbol{a}_1 \\
\boldsymbol{a}_2 \\
\vdots \\
\boldsymbol{a}_m
\end{bmatrix}
\begin{bmatrix}
\boldsymbol{b}_1^T & \boldsymbol{b}_2^T & \cdots & \boldsymbol{b}_n^T
\end{bmatrix}
=
\begin{bmatrix}
\boldsymbol{a}_1 \cdot \boldsymbol{b}_1
&
\boldsymbol{a}_1 \cdot \boldsymbol{b}_2
&
\cdots
&
\boldsymbol{a}_1 \cdot \boldsymbol{b}_n
\\
\boldsymbol{a}_2 \cdot \boldsymbol{b}_1
&
\boldsymbol{a}_2 \cdot \boldsymbol{b}_2
&
\cdots
&
\boldsymbol{a}_2 \cdot \boldsymbol{b}_n
\\
\vdots
&
\vdots
&
\ddots
&
\vdots
\\
\boldsymbol{a}_m \cdot \boldsymbol{b}_1
&
\boldsymbol{a}_m \cdot \boldsymbol{b}_2
&
\cdots
&
\boldsymbol{a}_m \cdot \boldsymbol{b}_n
\end{bmatrix}
=
\begin{bmatrix}
\boldsymbol{c}_1 \\
\boldsymbol{c}_2 \\
\vdots \\
\boldsymbol{c}_m
\end{bmatrix}
=
C
$C$ の第 $i$ 行の点 $\boldsymbol{c}_i$ を抜き出すと、
\boldsymbol{c}_i
=
\begin{bmatrix}
\boldsymbol{a}_i \cdot \boldsymbol{b}_1
&
\boldsymbol{a}_i \cdot \boldsymbol{b}_2
&
\cdots
&
\boldsymbol{a}_i \cdot \boldsymbol{b}_n
\end{bmatrix}
=
\boldsymbol{a}_i B^T
である。
したがって、$AB^T$ 型行列積では、$A$ の対応する点 $\boldsymbol{a}_i$ と、$B$ を構成する全ての点
\boldsymbol{b}_1, \boldsymbol{b}_2, \dots, \boldsymbol{b}_n
とを内積で照合することで、$C$ の対応する点 $\boldsymbol{c}_i$ を作る。
つまり、$AB^T$ 型行列積は、$A$ の各点を、$B$ の全ての点との内積による照合結果を表す点へ写し、その結果を縦に積んで、ベクトル列 $C$ を作る操作として読める。
おわりに
この記事では Attention を例に、行列積をベクトル列の写像として読む見方を整理しました。
ここで述べた読み方は、Attention に限られたものではありません。データを、各行が1つの対象・部分・時点・トークン・パッチなどを表すベクトル列として扱う限り、$AB^T$ 型の行列積は $A$ と $B$ の点群どうしの内積による照合、$AB$ 型の行列積は $A$ の各点を重みとした $B$ の全点群情報の結合として読むことができます。
Attention は、この「照合して、配分に変換し、結合する」という構造が、特に明示的に現れる例だと考えられます。
行列積を数式としてだけでなく、点群が別の点群へ写されていく過程として眺めると、AI の式の中で何が起きているのかを少し追いやすくなります。この記事が、Attention やデータサイエンスの式を読むための、小さな足場になれば幸いです。