要約
- RoPE(Rotary Positional Embedding)とは、エンべディングなどで得られたベクトルを回転させることで、ベクトルに位置情報を付与する手法。
- RoPEの優れた特徴は以下
- シークエンスの長さに対して柔軟性がある。
- トークン同士の距離が離れるに従って、その関係を低く評価するようにできている。
- Self-Attentionを相対位置エンコーディングでエンコードできること。
- 2023年9月の論文で提唱された手法。
- 自然言語処理や画像生成AIで活用されている。
- 絶対位置エンコーディングや相対位置エンコーディングなどの既存の位置情報を扱う手法を凌駕した結果を出した。
以上です。本手法は画像生成AIのFLUXのAttention内部をいじっている最中に見つけ、それが勉強するきっかけとなりました。
各位置エンコーディング手法について
なぜ位置情報を扱う必要があるのか
まず、なぜ位置情報を扱う必要があるのかについて自然言語処理の文脈で説明します。
かの有名なTransformerが論文で発表されてから、自然言語処理でのタスクは以下のアーキテクチャがデファクトスタンダードになりました。
1, 単語のEmbedding
2, Transformerを活用
3, 下流タスクに合わせて線型結合などのNNを活用する。
1のEmbeddingでは単語同士の近さを学習させます(単語Aが出てきた時に、次に出てきやすい単語は何か?)。
2のTransformerではシークエンス内部にある各単語同士の相関を学習させます(1つの文章が与えられた時、その文章内における単語同士はどのような関係があるか?)。
この時、1と2だけ学習させても、単語の順序や文脈の情報が欠落しています。そこで単語がどの位置にあるのかという情報をAIに理解させるために位置情報を付与する方法が必要になるわけです。
したがってPositional Embeddingは単語を機械が理解できるようにEmbeddingでベクトル化した後、Transformerでシークエンス内部の関係性を学習させる前にシークエンス内部にある各トークンに対して追加する必要があります。
代表的にな位置情報付与の手法について簡単に説明します。
絶対位置エンべディング
各トークンはn次元のベクトルで構成されます。
このアプローチでは各トークンの各ベクトルの要素に対して、sinやcosなどのいい感じの関数を見つけ出し、その関数から抽出される値足し算することで、位置情報を付与します(以下の図のベクトル要素1や2に対して関数から見つけた数値を足す)。
よく知られている絶対位置エンコーディングの関数として以下があります。
\begin{align}
p_{2t} &= \sin\left( \frac{k}{10000^{\frac{2t}{d}}} \right) \\
p_{2t+1} &= \cos\left( \frac{k}{10000^{\frac{2t}{d}}} \right) \tag{4}
\end{align}
ここでkはシークエンスにおける位置です。上図で”私”がk=1、”は”がk=2となります。
dは次元数。
tはベクトル要素1や要素2のような各トークンが持つベクトル要素の位置を表現しています。
ベクトル要素tが偶数の時はsinカーブから抽出された数値を、tが奇数のときはcosカーブから抽出された数値をその要素に足し算するという意味になります。
相対位置エンべディング
絶対位置エンべディングでは何らかの固定された関数を見つけ出し、それから出てくる数値を利用して位置情報を付与しました。
相対位置エンべディングでは位置情報を学習可能なものとして扱う手法です。
論文を読むとさまざまな手法を紹介してくれていますが、難しかったので割愛します。
2.3 Relative position embeddingを参照:
RoPEの概要
まずは2次元平面で考えると
まず、2次元空間でRoPEを適用する例について考える
早速だが以下の式になる
\begin{align}
f^{\{q,k\}}(x_m, m) &=
\begin{bmatrix}
\cos(m\theta) & -\sin(m\theta) \\
\sin(m\theta) & \cos(m\theta)
\end{bmatrix}
\begin{bmatrix}
W^{(11)}_{\{q,k\}} & W^{(12)}_{\{q,k\}} \\
W^{(21)}_{\{q,k\}} & W^{(22)}_{\{q,k\}}
\end{bmatrix}
\begin{bmatrix}
x^{(1)}_m \\
x^{(2)}_m
\end{bmatrix}
\end{align}
ここでは単語Embeddingによって得られた位置$m$の2つの要素(=2次元)から成立するトークン$x_m^{(1)}$と$x_m^{(2)}$を考える。
まず線形変換でQとKを計算する。その部分が以下となる。
\begin{align}
\begin{bmatrix}
W^{(11)}_{\{q,k\}} & W^{(12)}_{\{q,k\}} \\
W^{(21)}_{\{q,k\}} & W^{(22)}_{\{q,k\}}
\end{bmatrix}
\begin{bmatrix}
x^{(1)}_m \\
x^{(2)}_m
\end{bmatrix} &=\begin{bmatrix}
W^{(11)}_{\{q,k\}} x^{(1)}_m + W^{(12)}_{\{q,k\}} x^{(2)}_m \\
W^{(21)}_{\{q,k\}} x^{(1)}_m + W^{(22)}_{\{q,k\}} x^{(2)}_m
\end{bmatrix}
\end{align}
右辺に書かれた部分の1行目がQの要素、2行目がKの要素を計算したものになる。
最後に以下の部分だが、これは回転行列と呼ばれ、名前の通りベクトルを回転する際に用いる。
\begin{align}
\begin{bmatrix}
\cos(m\theta) & -\sin(m\theta) \\
\sin(m\theta) & \cos(m\theta)
\end{bmatrix}
\end{align}
以上より、Embeddingによって得られたベクトルを線形変換することでQとKを抽出し、それに回転する作業を数式で記述できた。
回転行列に関するメモ書き
なぜSinやCosを使うと行列が回転したことになるのかをベクトルを基底ベクトルに分解してから〜という流れで説明している。
これを一般化すると
以下のようになる。
- $x_m$はシークエンスにおけるm番目のトークン
- dはそのトークンの次元数
- Rが回転行列。
回転は少なくとも2軸なければできない(1次元空間でベクトルを回転させることを考えてみてほしい)のでトークンの各要素は2要素を1セットにして扱う。例えば、トークンの全要素が768次元の時、1と2次元目が1組目、3と4次元が2組目といった感じ。そのため回転行列が対角行列ではなくなっている。加えて、トークンは偶数次元でなければならない。
RoPEの特徴
ここではRoPEの特徴について説明する
Long-term decay(遠くのトークンの関連性が低くなる性質)
文章などで、遠くにある単語と近くにある単語と任意の単語の関連を比べたとき、近くにある単語の方が当単語と関係があると考える方が好ましい。従って、遠くのトークンとの関連性がRoPEを使うと低くなる性質は好ましいものだと考えられている。
なぜ、このような性質があるのか?理由はシンプルで回転させているから。遠くにあるトークンはコサイン類似度が小さくなる。
自然と相対位置エンべディングとなっている
RoPEが相対位置エンべディングになるには以下の前提条件が必要である。
前提条件:kとqは同じ角度での回転をする、つまり$θ$が同じハイパラである。
これは
\begin{align}
\begin{bmatrix}
\cos(m\theta) & -\sin(m\theta) \\
\sin(m\theta) & \cos(m\theta)
\end{bmatrix}
\end{align}
にて説明を省略したが、
- mはトークンのシークエンスにおける位置を示し
- $\theta$は$\theta_i = 10000^{-2(i - 1)/d}(iはトークン内におけるi次元目を意味する)$という式で、10000というハイパラが入っている。
このハイパラがkとqで一致している限り、RoPEは相対位置エンべディングと同じ性質になる。
実装するには
以下は画像生成AIのFLUXで利用されているRoPEの実装のためのコードである。
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float() # 偶数の取り出し
x_odd = x[..., 1::2].float() # 奇数の取り出し
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
引用元:
回転行列を活用してRoPEを適用している。
ただ、論文内部に計算効率最大化のための解析解があったのでそれを一応紹介する。
それが以下だ。
R^d_{\Theta, mx} =
\begin{bmatrix}
x_1 \\
x_2 \\
x_3 \\
x_4 \\
\vdots \\
x_{d-1} \\
x_d
\end{bmatrix}
\otimes
\begin{bmatrix}
\cos(m\theta_1) \\
\cos(m\theta_1) \\
\cos(m\theta_2) \\
\cos(m\theta_2) \\
\vdots \\
\cos\left(m\theta_{d/2}\right) \\
\cos\left(m\theta_{d/2}\right)
\end{bmatrix}
+
\begin{bmatrix}
- x_2 \\
x_1 \\
- x_4 \\
x_3 \\
\vdots \\
- x_d \\
x_{d-1}
\end{bmatrix}
\otimes
\begin{bmatrix}
\sin(m\theta_1) \\
\sin(m\theta_1) \\
\sin(m\theta_2) \\
\sin(m\theta_2) \\
\vdots \\
\sin\left(m\theta_{d/2}\right) \\
\sin\left(m\theta_{d/2}\right)
\end{bmatrix}
よく見るべきところは以下で
- アダマール積を用いている(同じ位置にある要素同士で掛け算する)
- 1と2次元、3と4次元のように要素ごとのペアが存在すること
以上の点を気を付けて見てみると、以下に書いてあることと全然変わらない。
\begin{align}
f^{\{q,k\}}(x_m, m) &=
\begin{bmatrix}
\cos(m\theta_1) & -\sin(m\theta_1) \\
\sin(m\theta_1) & \cos(m\theta_1)
\end{bmatrix}
\begin{bmatrix}
x_1 \\
x_2
\end{bmatrix}
\end{align}
- $x_1$と$cos(m\theta_1)$の掛け算 + $x_2$と$-\sin(m\theta_1)$の掛け算が足し算される。
- $x_1$と$sin(m\theta_1)$と掛け算 + $x_2$と$\cos(m\theta_1)$の掛け算が足し算される。
余談:画像生成AIのFLUXで活用されているRoPEを見て、改善する余地があるなと思ったところ
わかる人向けに書いています。
FLUXでテキストから画像(size=1024 * 1024)を生成する時を考える。
この時、潜在空間の隠れ状態は以下のようになる。
この隠れ状態のベクトルを線形変換し、QやKを求めた後、RoPEを適用する流れとなっている。
この時、シークエンス内部のImageVectorは以下のようになっている。
画像の左上から右に向かって順番に要素を一列に並べてシークエンスを構成している。
そして、画像を1列に並べたものに対して行列回転をおこな。
私は上記でRoPEの利点の1つの”Long-term decay”を紹介した。これは隣同士のシークエンス内の要素の関連性を強いと判断し、遠いもの同士の要素の関連性は弱いと判断する。
であるならば、FLUX内部では要素1番は以下の要素64番目の方が要素65番目よりもコサイン類似度の観点で見れば近いということになる。
ただ、画像では明らかに1番目の要素は64番目の要素よりも65番目の要素の方との関連性が高い。
Reference
RoPEの論文:
Positinal EmbeddingとPositional Encodingの違い:
回転行列について:
FLUXのAttentionのRotary :
Medium :