初めに
この記事では、トランスフォーマーモデルにおける最先端の位置エンコーディングの発見過程を段階的に説明していきます。位置をエンコードするアプローチを段階的に改善していき、最新のLLama 3.2リリースや最新のトランスフォーマーで使用されている回転位置エンコーディング(RoPE)に到達します。この記事では、理解に必要な数学的知識を最小限に抑えることを意図していますが、基本的な線形代数、三角関数、そしてセルフアテンションの理解が前提となります。参考
問題提起
すべての問題と同様に、まず私たちが達成しようとしていることを正確に理解することから始めるのが最善です。トランスフォーマーにおけるセルフアテンション機構は、シーケンス内のトークン間の関係を理解するために使用されます。セルフアテンションは集合演算であり、これは順列等価であることを意味します。セルフアテンションに位置情報を付加しない場合、多くの重要な関係性を判断することができなくなります。
動機付けとなる例
異なる位置に同じ単語が出現する次のような文を考えてみましょう:
「The dog chased another dog」
直感的に、"dog"は2つの異なる実体を指しています。これらをまずトークン化し、Llama 3.2 1Bの実際のトークンエンベディングにマッピングして、torch.nn.MultiheadAttentionに通した場合、何が起こるか見てみましょう。
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
model_id = "meta-llama/Llama-3.2-1B"
tok = AutoTokenizer.from_pretrained(model_id)
model = AutoModel.from_pretrained(model_id)
text = "The dog chased another dog"
tokens = tok(text, return_tensors="pt")["input_ids"]
embeddings = model.embed_tokens(tokens)
hdim = embeddings.shape[-1]
W_q = nn.Linear(hdim, hdim, bias=False)
W_k = nn.Linear(hdim, hdim, bias=False)
W_v = nn.Linear(hdim, hdim, bias=False)
mha = nn.MultiheadAttention(embed_dim=hdim, num_heads=4, batch_first=True)
with torch.no_grad():
for param in mha.parameters():
nn.init.normal_(param, std=0.1) # 重みを無視できない値で初期化
output, _ = mha(W_q(embeddings), W_k(embeddings), W_v(embeddings))
dog1_out = output[0, 2]
dog2_out = output[0, 5]
print(f"Dog output identical?: {torch.allclose(dog1_out, dog2_out, atol=1e-6)}") #True
見ての通り、位置情報がない場合、トークンが明らかに異なる実体を表しているにもかかわらず、(マルチヘッドの)セルフアテンション操作の出力は異なる位置にある同じトークンに対して同一となります。それでは、単語間の関係性を位置によってエンコードできるように、位置情報を使ってセルフアテンションを強化する方法を設計していきましょう。
最適なエンコーディング方式を理解し設計するために、そのような方式が持つべき望ましい特性をいくつか見ていきましょう。
望ましい特性
最適化プロセスをできるだけ容易にするための望ましい特性をいくつか定義してみましょう。
特性1 - 各位置に対する一意のエンコーディング(シーケンス間で一貫)
各位置には、シーケンスの長さに関係なく一貫した一意のエンコーディングが必要です。例えば、位置5のトークンは、現在のシーケンスの長さが10であっても10,000であっても、同じエンコーディングを持つべきです。
特性2 - エンコードされた位置間の線形関係
位置間の関係は数学的にシンプルであるべきです。位置pのエンコーディングが分かっている場合、位置p+kのエンコーディングを計算することは簡単であるべきです。これにより、モデルが位置パターンを学習しやすくなります。
数直線上の数字の表現方法を考えてみると、5は3から2ステップ離れている、または10は15から5ステップ離れているということが簡単に理解できます。私たちのエンコーディングにも同様の直感的な関係性が存在するべきです。
特性3 - 学習時に遭遇したよりも長いシーケンスへの一般化
実世界でのモデルの有用性を高めるために、学習分布の外側にも一般化できるべきです。したがって、エンコーディング方式は、他の望ましい特性を損なうことなく、予期せぬ入力長にも対応できるほど適応性がなければなりません。
特性4 - モデルが学習可能な決定論的プロセスによる生成
位置エンコーディングが決定論的プロセスから導き出されることが理想的です。これにより、モデルが私たちのエンコーディング方式の背後にあるメカニズムを効率的に学習できるはずです。
特性5 - 複数次元への拡張可能性
マルチモーダルモデルが標準となりつつある中、位置エンコーディング方式が1次元からn次元へ自然に拡張できることが重要です。これにより、モデルは画像(2次元)や脳スキャン(4次元)のようなデータを処理できるようになります。
これらの理想的な特性(以降、Prnと呼びます)が分かったところで、エンコーディング方式の設計と反復的な改善を始めていきましょう。
整数位置エンコーディング
最初に思いつく方法は、トークン位置の整数値を単純にトークンエンベディングの各成分に加えることです。この値は0からLの範囲をとり、Lは現在のシーケンスの長さです。
上のアニメーションでは、chased
というトークンに対する位置エンコーディングベクトルをインデックスから作成し、トークンエンベディングに加えています。ここで示されているエンベディング値はLlama 3.2 1Bの実際の値の一部です。これらの値が0の周りに集中していることが分かります。これは学習中の勾配消失や勾配爆発を避けるために望ましい特性であり、モデル全体を通して維持したい性質です。
現在の素朴なアプローチには問題があることは明らかです。位置の値の大きさが入力の実際の値を大きく上回ってしまいます。これは信号対雑音比が非常に低くなることを意味し、モデルにとって意味的情報と位置情報を分離することが困難になります。
この新しい知見を踏まえると、自然な発展として位置の値をN分の1で正規化することが考えられます。これにより値は0から1の間に収まりますが、別の問題が発生します。Nを現在のシーケンスの長さに設定すると、異なる長さのシーケンスごとに位置の値が完全に異なってしまい、Pr1を違反することになります。
数値を0と1の間に収めるより良い方法はないでしょうか?しばらく深く考えてみると、10進数から2進数に切り替えるという発想が浮かぶかもしれません。
2進数位置エンコーディング
(場合によっては正規化された)整数位置をエンベディングの各成分に加える代わりに、それを2進数表現に変換し、エンベディングの次元に合わせて値を引き伸ばすことができます。以下で示すとおりです。
対象となる位置(252)を2進数表現(11111100)に変換し、各ビットをトークンエンベディングの対応する成分に加えました。最下位ビット(LSB)は後続の各トークンで0と1の間を循環し、一方で最上位ビット(MSB)は$2^{(n-1)}$トークンごとに循環します(ここでnはビット数です)。下のアニメーションで異なるインデックスに対する位置エンコーディングベクトルを確認できます。
値の範囲の問題は解決し、異なるシーケンス長にわたって一貫した一意のエンコーディングを得ることができました。では、トークンエンベディングの低次元バージョンをプロットし、異なる値に対する2進数位置ベクトルの加算を可視化するとどうなるでしょうか。
結果が非常に「飛び飛び」になっていることがわかります(2進数の離散的な性質から予想される通りです)。最適化プロセスは滑らかで、連続的で予測可能な変化を好みます。同様の値の範囲を持ち、滑らかで連続的な関数を知っているでしょうか?
少し考えてみると、sin関数とcos関数がぴったりだということに気づくかもしれません!
シヌソイド埋め込み(Sinusoidal Embeddings)の説明
上のアニメーションは、徐々に波長が増加するsin関数とcos関数から交互に各成分を取得した場合の位置エンベディングを可視化しています。前のアニメーションと比較すると、驚くべき類似性に気づくでしょう!
これにより、私たちは「Attention is all you need」論文で最初に定義された正弦波エンベディングに到達しました。式を見てみましょう:
$$
PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d}}\right)
$$
$$
PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d}}\right)
$$
ここで:
-
pos
はトークンの位置インデックスです。 -
i
は位置エンコーディングベクトル内の成分インデックスです。 -
d
はモデルの次元です。 -
10000
は基準波長$\theta$と呼びます)であり、成分インデックスに応じてこの波長を伸縮します。
$\sin$ と $\cos$ の使用理由:
重要な点は、2つのエンコードされた位置間に線形関係を持たせることです。この線形関係を生成するために、三角関数の性質を利用します。
変換行列 M の導出
位置シフトを行う線形変換行列 $M$ を導出することを考えます。シフト量を $k$、周波数を $\omega_i$ とすると、次の関係を満たす行列 $M$ を求めます:
$$
M \cdot \begin{bmatrix} \sin(\omega_i p) \\ \cos(\omega_i p) \end{bmatrix} = \begin{bmatrix} \sin(\omega_i (p+k)) \\cos(\omega_i (p+k)) \end{bmatrix}
$$
周波数 $omega_i$
$$
\omega_i = \frac{1}{10000^{2i/d}}
$$
一般的な 2 x 2 行列
行列 $M$ を次の形式で表します:
$$
M = \begin{bmatrix} u_1 & v_1 \\ u_2 & v_2 \end{bmatrix}
$$
これをシフト後の三角関数展開に適用します。三角関数の加法定理を使い、次の2つの方程式を得ます:
$$
u_1 \sin(\omega_i p) + v_1 \cos(\omega_i p) = \cos(\omega_i k) \sin(\omega_i p) + \sin(\omega_i k) \cos(\omega_i p)
$$
$$
u_2 \sin(\omega_i p) + v_2 \cos(\omega_i p) = -\sin(\omega_i k) \sin(\omega_i p) + \cos(\omega_i k) \cos(\omega_i p)
$$
解の導出
係数比較により、未知数 $u_1, v_1, u_2, v_2$ を解くと:
$$
u_1 = \cos(\omega_i k), \quad v_1 = \sin(\omega_i k)
$$
$$
u_2 = -\sin(\omega_i k), \quad v_2 = \cos(\omega_i k)
$$
最終的な変換行列 $M_k$
$$
M_k = \begin{bmatrix} \cos(\omega_i k) & \sin(\omega_i k) \\ -\sin(\omega_i k) & \cos(\omega_i k) \end{bmatrix}
$$
回転行列との関係
この結果は、回転行列と同じ形になります。つまり、シヌソイド埋め込みは、位置エンコーディングを回転として表現していたのです。
絶対位置と相対位置のエンコーディング
回転が重要であるという知識を得たところで、最初の例に戻り、次の反復に向けての直感を探ってみましょう。
絶対位置:
0 1 2 3 4
The dog chased another dog
相対位置("chased"から見た場合):
-2 -1 0 1 2
The dog chased another dog
上の例で、トークンの絶対位置と、chased
から見た各トークンへの相対位置が分かります。正弦波エンコーディングでは、絶対位置を表す別のベクトルを生成し、三角関数のテクニックを使って相対位置をエンコードすることができました。
しかし、文を理解しようとするとき、ある単語がこのブログ記事の2157番目の単語であることは重要なのでしょうか?それとも、その周囲の単語との関係性の方が重要なのでしょうか?単語の絶対位置が意味に関係することは稀で、重要なのは単語同士がどのように関連しているかということなのです。
文脈における位置エンコーディング
ここからは、位置エンコーディングをセルフアテンションの文脈で考えることが重要です。繰り返しになりますが、セルフアテンション機構により、モデルは入力シーケンスの異なる要素の重要性を重み付けし、出力に対するそれらの影響を動的に調整することができます。
$$
Attn(Q,K,V) = softmax(\frac{QK^T}{\sqrt(d_k)})V
$$
これまでの全ての反復において、私たちは別個の位置エンコーディングベクトルを生成し、Q
、K
、V
の射影の前にトークンエンベディングに加算していました。位置情報をトークンエンベディングに直接加えることで、意味情報を位置情報で汚染していることになります。ノルムを変更せずに情報をエンコードするよう試みるべきです。乗算的なアプローチに移行することがカギとなります。
辞書のアナロジーを使うと、単語(クエリ)を辞書(キー)で検索する時、近くの単語は遠くの単語より大きな影響を持つべきです。あるトークンが別のトークンに与える影響は$QK^T$の内積によって決定されます - だからこそ、そこに位置エンコーディングの焦点を当てるべきなのです!
$$\vec{a} \cdot \vec{b} = |\vec{a}| |\vec{b}| \cos \theta$$
上に示した内積の幾何学的解釈は、私たちに素晴らしい洞察を与えてくれます。2つのベクトル間の角度を増減させるだけで、内積の結果を調整できるのです。さらに、ベクトルを回転させることで、トークンの意味情報をエンコードするベクトルのノルムには全く影響を与えません。
これで私たちはアテンションをどこに向けるべきかが分かり、別の角度から位置情報をエンコードする「チャネル」として回転が理にかなっている理由も理解できました。では、これらすべてをまとめてみましょう!
回転位置エンベディング(RoPE)
回転位置エンベディング(Rotary Positional Embedding、略してRoPE)はRoFormer論文で定義されました(Jianlin Su)。最終結果だけを見ると不可思議に見えるかもしれませんが、セルフアテンションの文脈(特に内積)で正弦波エンコーディングについて考えることで、すべてがどのように組み合わさるのかが分かります。
正弦波エンコーディングと同様に、私たちのベクトル(射影前の$x$ではなく、$q$または$k$)を2次元のペア/チャンクに分解します。ゆっくりと減少する周波数の正弦波関数から導いたベクトルを加算することで絶対位置を直接エンコードする代わりに、端的に各ペアを回転行列と乗算することで相対位置をエンコードします。
位置$p$における入力ベクトル$q$または$k$について考えます。各成分ペアに対して望ましい回転を行う回転行列$M_i$からなるブロック対角行列を作成します:
$$
R(q, p) =
\begin{pmatrix}
M_1 & & \\
& M_2 & \\
& & \ddots & \\
& & & M_{d/2}
\end{pmatrix}
\begin{pmatrix}
q_1 \\
q_2 \\
\vdots \\
q_d
\end{pmatrix}
$$
ここで$M_i$は以下のような回転行列です:
$$
M_i =
\begin{pmatrix}
\cos(\omega_i p) & -\sin(\omega_i p) \\
\sin(\omega_i p) & \cos(\omega_i p)
\end{pmatrix}
$$
実際には、このスパース行列を使った行列乗算は計算コストが高いため、使用しません。その代わりに、計算の規則性を利用して、要素のペアごとに回転を直接適用します:
$$
R_{\Theta, p}^d q =
\begin{pmatrix}
q_1 \\
q_2 \\
q_3 \\
q_4 \\
\vdots \\
q_{d-1} \\
q_d
\end{pmatrix}
\otimes
\begin{pmatrix}
\cos(p \theta_1) \\
\cos(p \theta_1) \\
\cos(p \theta_2) \\
\cos(p \theta_2) \\
\vdots \\
\cos(p \theta_{d/2}) \\
\cos(p \theta_{d/2})
\end{pmatrix}
+
\begin{pmatrix}
-q_2 \\
q_1 \\
-q_4 \\
q_3 \\
\vdots \\
-q_d \\
q_{d-1}
\end{pmatrix}
\otimes
\begin{pmatrix}
\sin(p \theta_1) \\
\sin(p \theta_1) \\
\sin(p \theta_2) \\
\sin(p \theta_2) \\
\vdots \\
\sin(p \theta_{d/2}) \\
\sin(p \theta_{d/2})
\end{pmatrix}
$$
これで完了です!ベクトル $q$ と $k$ を2次元チャンクに分割して回転を適用し、加算から乗算に切り替えることで、評価時に大幅な性能向上が得られます。
nD次元へのRoPEの拡張
RoPEの1次元の場合について探究してきました。この時点で、直感的には理解しにくいトランスフォーマーのコンポーネントについて、直感的な理解が得られたことを願っています。最後に、画像などのより高次元への拡張について探ってみましょう。
最初の自然な直感として、画像から直接[x,y]
座標ペアを使用することが考えられます。これまでほぼ任意に成分をペアにしていたことを考えると、直感的に思えるかもしれません。しかし、これは間違いです!
1次元の場合、入力ベクトルから値のペアを回転させることで相対位置m-n
をエンコードします。2次元データの場合、水平方向と垂直方向の相対位置(例えばm-n
とi-j
)を独立してエンコードする必要があります。RoPEの素晴らしさは、複数の次元をどのように扱うかにあります。すべての位置情報を単一の回転でエンコードしようとするのではなく、同じ次元内の成分をペアにしてそれらを回転させます。そうしなければ、x
とy
のオフセット情報が混ざってしまいます。各次元を独立して扱うことで、空間の自然な構造を維持します。これは必要な次元数まで一般化できます!
位置エンコーディングの未来
RoPEは位置エンコーディングの最終形態でしょうか?DeepMindの最近の論文はRoPEを深く分析し、いくつかの根本的な問題を指摘しています。
ウェーブレットや階層的な実装といったシグナル処理からのアイデアを取り入れた、将来的なブレークスルーが起こる可能性があると予想しています。また、モデルが展開のためにますます量子化されていく中で、低精度演算でも堅牢さを保つエンコーディング方式にもイノベーションが見られるでしょう。
結論
位置エンコーディングは、トランスフォーマーにおいて後付けのように扱われ続けています。私たちはこれを異なる視点で見るべきです - セルフアテンションにはアキレス腱があり、それは繰り返しパッチが当てられてきたのです。
この記事を通じて、最初は直感的でなかったにもかかわらず、あなたも最先端の位置エンコーディングを発見できたかもしれないということを示せたことを願っています。次回の記事では、パフォーマンスを最大化するためのRoPEの実践的な実装の詳細について探究したいと思います。
参考
- https://fleetwood.dev/posts/you-could-have-designed-SOTA-positional-encoding
- https://kazemnejad.com/blog/transformer_architecture_positional_encoding/
- https://blog.eleuther.ai/rotary-embeddings/
- https://www.youtube.com/watch?v=T3OT8kqoqjc
- https://arxiv.org/pdf/1706.03762
- https://arxiv.org/pdf/2410.06205
- https://arxiv.org/pdf/2104.09864