はじめに
この記事は、ChatGPTやLLMの各所で活用されるTransfomerに対して、ベースとなる数学から理解を試みようとした際のメモです。高校数学の基礎的な話題から、必要最低限の知識に絞り最短距離で理解を目指します。
ターゲットとして大規模言語モデル入門(以降、本書)を読み進めるにあたり最低限必要な数学知識をメモしながら、内容の理解を目指します。
このメモは基本自分向けですが、もしかすると他の方にも役立つかも知れないと思い公開します。
初学者の観点ですので誤り多いと思います。気づいた方はコメントを残していただけると大変うれしいです。
お品書き
- ベクトルの知識
- 内積に関する知識
- 行列と行列の和、積
- 行列とベクトルの関係、および、LLMとの関連
- 行列の積と内積演算の関係
- 確率・統計(Transformerで使うものだけ)
- LLMの学習プロセスの大雑把な構造
- 自然言語とベクトルの関係(埋め込み行列とOneHotベクトル)
- Transformerの理解に挑戦(概要)
- Encoderの理解に挑戦
10-1. Input Embedding(入力埋め込み)
10-2. Positional Encoding(位置符号)
10-3. self-attention(自己注意機構)
10-4. Multi-Head Attention(マルチヘッド注意機構)
10-5. Feed Forward(フィードフォワード層)
10-6. 残差結合
10-7. Add & Norm(層正規化) - Encoder/Decoderの理解に挑戦(今後予定)
11-1. masked multi head attention(今後予定)
11-2. 交差注意機構(今後予定)
微積分については記載を省いています。以下理由になります。
あとで後述しますが、深層学習の本質的なポイントは微分(偏微分)です。
しかし、PyTorchを代表としたありがたいライブラリが数値微分という機構をサポートしているおかげで、我々は深層学習を役立てるために微分を計算する必要がなくなっています。
こういった背景もあるせいか、本書を読み進めるうえで、微分の知識はほぼ不要とわかりましたので、説明を省きます。
積分は統計学では必要ですが、本書では不要でした。
ベクトル
ベクトルとは平面上の向きと大きさ(長さ)を持った矢印としてとして表現されますが、第一の理解としてはそれが良い(入りやすさとして)と思います。
線形代数の基礎 第2回 - ベクトル(1)が詳しいので、このページをまず参考にされるのが良いかと思います。
ベクトルの要素が3つ(3次元)までは、3次元空間座標上の矢印として認識できるのですが、4次元以降は理解しにくいです。
そこで、プログラミング言語風の理解で、「ベクトルは単に一次元の配列である」と理解するとこの先の理解が楽になると思います。
vec = [1,2,3,4,5]
ここで、pythonプログラミング言語上は1次元の配列ですが、この配列をベクトルとみなせば、要素を5つもった5次元のベクトルとして理解できます。
このように、ベクトルはデータの配列であると理解しておけば、その先の行列演算の理解もスムースに行えることに気付きました。
ベクトルの加算・減算
ベクトルは加算・減算が可能です。単純に各々のベクトルの要素に対して演算します。
以下のn次元ベクトル$\vec{a}$と$\vec{b}$が存在した場合、
\vec{a}=(a_{1},a_{2},...,a_{n})
\vec{b}=(b_{1},b_{2},...,b_{n})
それぞれの演算は以下のように定義されます。定義なので、まずは飲み込みます。
\vec{a}-\vec{b}=(a_{1}-b_{1},a_{2}-b_{2},...,a_{n}-b_{n})
\vec{a}+\vec{b}=(a_{1}+b_{1},a_{2}+b_{2},...,a_{n}+b_{n})
ベクトルの大きさ
ベクトルの大きさとは何でしょうか。まず、2次元のベクトルで考えてみましょう。
\vec{a} = (1,2)
最初に、このベクトルの各要素に意味づけを与える必要があります。ここでは、第1要素目がx座標、第2要素目がy座標とします。
そうすると、$\vec{a}$は2次元平面に対応することが出来ます。ベクトルの大きさとは矢印の大きさ(長さ)ですから、三平方の定理を使って簡単に求める事ができます。
\lVert a \lVert = \sqrt{1*1 + 2*2} = \sqrt{5}
これを2次元の範囲で一般化すると、以下のようになります。
\lVert a \lVert = \sqrt{x^2 + y^2}
また、n次元の場合は、$\vec{a} = (a_{1},a_{2},a_{3},...,a_{n})$とすれば、ベクトルの大きさは、上記2次元の場合を素直に拡張して以下のようになります。n次元ベクトルの大きさの定義を紹介します。
\lVert a \lVert = \sqrt{a_{1}^2 + a_{2}^2 + ... + a_{n}^2}
4次元以上の場合、形が想像できないため納得感があまりありませんが、定義なのでとりあえず飲み込むしかありません。数学の定義とはそういうものだと理解するべきです(なにかの数学書に記載してあったのであとで補足する)
ベクトルの類似
LLMでは、ベクトルが似ているかどうかを考えるシーンが良く出てきます。内積はベクトルの類似度を考える上でのベースとなる定義/概念になります。
まず、ベクトルの類似度を考えてみます。理解のために2次元の場合に戻ります。
ベクトルは向きと大きさを持った量ですので、この2つが同じであればあるほど、ベクトルは似ているとみなすことができます。以下、$\vec{a}$と$\vec{b}$は大きさが違いますが、向きは同じベクトルです。$\vec{a}$と$\vec{c}$は大きさが同じですが向きが違うパターンです。$\vec{a}$と$\vec{b}$、$\vec{a}$と$\vec{c}$の2パターンを見比べると、向きが違うと全然違うベクトルのように見えると直感的に思います。このため、ベクトルの類似度は、まずは大雑把には「向き」を考えれば良さそうです。
ベクトルは向きと大きさを持った量ですので、ベクトルの類似度を考えるうえで矢印の始点は実は重要ではありませんので、ベクトルの始点を一致させて(物事をすこし単純にして)、2つのベクトルの類似度を考えることが可能です。始点を合わせた$\vec{a}$と$\vec{d}$をみると、両者のなす角が定義できます。
ここで、なす角を$\theta$として、おもむろに$\cos(\theta)$を考えます。$-1\leq \cos(\theta) \leq 1$ですから、$\vec{a}$と$\vec{d}$の開きが小さい場合(0度〜90度)の場合は$\cos(\theta)$は0から1の間の値を取ります。また、開きが大きくなった(90度〜180度)場合は-1から0の値を取ります。なので、$\cos(\theta)$はベクトルの向き(大雑把な類似度)を考えるうえで、重要な情報であることが分かると思います。
ところで、この$\cos(\theta)$はどのように計算すれば良いのでしょうか。それには、以下の内積が役に立ちます。
内積について
ここで、内積というものを定義します。
\vec{a}\cdot\vec{b} = \lVert \vec{a} \lVert\lVert \vec{b} \lVert\cos(\theta)
これについてはまずは、飲み込むしかありません。また、内積はもう1つ別の定義があります。
\vec{a}\cdot\vec{b} = a_{1}b_{1} + a_{2}b_{2} + ... + a_{n}b_{n}
つまり、
a_{1}b_{1} + a_{2}b_{2} + ... + a_{n}b_{n} = \lVert \vec{a} \lVert\lVert \vec{b} \lVert\cos(\theta)
です。左辺の値が計算できれば、2つのベクトルの方向が分かる。ということになります。内積すばらしい!角度という少し複雑な値の正体が実はベクトルの要素同士の簡単な演算で求められる。
なぜ、左辺と右辺が等しくなるのかについて証明を載せておきます。
最初の内積の定義式から始めます。
最初に、以下が成り立つことを確認します。内積の定義にしたがって計算すればすぐにわかります。
\vec{a}\cdot\vec{a} = \lVert \vec{a} \lVert^2
次に、この式の$\vec{a}$に$\vec{a}-\vec{b}$を代入すると以下となります。
(\vec{a}-\vec{b})\cdot(\vec{a}-\vec{b}) = \lVert \vec{a} - \vec{b} \lVert^2
ここで詳細は別途説明しますが、内積は分配法則が成立しますから、左辺配下のようになります。
\vec{a}\cdot\vec{a}-\vec{a}\cdot\vec{b}-\vec{b}\cdot\vec{a}-\vec{b}\cdot\vec{b}= \lVert \vec{a} - \vec{b} \lVert ^ 2
また、内積は交換法則が成り立つので(別途説明)、$\vec{a}-\vec{b}=\vec{b}-\vec{a}$したがって、
\vec{a}\cdot\vec{a}-\vec{a}\cdot\vec{b}-\vec{a}\cdot\vec{b}-\vec{b}\cdot\vec{b}= \lVert \vec{a} - \vec{b} \lVert ^ 2
\vec{a}\cdot\vec{a}-2(\vec{a}\cdot\vec{b}) +\vec{b}\cdot\vec{b}= \lVert \vec{a} - \vec{b} \lVert ^ 2
ここで、$\vec{a}\cdot\vec{a}=\lVert \vec{a} \lVert^2$($\vec{b}$も同様)なので、この結果を代入すれば、
\lVert \vec{a} \lVert ^ 2 -2(\vec{a}\cdot\vec{b}) + \lVert \vec{b} \lVert ^ 2= \lVert \vec{a} - \vec{b} \lVert ^ 2
ところで、ベクトルの差は図形の意味では、以下のように定義されます。引く方が引かれる方のベクトルに向かっていくイメージで覚えるとよいです。
ここで余弦定理を用いると、
\lVert \vec{a} - \vec{b} \lVert ^ 2 = \lVert \vec{a} \lVert ^ 2 + \lVert \vec{b} \lVert ^ 2 -2\lVert \vec{a} \lVert \lVert \vec{b} \lVert \cos(\theta)
この結果を先ほどの式に代入すれば、
\lVert \vec{a} \lVert ^ 2 -2(\vec{a}\cdot\vec{b}) + \lVert \vec{b} \lVert ^ 2= \lVert \vec{a} \lVert ^ 2 + \lVert \vec{b} \lVert ^ 2 -2\lVert \vec{a} \lVert \lVert \vec{b} \lVert \cos(\theta)
整理して、内積の最初の定義式が得られます。角度は余弦定理からきたのですね。
\vec{a}\cdot\vec{b} = \lVert \vec{a} \lVert\lVert \vec{b} \lVert\cos(\theta)
次に、2つめの内積の定義式について説明します。1つめの内積の式の証明で出てきた、以下の式からスタートします。
\lVert \vec{a} \lVert ^ 2 -2(\vec{a}\cdot\vec{b}) + \lVert \vec{b} \lVert ^ 2= \lVert \vec{a} - \vec{b} \lVert ^ 2
ここで、簡単のために$\vec{a}$と$\vec{b}$を2次元とします。
上の式に当てはめて素直に計算すれば、
(a_{1}^2+a_{2}^2)-2(\vec{a}\cdot\vec{b})=(a_{1}-b_{1})^2+(a_{2}-b_{2})^2
これを整理して、2つめの内積の定義式が得られます。
行列
行列の定義や説明は巷に溢れているので詳細はそういったものに譲りますが、ここではLLMの仕組みの基礎を理解するうえで、必要なものにフォーカスしたいと思います。
行列とは数値を縦と横に並べたものです。単純なものは以下のような行列でしょう。行列の横の要素の並びを列と言います。縦の並びを行と言います。
以下は、4行1列の行列になります(行列のサイズをこう表す)。
\mathbf{A} \ =\ \begin{bmatrix}
1\\
2\\
3\\
4
\end{bmatrix}
今度は列を1つ追加したものを考えてみます。以下は、4行2列の行列です。
\mathbf{A} \ =\ \begin{bmatrix}
1 & 10\\
2 & 20\\
3 & 30\\
4 & 40
\end{bmatrix}
1行1列目の要素が1,3行1列目が3、2行2列目が20になります。
行と列の場所(インデックス)を指定して、行列の要素を取り出すことが出来ます。
ちょうど、pythonなどのプログラミング言語でいうところの2次元配列が行列に相当します。
行列の加算・減算
以下のように2つの行列があったとき、
\mathbf{A} \ =\ \begin{bmatrix}
1 & 10\\
2 & 20\\
3 & 30\\
4 & 40
\end{bmatrix}
\mathbf{B} \ =\ \begin{bmatrix}
100 & 1000\\
200 & 2000\\
300 & 3000\\
400 & 4000
\end{bmatrix}
加算は以下のようになります。行列のそれぞれの要素を単に足します。減算は符号を逆にしただけなので、説明は省略します。
\mathbf{A+B} \ =\ \begin{bmatrix}
1+100 & 10+1000\\
2+200 & 20+2000\\
3+300 & 30+3000\\
4+400 & 40+4000
\end{bmatrix}
行列の転置
行列の横方向を列、縦方向を行と呼びます。以下の行列では、1行1列目の要素が1で、3行2列目が30になります。
行列の転置とは行列の要素の行と列を入れ替えたものになります。行列を転置したことを表す記号をつけて、以下のようになります。
\mathbf{A}^{T} \ =\ \begin{bmatrix}
1 & 2 & 3 & 4\\
10 & 20 & 30 & 40
\end{bmatrix}
行列の積
行列の積は複雑です。詳細は以下のサイトに詳しく説明されていますので、そちらを参照できます。
数学の景色(行列の演算)
以下、ポイントだけ説明します。説明を簡単にするために2行2列の行列を例に取ります。
一般形は以下です。
\mathbf{A} \ =\ \begin{bmatrix}
a_{11} & a_{12}\\
a_{21} & a_{22}
\end{bmatrix} \ \ \mathbf{B} \ =\ \begin{bmatrix}
b_{11} & b_{12}\\
b_{21} & b_{22}
\end{bmatrix}
この積は以下になります。
\mathbf{AB} \ =\ \begin{bmatrix}
a_{11} b_{11} +a_{12} b_{21} & a_{11} b_{12} +a_{12} b_{22}\\
a_{21} b_{11} +a_{22} b_{21} & a_{21} b_{12} +a_{22} b_{22}
\end{bmatrix}
具体値を当てはめて計算してみます。
\mathbf{A} \ =\ \begin{bmatrix}
1 & 3\\
2 & 4
\end{bmatrix} \ \ \ \ \mathbf{B} \ =\ \begin{bmatrix}
10 & 30\\
20 & 40
\end{bmatrix} \ \ \ \ \ \ \ \ \ \ \
結果は以下となります。
\mathbf{AB} \ =\ \begin{bmatrix}
1*10+3*20 & 1*30+3*40\\
2*10+4*20 & 2*30+4*40
\end{bmatrix} =\ \begin{bmatrix}
70 & 150\\
100 & 220
\end{bmatrix}
行列とベクトルの関係、および、LLMとの関係
ここまで、ベクトルと行列を見てきましたが、これはLLMとどのような関係があるのでしょう?
後で詳細を述べますが、LLMでは深層学習を活用した学習/推論を行う関係上、LLM内で扱うデータはベクトルと行列である必要があります。
このため、我々が暮らす自然言語はそのままで扱うことはできないため、何らかの方法で、深層学習の世界へ変換する処理が必要です。
LLMでは例えば、埋め込みと呼ばれる処理により、日本語の1つの単語(トークン)がベクトルに変換されます。
つまり、LLMにおける処理の最小単位である単語に1:1に対応したベクトルが存在することになります。
ところで、深層学習では大量のデータが必要ですので、数学を使った仕組みの記述のためには、沢山のベクトルを一括して扱うことが必要です。
ここで、行列が役に立ちます。
すなわち、LLMにおける行列の活用方法は「行列はベクトルを沢山ならべたもの」であり、そう考えておくと後々ラクです。
また、単に数字の羅列であった行列という仕組みが活き活きと動き出すことも面白みの1つです。
行列とベクトルの関係
ベクトルを行列で表すことを考えます。
いままで、ベクトルの要素を横に書いていましたが、行列でベクトルを表す場合は縦に記載するのが基本です。
また、$\vec{a}$のように記載していましたが、行列としてベクトルを表現する際は$\boldsymbol{A}$のようにボールド体で表すこととします。以下は、$\vec{a}=(1,2,3,4)$と等価です。
\mathbf{A} \ =\ \begin{bmatrix}
1\\
2\\
3\\
4
\end{bmatrix}
このようなベクトルの記載法を縦ベクトルと言います。
さて、「行列とはベクトルを沢山ならべたものだ」ということですが、新しいベクトルを$\mathbf{A}$に追加することができます。
今度は列を1つ追加したものを考えてみます。以下は、4行2列の行列です。
\mathbf{A} \ =\ \begin{bmatrix}
1 & 10\\
2 & 20\\
3 & 30\\
4 & 40
\end{bmatrix}
$(1,2,3,4)$と$(10,20,30,40)$のベクトルを2つもつ行列になります。同様になんこでもベクトルを追加することもできます。
上記は縦ベクトルの話でしたが、横ベクトルでも同様に成り立ちます。
大規模言語モデル入門は、縦ベクトルが基本になります。
行列の転置
この行列(縦ベクトルが2つ)を転置すると、
\mathbf{A}^{T} \ =\ \begin{bmatrix}
1 & 2 & 3 & 4\\
10 & 20 & 30 & 40
\end{bmatrix}
これを見るとわかるように、転置は要するに、縦ベクトルで構成されていた行列を、横ベクトルで構成した行列に変換する操作だと考えることもできます。
行列の積に関する重要な豆知識
行列の積の演算では、掛ける方の列のサイズと、掛けられる方の行のサイズが一致する必要があります。
つまり、以下のような場合はOKです。
A=(N行M列)で、B=(M行L列)
※このようにサイズが一致しないと、内積計算ができないので、このルールには納得性があります。
なお、掛けた結果は、以下のサイズになります。
AB=(N行L列)
この法則は頭に叩き込む必要があります。
LLMではベクトルや行列の積演算が多数出てきますが、都度、サイズがどう変化するか考えておかないと、理論や仕組みの理解ができなくなるためです。
内積計算を行列で行う
以下のように、ベクトルの内積演算は行列の積で表現することが可能です。
\mathbf{A} \ =\ \begin{bmatrix}
1\\
2
\end{bmatrix} \ \ \ \ \mathbf{B} \ =\ \begin{bmatrix}
10\\
20
\end{bmatrix} \ \ \ \ \ \ \ \ \ \ \
このベクトルの内積は$1 * 10+2 *20$ですが、行列の積をそのまま適用すると、Aが(2行1列)、Bも(2行1列)なので計算できません。
そこで、Aを転置してBとの積を取ります。
\begin{array}{l}
\mathbf{A}^{T} \ =\ \begin{bmatrix}
1 & 2
\end{bmatrix} \ \ \ \ \mathbf{B} \ =\ \begin{bmatrix}
10\\
20
\end{bmatrix} \ \\
\mathbf{A}^{T} \ \mathbf{B} \ =1*10+2*20\ \ \ \ \ \ \ \ \
\end{array}
また、AとBが複数の縦ベクトルで構成される場合も同様です。
\mathbf{A} \ =\ \begin{bmatrix}
1 & 3\\
2 & 4
\end{bmatrix} \ \ \ \ \mathbf{B} \ =\ \begin{bmatrix}
10 & 30\\
20 & 40
\end{bmatrix} \ \ \ \ \ \ \ \ \ \ \
\begin{array}{l}
\mathbf{A}^{T} \ =\ \begin{bmatrix}
1 & 2\\
3 & 4
\end{bmatrix} \ \ \ \ \mathbf{B} \ =\ \begin{bmatrix}
10 & 30\\
20 & 40
\end{bmatrix} \ \ \ \\
\mathbf{A}^{T} \ \mathbf{B} \ =\ \begin{bmatrix}
1*10+2*20\\
3*30+4*40
\end{bmatrix} \
\end{array}
これにて、LLMの基本を理解するベクトルと行列の基本知識は身についたかなと思います。
確率・統計(Transformerで使うものに絞る)
ここでは、確率・統計について、EncoderのAdd & Norm層で使う知識だけに絞って解説します。
以下の表はサイコロの出る目と確率の表です。
X | 1 | 2 | 3 | 4 | 5 | 6 |
---|---|---|---|---|---|---|
P | 1/6 | 1/6 | 1/6 | 1/6 | 1/6 | 1/6 |
このときサイコロの目の値を確率変数といい、$\boldsymbol{X}$で表します。事象$\boldsymbol{X}$が発生する確率を$\boldsymbol{P}$で表しています。
また、上の表を関数に見立てて、例えば、サイコロの目が3の場合の確率を以下のように記します。
P(X=3)=\frac{1}{6}
この$P(X)$を確率関数と言ったりします。
上の表を一般的に記述すると以下のような感じになります。
X | $x_1$ | ・・・ | $x_n$ |
---|---|---|---|
P | $p_1$ | ・・・ | $p_n$ |
このとき、以下のように期待値(平均)$\mathbb{E}[\mathbf{X}]$と分散$\mathbb{V}[\mathbf{X}]$を定義します。
\mathbb{E}[\mathbf{X}]=x_1p_1+\cdots+x_np_n=\sum_{i=1}^{n}x_ip_i
ここで、$\mathbb{E}[\mathbf{X}]=\mu$とおいて、
\mathbb{V}[\mathbf{X}]=\mathbb{E}[(\mathbf{X}-\mu)^2]=(x_1-\mu)^2p_1+\cdots+(x_n-\mu)^2p_n=\sum_{i=1}^{n}(x_i-\mu)^2p_i
標準偏差は以下になります。
\sigma=\sqrt{\mathbb{V}[\mathbb{X}]}
また、期待値と分散については以下の公式が成り立ちます(証明は省略します)。
\mathbb{E}[a\mathbf{X}+b]=a\mathbb{E}[\mathbf{X}]+b
\mathbb{V}[a\mathbf{X}+b]=a^2\mathbb{V}[\mathbf{X}]
次に標準化というものを考えます。
確率変数$\mathbf{X}$の期待値、標準偏差をそれぞれ$\mu=\mathbb{E}[\mathbb{X}]$、$\sigma=\sqrt{\mathbb{V}[\mathbb{X}]}$としたときに、以下の変数変換をすると、変換後の新しい$\mathbf{Z}$の期待値は0、標準偏差は1になります。
\mathbf{Z}=\frac{\mathbf{X}-\mu}{\sigma}
以下、確かめです。$\sigma=\sqrt{\mathbb{V}[\mathbf{X}]}$より$\sigma^2=\mathbb{V}[\mathbf{X}]$に注意します。
\mathbb{E}[\mathbf{Z}]=\mathbb{E}[\frac{\mathbf{X}-\mu}{\sigma}]=\frac{1}{\sigma}\mathbb{E}[\mathbf{X}-\mu]=\frac{1}{\sigma}(\mathbb{E}[\mathbf{X}]-\mu)=\frac{1}{\sigma}(\mu-\mu)=0
\mathbb{V}[\mathbf{Z}]=\mathbb{V}[\frac{\mathbf{X}-\mu}{\sigma}]=\frac{1}{\sigma^2}\mathbb{V}[\mathbf{X}]=\frac{1}{\sigma^2}\sigma^2=1
この変数変換が層正規化(Add & Norm)に登場してくるので、意識しておきましょう!
LLMの学習プロセスの大雑把な構造
ここで、単語の並びを$w_{i}$で表すことにします。例えば「サンディーが元気に走る」という文章があったとすると、$w_{1}$=サンディー、$w_{2}$=が、$w_{3}$=元気、$w_{4}$=に、$w_{5}$=走る。といった具合です。
ところで、これから本書を通じて理解しようとしているLLMの対象は文章を生成するものです。ここで、少し数学的なトリック(?)というものが登場してきます。それはつまり、文章の生成を「確率」として表現し、生成するべき文章の確率を高くするようにLLMを訓練していくという考え方です。
例えば、ある語の集合(大きさがN個)が存在したとして、ある語の生成確率は確率関数$P$を使って$P(w_{i})$のように表すことができます。また、文章は続いていくものですから、ある語の次に来る語の確率は$P(w_{i+1}|w_{i})$のように表せます。例えば、「サンディー」の次に「が」が来る確率は、$P(が|サンディー)$のように記載できます。そして、訓練対象の関数が確率関数$P$になるのですが、この関数は確率計算の元ネタとなるパラメータを受け取る必要があり、それが$\theta$になります。つまり、「サンディー」の次に「が」が来る確率は$P(が|サンディー,\theta)$となります。
ですから、「サンディーが元気に走る」という文章が確率関数$P$で生成される確率は以下のように表現できます。
P=P(が|サンディー,\theta)P(元気|が,\theta)P(に|元気,\theta)P(走る|に,\theta)
各事象が同時に発生してほしいと考えた式になります。これは積事象ですから、文章が生成される確率は、それぞれの積になります。
今度はこの$P$を最大化する$\theta$を考えるときに、積事象ですと確率関数$P$の結果の掛け算となり値が大きくなりすぎて、コンピュータで計算するのに都合が悪くなる場合がありますので、$P$の代わりに$\log(P)$を計算します。$P$も$\log(P)$も単調増加関数のため、このように代用しても大丈夫という考え方になります。
\log(P)=\log(P(が|サンディー,\theta)P(元気|が,\theta)P(に|元気,\theta)P(走る|に,\theta))
$log$の中の掛け算は以下のように足し算で表すことができるので、コンピュータでの計算上大変好都合です。
\log(P)=\log(P(が|サンディー,\theta))+\log(P(元気|が,\theta))+\log(P(に|元気,\theta))+\log(P(走る|に,\theta))
上記を記号で表すと、以下のような式になります。
\log(P)=\log(P(w_{2}|w_{1},\theta))+\log(P(w_{3}|w_{2},\theta))+\log(P(w_{4}|w_{3},\theta))+\log(P(w_{5}|w_{4},\theta))
今度は、$w_{i}$を中心に前1語、後ろ1語に着目した確率を求めると以下のようになります。
\log(P)=\log(P(w_{i}|w_{i-1},\theta))+\log(P(w_{i+1}|w_{i},\theta))
前1語、後ろ1語の1をpで表します。
\log(P)=\log(P(w_{i}|w_{i-p},\theta))+\log(P(w_{i+p}|w_{i},\theta))
和の記号を用いてスッキリさせます。
\log(P)=\sum_{-p\leq j\leq p,j\neq 0}\log P( w_{i+j} |w_{i} ,\theta )
なお、$j=0$だと$\log P( w_{i} |w_{i} ,\theta )$を計算することになり、いままでそういった話はありませんでしたので、$j\neq 0$です。
次にこれをすべての語(N個)について和を取ります。なお、実験回数やモデル毎にNが異なるとNによって$\log(P)$の値が大きなったり、小さくなったりして
それぞれの試行で確率値を比較することが困難になりますから、Nで割っておきます(Nの平均値を採用する)。最後にマイナスをかけたものを$L( \theta )$として定義し、損失関数と呼びます。
L( \theta ) =-\frac{1}{N}\sum _{i=1}^{N}\sum_{-p\leq j\leq p,j\neq 0}\log P( w_{i+j} |w_{i} ,\theta )
マイナスをかけて符号を逆転させると、今度は単調減少関数になることがわかります。深層学習では損失関数の値を学習の過程で減少させていくことになりますので、この符号の逆転が必要になるということになります。
以下のようなフローで$L(\theta)$を最小化する$\theta$を求めていきます。
この$\theta$へのフィードバックを何度も繰り返し実施していくのがLLMの学習プロセスの核となる考え方です。
また、以下の式により$\theta$値を更新していきます。$\nabla_{\theta}$がミソです。$\theta$に関する$L(\theta)$の増分を計算するため、微分計算($\frac{\partial L(\theta)}{\partial \theta}$)を実施し、それをコンパクトに$\nabla_{\theta}$と表します。偏微分の計算結果は$L(\theta)$の増加方向に向くので(グラディエント)、マイナスを付けてさらに$L(\theta)$をかけた分を$\theta$から引くことで、$L(\theta)$が最小に向かうような$\theta$を求めます。このプロセスを繰り返して最終的に最適な$\theta$を得ようとするのが、LLMの学習プロセスの本質です。
\theta = \theta - \alpha\nabla_{\theta}L(\theta)
なお、$\alpha$は学習率とよばれ、1回の$\theta$の更新でどれくらい、$\nabla_{\theta}$を反映させるかのパラメータで、「ハイパーパラメータ」と呼ばれます。
この種のパラメータはLLMの学習プロセスで学習されないため、データサイエンティストなどが経験則などをもとに設定していく形になります。
ここまでで、LLMの学習プロセスに関する説明は終わりです。
自然言語とベクトルの関係(埋め込み行列とOneHotベクトル)
LLMではデータの内部表現としてベクトルを用いるため、自然言語をベクトル化する必要があります。
それには2つのステップを踏みます。1つに単語をOneHotベクトルに変換し、2つにOneHotベクトルを埋め込み行列を使って最終的に単語をベクトルに変換します。
OneHotベクトル
OneHotベクトルは、ベクトルの成分のうちどれか1つが1で、それ以外が0になるベクトルです。
例えば、LLMで扱う単語の集合があり、その大きさがN個だとすると、LLMで扱うOneHotベクトルの次元数もNになります。
そして、単語とOneHotベクトルが一意に対応するように変換する処理を行います。
サンディー=\begin{bmatrix}
1\\
0\\
0\\
0\\
0
\end{bmatrix} ,\ \ が=\begin{bmatrix}
0\\
1\\
0\\
0\\
0
\end{bmatrix} ,\ \ 元気=\begin{bmatrix}
0\\
0\\
1\\
0\\
0
\end{bmatrix} ,\ \ に=\begin{bmatrix}
0\\
0\\
0\\
1\\
0
\end{bmatrix} ,\ \ 走る。=\begin{bmatrix}
0\\
0\\
0\\
0\\
1
\end{bmatrix}
LLMの理論の枠組みの中では、OneHotベクトル表現について特に言及されていない様子です。
このため実装依存の話になるのでしょう。日本語は英語と異なり、単語が空白で区切られていないので、形態素解析ソフトやライブラリなどで、単語に分解する必要があります。その後、その単語群に対してOneHotベクトルを順に割り当てていくという感じの処理になるのかと思います。
OneHotベクトルと埋め込み行列
先程と同じように、$w_{1}$=サンディー、$w_{2}$=が、$w_{3}$=元気、$w_{4}$=に、$w_{5}$=走る。という単語集合(大きさN)が存在する場合に、それらをOneHotベクトル化したものを$\boldsymbol{w_{1}},\boldsymbol{w_{2}},\boldsymbol{w_{3}},\boldsymbol{w_{4}},\boldsymbol{w_{5}}$とします($w_{i}$の$w$をボールド体にしたものをベクトルとして、上のように1列の行列表現とします)。
埋め込み行列を$\boldsymbol{X}$とすれば、単語を埋め込み行列を使って変換したあとのベクトルを「埋め込み」といいます。例えば、以下の計算例を見てみます。
\boldsymbol{O}=\boldsymbol{X}\boldsymbol{w_{3}}
は、具体的な値でいくと、以下になります。
\boldsymbol{O} =\begin{bmatrix}
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5
\end{bmatrix}\begin{bmatrix}
0\\
0\\
1\\
0\\
0
\end{bmatrix} =\begin{bmatrix}
3\\
3\\
3\\
3\\
3
\end{bmatrix}
5行5列の行列×5行1列の行列のサイズは、上記で注意した法則を使って、5行1列のサイズと簡単に求まります。計算結果もそうなっています。計算の結果、「$w_{3}$=元気」の埋め込みは上の結果のベクトルとなるわけです。
通常は、語1つ1つに対して処理せずに、以下のように行列を使います。
$\boldsymbol{W}=\begin{bmatrix}\boldsymbol{w_{1}},\boldsymbol{w_{2}},\boldsymbol{w_{3}},\boldsymbol{w_{4}},\boldsymbol{w_{5}}\end{bmatrix}$とした場合、埋め込み行列を$\boldsymbol{X}$とすれば、単語を埋め込み行列を使って変換したあとのベクトルの列は以下の通りです。
\boldsymbol{O}=\boldsymbol{X}\boldsymbol{W}
上の例を用いながら具体値を示してみると以下のようになります。なお、$\boldsymbol{X}$の要素の数値は適当です。$\boldsymbol{X}$の要素はLLMの学習を通じて適切な値に調整されます(重要)。
\boldsymbol{O} =\begin{bmatrix}
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5
\end{bmatrix}\begin{bmatrix}
1 & 0 & 0 & 0 & 0\\
0 & 1 & 0 & 0 & 0\\
0 & 0 & 1 & 0 & 0\\
0 & 0 & 0 & 1 & 0\\
0 & 0 & 0 & 0 & 1
\end{bmatrix} =\begin{bmatrix}
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5\\
1 & 2 & 3 & 4 & 5
\end{bmatrix}
なんだか簡単な結果になりましたが、埋め込み行列に埋め込みがあり、OneHotベクトルに対応する埋め込みを選択するようなイメージの計算になりますね。
Transformerの理解に挑戦(概要)
では、本丸のTransformerの理解に挑戦してみます。以下がTransformerのアーキテクチャです。
アーキテクチャとは、深層学習の層をどのように接続したかの全体像だと理解しています。
画像参照:Attention Is All You Need Figure 1: The Transformer – model architecture.
図の大きく左部分をEncoderとよび、右半分をDecoderというそうです。
本書によれば、Encoderでは入力(図:Inputs)に対して「文脈化トークン埋め込み」というものを与えるとあります。上記で説明した埋め込みは、単語に対応するベクトル(=埋め込み)だったのですが、Encoderでは文脈を考慮した埋め込みを与えるという意味だと思います。
また、このセットがN(=6)個直列的に連結されています。この仕組みにより、初段では表層的な情報、中段では文法的な情報、後段では意味的な情報が保持されていると考えられているそうです。画像分類の深層学習モデルで使われるCNNと似た感じがします。やっぱり層を厚くするとモデルに溜まる知識/情報量が増えて賢くなるのでしょうね。
Encoderが出力した文脈化トークン埋め込みを参照しながら具体的なタスクを実行して、入力に対する次の単語の予測を行う部分がDecoderになります。
この動きの具体例を日本語から英語への翻訳タスクを例に考えてみます(参考ページ)。
例えば「こたつでみかんを食べる」を翻訳して「I eat mandarin at the kotatsu」を得るシーンを考えてみます。
なお、「こたつでみかんを食べる」は入力として与えられていますが、「I eat mandarin at the kotatsu」は得られていないのが初期値です。
「こたつでみかんを食べる」がInputsとしてEncoderに投入された結果、この文章を構成する個々の単語の文脈化埋め込みが計算されます。それが、Decoderに流れてきます(交差注意機構)。
Decoderには"Outputs(shifted Right)"として最初は文章のスタートを示す記号が入力されます。先程の交差注意機構での情報を参照しながら、「I」を出力します。
今度は、Decoderは「I」を"Outputs(shifted Right)"として入力して、同様に「I eat」を出力。以降、同様に処理していき、最終的に翻訳結果の文章を生成するということになります。
Chat GPTをつかうとポツポツと出力が得られますが、このような動作原理が関係しているのかもしれません。
ここまで、Transformerの概要でした。次はEncoderの理解に挑戦していきます。
Encoderの理解に挑戦
上のTransformerのアーキテクチャ図を見ると、Encoderは以下の部品から成り立っていることがわかります。
9-1. Input Embedding(入力埋め込み)
9-2. Positional Encoding(位置符号)
9-3. self-attention(自己注意機構)
9-4. Multi-Head Attention(マルチヘッド注意機構)
9-5. Feed Forward(フィードフォワード層)
9-6. 残差結合
9-7. Add & Norm(層正規化)
なお、self-attentionと残差結合はアーキテクチャ図に明に記載されてませんが、必要順に見ていきます。
なお、Transfomerでは入力する単語をトークンと呼ぶことがあります。以降、両語が混在して登場する場合がありますが、特に意味はなく同一ものもと思ってください。
Input Embedding
これは、上で説明した埋め込み行列の事です。
自然言語の単語に1:1に対応した埋め込みベクトルを求める処理になります。
繰り返しになりますが、埋め込み行列は最初はランダムな初期値から始まり、LLMの学習過程を通じて最適な値になります。
実は、埋め込み行列はそれだけで、様々な活用シーンがある大切な情報源になりますが、Transformerの理解試行を超えるため、別記予定です。
Positional Encoding(位置符号)
Positional Encodingは本書だけだと理解が難しい点がありましたので、参考サイトなどを参照しながら理解をメモしていきます。
ここまで説明した入力トークンの埋め込みは、トークンの順序や位置に関する情報を含んでいませんでした。
このため、たとえば「こたつでみかんを食べる」とそのトークンの任意の並べ替え「食べるみかんをこたつで」を同一の入力として処理されてしまうことを解決するために、入力トークン埋め込みに対して、位置符号Positional Encoding)を付加します(足し算します)。
位置符号は正弦関数を使って、入力トークン列の中のトークンの位置を表現する方法です。
なぜ、正弦関数を使うのかについては、位置符号に以下の要件があるとのことです。
なお、No3は私が独自の理解から加えたものです。
【位置符号の要件】
- 文章の長さに関係なく、それぞれの位置が同じ位置ベクトルを持つ必要がある
- 位置ベクトルを大きくしすぎてはいけない
- トークン列の位置ごとにユニークな値を付加する必要がある
埋め込みの次元が$D$のとき、位置$i$に対する位置符号(ベクトル)$\boldsymbol{p_{i}}$は、$\boldsymbol{p_{i}}$の$j$番目の要素を$\boldsymbol{p_{i,j}}$とすると、$k \in \{0,1,...,\frac{D}{2} -1\}$に対して、下記のように計算されます(ただし$D$は偶数)。
\boldsymbol{p_{i,2k+1}}=\sin\Biggl(\frac{i}{10000^{2k/D}}\Biggr)
\boldsymbol{p_{i,2k+2}}=\cos\Biggl(\frac{i}{10000^{2k/D}}\Biggr)
$\boldsymbol{p_{i}}$の要素を書き下してみると、以下のようになります。
\boldsymbol{p}_{i} =\begin{bmatrix}
\sin( i)\\
\cos( i)\\
\sin\left(\frac{i}{10000^{2/D}}\right)\\
\cos\left(\frac{i}{10000^{2/D}}\right)\\
\vdots \\
\sin\left(\frac{i}{10000^{( D-2) /D}}\right)\\
\cos\left(\frac{i}{10000^{( D-2) /D}}\right)
\end{bmatrix}
位置$i$のトークン$\boldsymbol{w_{i}}$の入力トークン埋め込みを$\boldsymbol{e_{w_{i}}}$とすると、モデルの入力埋め込みを$\boldsymbol{x_{i}}$は以下のように計算されます。
\boldsymbol{x_{i}}=\sqrt{D}\boldsymbol{e_{w_{i}}}+\boldsymbol{p_{i}}
ここで、$\boldsymbol{p_{i}}$の大きさは$\sqrt{D}$です($\sin(x)^2+\cos(x)^2=1$に注意して、$\sqrt{\boldsymbol{p_{i}}^T\boldsymbol{p_{i}}}$を計算すればわかる)。そこで、$\boldsymbol{e_{w_{i}}}$の大きさを$\boldsymbol{p_{i}}$と合わせるため、$\boldsymbol{e_{w_{i}}}$に$\sqrt{D}$を掛けています。
具体的には、「こたつ」のモデルへの入力埋め込み$\boldsymbol{x_{1}}$は$\boldsymbol{x_{1}}=\sqrt{D}\boldsymbol{e_{こたつ}}+\boldsymbol{p_{1}}$のように計算されます。
トークン数をN個とすると、モデルには、$\boldsymbol{x_{1}},...,\boldsymbol{x_{N}}$が入力されます。
ところで、$\boldsymbol{p_{1}}$の具体的な値はどのようになるのでしょうか。以下にグラフ化してみます。
import numpy as np
import matplotlib.pyplot as plt
#D = 512 #Transformerの入力トークンは512次元とのこと
D = 32 #Transformerの入力トークンは512次元とのこと
pos_enc = np.empty(D)
i = 1
for k in range(D//2):
theta = i / (10000 ** ( 2 * k / D ))
pos_enc[2 * k ] = np.sin(theta)
pos_enc[2 * k + 1] = np.cos(theta)
#pos_enc[2 * k + 1] = 0
x = np.arange(len(pos_enc))
y = pos_enc
# 折れ線グラフを描く
# plt.plot(x, y)
plt.scatter(x, y, label='Positional encoding', color='blue', marker='o', s=10)
plt.plot(x, y, label='Connecting Line', color='red', linestyle='--')
# グラフにタイトルとラベルを追加
plt.title('Positional encoding')
plt.xlabel('index of element number of p')
plt.ylabel('element value')
グラフにしてみます。
X軸が$\boldsymbol{p_{1}}$の要素を先頭から番号(0~)を付けていったものです。
y軸が要素の値となっています。
グラフの下側に集中して存在する青点がsinのデータで、逆の上側がcosです。
sinとcosの周期が大きいので、それぞれの関数の値が変わっていないように見えますが、微妙に変わっています。また、sin関数または、cos関数だけの場合はx軸の値変化に対するy軸の値の変化量が小さすぎるので、表現力を加えるため、このようにsinとcosを混ぜていると考えられます。
これで、埋め込みの各要素毎にユニークな値が付加できますし(要件3)、この関数はいつも同じ出力であることから、要件1も満たせます。また、このように0~1の値を取る関数なので要件2も満たせます。
なお、$\sin$と$\cos$は周期関数のため、いずれは同じ値に戻ってしまいます。これを防ぐため、これだけの周期の大きさにしています。Transformerでは入力埋め込みの次元は512次元のため、この周期にしておけば実用上十分ということになります。
このあたりがわかるよう、D=512とした場合のグラフも載せておきます(単純な散布図なので、点間の接続は描画せず)
しかし、コレが一体何を示しているものなのか?というのは位置符号$\boldsymbol{p_{i}}$同士の内積をグラフ化するとよくわかります。今度は各単語の$\boldsymbol{p_{i}}$同士の内積を取ってみます。
import numpy as np
import matplotlib.pyplot as plt
K = 32 #単語の数
D = 32 #埋め込みの次元
pos_enc = np.empty((K,D))
for i in range(K):
for k in range(D//2):
theta = i / (10000 ** ( 2 * k / D ))
pos_enc[i, 2 * k ] = np.sin(theta)
pos_enc[i, 2 * k + 1] = np.cos(theta)
#pos_enc[2 * k + 1] = 0
#位置符号同士の内積を計算
dot_matrix = np.matmul(pos_enc, pos_enc.T)
#得た行列を画像で表示
im = plt.imshow(dot_matrix, origin="lower")
plt.xlabel("position")
plt.ylabel("position")
plt.colorbar(im)
plt.show()
このように同じ位置同士は大きな値になり、遠くなるにつれ値が小さくなることがわかります。
この位置符号を、入力埋め込みに足しておけば、埋め込み同士の類似度計算(内積)の際に、近い位置であれば類似度が高く、遠ければ低いということになります1。
これは、位置が近い単語のほうが遠い方よりも関連度が高くなりやすいという(経験的?)な言語の特性を学習するのに役に立つと考えられているそうです。
次は、Multi-Head Attentionに行く前に、self-attention(自己注意機構)について挑戦していきます。
self-attention(自己注意機構)
自己注意機構は、入力トークン埋め込みに対して文脈の情報を付与していくための機構です。この機構は埋め込み列を入力として受取り、それらを相互に参照して、新しい埋め込み列を計算します。自己注意機構の一種である、以下のキー・クエリ・バリュー注意機構は、入力された埋め込みに対して、Key、Query、Valueの3つの異なる埋め込みを計算します。
この注意機構には、クエリの埋め込み、キーの埋め込み、バリューの埋め込みを計算するための3つの$D\times D$次元の行列$\boldsymbol{W_q}$、$\boldsymbol{W_k}$、$\boldsymbol{W_v}$が含まれます。これらの行列を訓練時に学習することで、重要度を加味した文脈化ができるようになるそうです。
以下が、この注意機構の概略になります。
なお、前提として、$D$次元の埋め込み列$\boldsymbol{h_1},\boldsymbol{h_2},...,\boldsymbol{h_N}$がこの注意機構に入力されたとします。
最初にクエリが計算されます。計算式は以下になります。
\boldsymbol{q_i} = \boldsymbol{W_q}\boldsymbol{h_i}
次にキーが計算されます。すべての入力埋め込みに対するキーが計算されることに注意が必要です。
\boldsymbol{k_i} = \boldsymbol{W_k}\boldsymbol{h_i}
この概略図では上の計算式で、$\boldsymbol{k_1},\boldsymbol{k_2},\boldsymbol{k_3}$が計算されます。
今度は、$i$番目のトークンからみた$j$番目のトークンの関連スコア$s_{ij}$を、内積(類似度)を用いて、以下のように計算します。
s_{ij}=\frac{\boldsymbol{q_i}^T\boldsymbol{k_j}}{\sqrt{D}}
ここで、分母の$\sqrt{D}$は、次元$D$が増えるに伴って、内積の絶対値が大きな値になりすぎるのを防ぐことにより、訓練を安定化させるために導入されているとのことです。
最後に、出力埋め込み$\boldsymbol{o_i}$についてです。これは、関連性スコア$s_{ij}$をソフトマックス(softmax)関数を使って正規化(確率値化)した重み$\alpha_{ij}$によるバリュー埋め込みの重み付き和になります。
\boldsymbol{v_i} = \boldsymbol{W_v}\boldsymbol{h_i}
\alpha_{ij} = \frac{\exp(s_{ij})}{\sum_{j'=1}^{N}\exp(s_{ij'})}
\boldsymbol{o_i} = \sum_{j=1}^{N}\alpha_{ij}\boldsymbol{v_j}
感覚的な話ですが、入力単語列をこの機構に流し込むと、(位置符号によって、)単語の位置も考慮しつつ、入力単語列に近いような単語を出力してくれる感じに見えます。
なお、「自己注意機構」という名前は、自分自身の途中の計算結果に注意し(クエリがキーを参照する)、そこから読み込む(バリューの重み和を取る)ことからこの名がついているとのことです。
ここまでで、self-attention(自己注意機構)の説明は終わりです。次はmulti head attentionですが、self-attentionを理解していれば今までの繰り返しになるので、簡単です。
multi head attention(マルチヘッド注意機構)
この節は本書から引用しています。「表現力」のところに独自の注、行列サイズのところに独自の解説を入れています。
キー・クエリ・バリュー注意機構の表現力2を更に高めるために、この注意機構を同時に複数適用するマルチヘッド注意機構(multi head attention)が採用されています。
例えば、上述の「マウスでクリックする」という文の「マウス」について、タスクによってはトークンの意味の他に品詞や係り受けなどの文法的な情報が重要になる場合があります。
複数の注意機構を同時に適用することで、複数の観点から文脈化を行うことができます。
マルチヘッド注意機構では$D$次元の埋め込み$\boldsymbol{h_i}$に対して、$M$個の注意機構を同時に適用します。ここで、$M$は$D$の約数である事に注意が必要です。$\boldsymbol{h_i}$に対する$m\in
\{1,2,...M\}$番目の注意機構の埋め込みは下記のように計算されます。
\boldsymbol{q_i}^{(m)} = \boldsymbol{W_q}^{(m)}\boldsymbol{h_i}
\boldsymbol{k_i}^{(m)} = \boldsymbol{W_k}^{(m)}\boldsymbol{h_i}
\boldsymbol{v_i}^{(m)} = \boldsymbol{W_v}^{(m)}\boldsymbol{h_i}
ここで、$\boldsymbol{W_q}^{(m)}$、$\boldsymbol{W_k}^{(m)}$、$\boldsymbol{W_v}^{(m)}$は、$m$番目のヘッド(head)に対応する、$\frac{D}{M} \times D$の行列であり、$\boldsymbol{h_i}$が$D$次元の縦ベクトルのため($D行\times1列$の行列)したがって、$\boldsymbol{q_i}^{(m)}$、$\boldsymbol{k_i}^{(m)}$、$\boldsymbol{v_i}^{(m)}$は、$\frac{D}{M}$次元のベクトルとなります($\frac{D}{M}行 \times 1列$の縦ベクトル)。この$M$個のヘッドがそれぞれ異なる観点から文脈化を行います。
各ヘッドの$\boldsymbol{o_i}^{(m)}$は、単一のヘッドによる注意機構と同様に下記のように計算されます。
s_{ij}^{(m)}=\frac{\boldsymbol{q_i}^{(m)T}\boldsymbol{k_j}^{(m)}}{\sqrt{\frac{D}{M}}}
\alpha_{ij}^{(m)} = \frac{\exp(s_{ij}^{(m)})}{\sum_{j'=1}^{N}\exp(s_{ij'}^{(m)})}
\boldsymbol{o_i}^{(m)} = \sum_{j=1}^{N}\alpha_{ij}^{(m)}\boldsymbol{v_j}^{(m)}
マルチヘッド注意機構の出力は、$M$個の出力結果埋め込みを連結して計算されます。
\boldsymbol{o_i}^{(m)} =\boldsymbol{W_o}\begin{bmatrix}
\boldsymbol{o_i}^{(1)}\\
\vdots\\
\boldsymbol{o_i}^{(M)}\\
\end{bmatrix}
ここで$\boldsymbol{W_o}$は$D \times D$の重み行列です。
Feed Forward(フィードフォワード層)
本書を一部引用しながら、ここは簡単に説明します。
この層の目的は、Transformerに表現力を加えること、および、文脈に関する豊富な情報を記憶するための層とされています。
フィードフォワード層は以下のように、数式で表現されます。
この層への入力ベクトルを$\boldsymbol{u_i}$、出力ベクトルを$\boldsymbol{z_i}$とします。
\boldsymbol{z_i} = \boldsymbol{W_2}f(\boldsymbol{W_1}\boldsymbol{u_i} + \boldsymbol{b_1}) + \boldsymbol{b_2}
ここで、$\boldsymbol{W_1}$、$\boldsymbol{W_2}$は、それぞれ$D_f \times D$次元、$D \times D_f$次元の行列、$\boldsymbol{b_1}$と$\boldsymbol{b_2}$は、それぞれ、$D_f$次元、$D$次元のベクトルで、$f$は、非線形性をもつ活性化関数で、フィードフォワード層の表現力を高くするのに不可欠な要素です。また、この層が無いと、マルチヘッド注意機構により単に入力を線形変換して重み付きで足し合わせるだけになり、表現力が著しく低くなるそうです。
なお、提案時のTransformerでは入力次元$D=512$に対して、中間層の次元は4倍の$D_f=2048$が使われています。この結果、フィードフォワード層に含まれるパラメータ数は、Transformer全体の約2/3を占めることになるそうです。また、この層は文脈に関する情報をその豊富なパラメータの中に記憶し、入力された文脈に対して、関連する情報を付加する役割を果たしていると言われています3。
残差結合
従来、深層学習は層を重ねれば重ねるほど性能(精度)が向上すると考えられてきましたが、実は逆で性能が悪くなる場合が出てきました。原因は層を重ねれば重ねるほど顕著になる「勾配爆発」と「勾配消失」です。
そして、ここで説明する残差結合は勾配消失に対する対策法です。
以下はEncoderだけを抜き出した図です。
この図で赤い丸でくくったところが2箇所あります。これが残差結合と呼ばれるもので、「次の層への入力が分岐して、そのまま層を飛び越したものと層の出力が足し合わされたものを層の総合結果とする」になります。この工夫によって、勾配消失の可能性が軽減されるため、層を深くすることが可能になります。なお、勾配爆発への対処方法は内積を$\sqrt{D}$で割ったりするなど小さくすることで対処しています。以上で残差結合の説明は終わりです。
以下からは深層学習においてなぜ、勾配爆発や勾配消失が発生するかのメカニズム、および、勾配消失がどうして残差結合で軽減されるのかについて説明します。微分の知識が必要(説明はします)であり、以下を理解しなくてもTransformerの本質的な理解には影響しませんので、読み飛ばしを推奨します。
残差結合の詳細
この説明は参考書をかなり参考にしています。
最初に、簡単なニューラルネットワークを取り上げることで、勾配爆発/消失の説明を試みます。次の構造を持つニューラルネットワークを考えます。準備として$f^{(1)}(x)=w^{(1)}x=h^{(1)}$はパラメータ$w^{(1)}$をかけてスカラー$h^{(1)}$を出力する関数、$f^{(o)}(h,y)=w_y^{(o)}h=o_y$は、ラベル$y$のスコアを計算する関数です。なお、関数$f$、$w$、$h$の右肩の$(1)$や$(o)$は層の番号、位置を示します。つまり、$(1)$は1層目、$(o)$はスコアの出力層になります。また、追加で第2層、第3層も定義します、それぞれ、$f^{(2)}(h)=w^{(2)}h=h^{(2)}$、$f^{(3)}(h)=w^{(3)}h=h^{(3)}$です。ここで、スコア関数は以下のように定義します。
f(x,y)=f^{(o)}(f^{(3)}(f^{(2)}(f^{(1)}(x))),y)=w_y^{(o)}w^{(3)}w^{(2)}w^{(1)}x
ここで、損失関数$\ell(f(x,y))$とすれば、$f^{(1)}$のパラメータ$w^{(1)}$による微分は、連鎖律を使って以下のように計算できます。
\frac{\partial \ell(f^{(o)}(f^{(3)}(f^{(2)}(f^{(1)}(x))),y))}{\partial w^{(1)}}=\\
\frac{\partial \ell(f^{(o)}(f^{(3)}(f^{(2)}(f^{(1)}(x))),y))}{\partial f^{(o)}(f^{(3)}(f^{(2)}(f^{(1)}(x))),y)}\frac{\partial f^{(o)}(f^{(3)}(f^{(2)}(f^{(1)}(x))),y)}{\partial f^{(3)}(f^{(2)}(f^{(1)}(x)))}\frac{\partial f^{(3)}(f^{(2)}(f^{(1)}(x)))}{\partial f^{(2)}(f^{(1)}(x))}\frac{\partial f^{(2)}(f^{(1)}(x))}{\partial f^{(1)}(x)}\frac{\partial f^{(1)}(x)}{\partial w^{(1)}}
これは、上で出てきた関数の微分の結果を代入して整理すると、以下になります。
\frac{\partial \ell(f^{(o)}(f^{(3)}(f^{(2)}(f^{(1)}(x))),y))}{\partial w^{(1)}}=\\
\frac{\partial \ell(o_y)}{\partial o_y}w_y^{(o)}w^{(3)}w^{(2)}x
ここで、$w_y^{(o)}$、$w^{(3)}$、$w^{(2)}$が仮に100だとすると、偏微分の計算結果に$100^3$が乗じられる形になり、多少の重みの変化によって、随分と計算結果が大きくぶれそうです。実際にこうなると、損失関数の値がなかなか収束しないため、学習がうまく進みません。これを勾配爆発と呼びます。
逆に、それぞれの値が仮に0.01くらいだったとすると、偏微分の計算結果に$0.01^3$が乗じられ、勾配がかなり小さな値になってしまいます。こうなると、計算結果に差がでなくなり、損失関数の値が急に収束するようになり、満足する精度が出ないパラメータとなったまま、学習がストップするような形になってしまいます。これを勾配消失と呼び、残差結合で軽減できる対象です。
ここで、残差とは以下を考えることです。
f^{(1+)}(h)=f^{(1)}(h)+h
つまり、関数$f^{(1)}$の出力と入力$h$そのものの和を出力とする新しい関数(残差関数)を考え、これを対象に偏微分計算をすすめるのです!まず、$f^{(1+)}(h)$の偏微分を考えてみます。計算結果は以下の通りです。
\frac{\partial f^{(1+)}(h)}{\partial h}=\frac{\partial f^{(1)}(h)}{\partial h}+\frac{\partial h}{\partial h}=w^{(1)}+1
ここで、結果に1が出現しているのがミソです。次に、残差関数を使ってスコア関数の定義をし直します。
f(x,y)=f^{(o)}(f^{(3+)}(f^{(2+)}(f^{(1+)}(x))),y)
なお、出力層の$f^{(o)}$については残差を考えても結果に影響を与えなく無駄なだけなので、これを対象に残差関数は与えません。
損失関数の$w^{(1)}$に関する偏微分は以下の通りとなります。
\frac{\partial \ell(f^{(o)}(f^{(3+)}(f^{(2+)}(f^{(1+)}(x))),y))}{\partial w^{(1)}}=\\
\frac{\partial \ell(o_y)}{\partial o_y}w_y^{(o)}(w^{(3)}+1)(w^{(2)}+1)x
こうなるため、$w^{(3)}$と$w^{(2)}$の絶対値が小さくとも1が計算結果にたされるため、勾配消失は軽減されることになります。
ところで、もともとの関数から残差関数に変えて大丈夫なの?という疑問があると思いますが、結論から言うと問題ありません。なぜかというと、残差関数はもとの関数の出力に、入力がそのまま加算されただけであり、もともとの関数の形を大きく損なうわけじゃなく、もともとの関数と「大体同じ感じ」に扱うことができるからだと、解釈しています。
以上、残差結合の本質となります。
次は、Encoder最後の要素、Add & Norm(層正規化)になります。
Add & Norm(層正規化)
この層の役割は、ベクトルの要素の値が過剰に大きい値になることで、学習/訓練が不安定になることを防ぐために、ベクトルの値を平均0、分散1の分布の値に正規化することです。
この層への入力ベクトルを$D$次元の$\boldsymbol{x}$とします。また、ベクトルの要素の平均を$\mu_\boldsymbol{x}$、標準偏差$\sigma_\boldsymbol{x}$を以下とします。
\mu_\boldsymbol{x} = \frac{1}{D}\sum_{i=1}^{D}x_i
\sigma_\boldsymbol{x}=\sqrt{\frac{1}{D}\sum_{i=1}^{D}(x_i-\mu_\boldsymbol{x})^2}
そして、この層の役割を担う「層正規化関数」を$layernorm(\boldsymbol{x})$と定義します。この関数の出力はベクトルになります。また、出力ベクトルの$k$番目の要素を$layernorm(\boldsymbol{x})_k$とするとき、具体的な値は以下になります。
layernorm(\boldsymbol{x})_k=g_k\frac{x_k-\mu_\boldsymbol{x}}{\sigma_\boldsymbol{x}+\epsilon}+b_k
$g_k$と$b_k$はゲインベクトル$\boldsymbol{g}$とバイアスベクトル$\boldsymbol{b}$の$k$番目の要素です。この2つのベクトルはこの層の表現力を向上するために導入されていますが、$\boldsymbol{g}=\boldsymbol{1}$、$\boldsymbol{b}=\boldsymbol{0}$($\boldsymbol{g}$のすべて要素が1、$\boldsymbol{b}$のすべての要素が0)として無効化することも可能とのことです。なお、$\epsilon$はゼロ除算防止のために設けられており、0.000001などの非常に小さい値が用いられます。
なお、層正規化関数は、入力ベクトルの要素を平均0、分散1の分布を持つ値に変換します。ゲインベクトルとバイアスベクトルを無効化したあとの式を考えるとよくわかります。
layernorm(\boldsymbol{x})_k=\frac{x_k-\mu_\boldsymbol{x}}{\sigma_\boldsymbol{x}+\epsilon}
「確率・統計」のところに出てきた、確率分布の変数変換になります。確率変数$\boldsymbol{X}$を$\boldsymbol{Z}=\frac{\boldsymbol{X}-\mu}{\sigma}$として変換してやれば、平均0、分散1の分布に変換されます。「確率・統計」の変数変換が層正規化で使われていたのです!
これで、Add & Norm層の説明は終わりです!
ここまででTransformerのEncoderのほぼすべての要素を説明し終わりました。
エンコーダ・デコーダ構成の場合に交差注意機構、masked multihead attentionが出てきますが、Encoderで説明した話とほぼ同じため、リクエストがあれば追記することにします。
Encoder/Decoderの理解に挑戦
今後記載するかも。
masked multi head attention
交差注意機構
-
ところで、なぜ、このような計算結果になるかですが、計算結果は、「単語の数×位置符号の次元数」×「位置符号の次元数×単語の数」の行列の積の結果、「単語の数×単語の数」サイズの行列になっています。実際手で計算してみると分かるのですが、この行列の行および列が単語の番号であり、行列の要素が内積の計算結果となっています。 ↩
-
深層学習の誤差逆伝播法では、損失関数に対して変化を求めたいパラメータに関する偏微分を実施していきますが、その際、関数の値が滑らかで(連続しており)変化がある方がより結果が良いと言われています。逆に滑らかでない場合は微分の結果の値が急に収束してしまい、学習結果が安定しません。滑らかで変化がある具合を指して「表現が豊か」と言ったり、このニューラルネットワーク(モデル)は表現力があると言ったりします。 ↩
-
いわゆるファインチューニングではこの層をトレーニングしていくようです。 ↩