はじめに
この記事は、論文 Transformers are SSMs(https://arxiv.org/abs/2405.21060 ) とそこで提唱されたアーキテクチャ、Mamba2について解説することを目標にします。
Mamba (https://arxiv.org/abs/2312.00752) は2023年末に発表された論文です。Transformer が席巻する中、全く別のバックグラウンドや数学的な理論づけを持ち、Transformer を上回る性能の汎用アーキテクチャとして話題になりました。
その著者たちが2024年5月に発表したのがこの論文、 Transformers are SSMs です。その名の通り、著者たちは「Transformerの計算は、数学的にはSSM (つまりMamba 1)と等価である」ということを主張しています。その上で、Mambaを発展させたアーキテクチャ Mamba2 を提案しています。
想定読者として、(1)Transformer は知っている、(2) SSMはなんとなく知っている、ぐらいの人をターゲットにします。Transformerは既知として、そこから発展したLinear Attentionと、SSMについてはじめに解説します。
1. Mamba2のアウトライン
この論文の流れをざっくり説明すると以下のようになります。
まず、MambaをはじめとするSSMの処理が、ある性質の良い行列(系列半分離行列 Sequential Semiseparable Matrix: SSS) と入力との掛け算で表されることを示します。
つぎに、Transformerで用いられたScaled Dot-Product Attentionの高速化手法である Linear Attentionも、同様にSSSと入力との掛け算で表されることを示します。
さらに、Mambaの計算と、Attentionの特殊ケースである Structured Masked Attention という計算が一致することをみます。これらの観察から、MambaとTransformerを一般化した State Space Duality :SSD という骨組みを作り、MambaをMulti-Head Attentionに近い形で並列・高速化したMamba2 というアーキテクチャを提唱しています。
2. 前提知識① Linear Attention
Linear Attention (線形時間のAttention)は、Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention (Katharopoulos et al. ICML 2020) で提唱されました。
おさらいをすると、最初に Transformer の原論文で使われたScaled Dot-Product Attentionは、
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
という形をしていました。1トークンが $N$次元かつ、系列の長さが$L$とすると、$Q,K$は$(L,N)$サイズの行列です。計算量は $O(L^2)$ となります。
ここで、Softmaxも含めて第i行の計算を考えると、
\text{Attention}(Q, K, V)_i = \frac{\sum_{j=1}^L \exp(q_i^T k_j)\cdot v_j}{\sqrt{d_k}\cdot\sum_{j=1}^L \exp(q_i^T k_j)}
です。1行の計算中にL回のforループが必要で、これをL回繰り返すので、計算量は$O(L^2)$です。
言い換えれば、一回の計算を$O(1)$で行えるなら全体の計算量を落とすことができます。そのために、まず次のような一般化をします。
\text{Attention}(Q, K, V)_i = \frac{\sum_{j=1}^L \text{sim}(q_i, k_j)\cdot v_j}{\sum_{j=1}^L \text{sim}(q_i, k_j)}
ここで、$\text{sim}(q_i, k_j)$は、$q_i$と$k_j$の類似度を表す関数です。Softmaxを使うAttentionでは、$\text{sim}(q_i, k_j) = \exp({q_i^T k_j})/\sqrt{d_k}$ です。
ここで、 $\text{sim}(q_i, k_j) = \phi(q_i)^T \cdot \phi(k_j)$ という形で表されるなら、計算量を落とすことができます。なぜかというと、
\begin{align*}
\text{Attention}(Q, K, V)_i &= \frac{\sum_{j=1}^L \phi(q_i)^T \cdot \phi(k_j) \cdot v_j}{\sum_{j=1}^L \phi(q_i)^T \cdot \phi(k_j)}
\\
&= \frac{\phi(q_i)^T \cdot \sum_{j=1}^L \phi(k_j) \cdot v_j}{\phi(q_i)^T \cdot \sum_{j=1}^L \phi(k_j)}
\end{align*}
となって、$\sum_{j=1}^L \phi(k_j) \cdot v_j$ は すべてのiに対して同じ値 になるためです。これを事前計算しておけば、計算量は $O(L)$ になります。しかし残念ながら、Softmaxを関数を完全にこの形に分解することはできません。1
なので論文では $\phi(x) = \text{elu}(x) + 1$を採用しています。(eluは次のような形の活性化関数です)
\text{elu}(x) = \begin{cases}
x & (x > 0) \\
\exp(x) - 1 & (x \leq 0)
\end{cases}
大事なのは、Attentionは、適切な形で類似度関数を分解すれば線形時間で計算できる ということです。
3. 前提知識② Mamba
Mambaは状態空間モデル State Space Model : SSMという考え方をもとにした、深層学習のアーキテクチャの一つです。Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Gu & Dao, arXiv 2023) で発表されました。
Mambaを理解するには、HiPPO→S4→Mambaという3本の論文を読む必要があるのですが、これについては過去に書いた記事があります。
なのでここでは、理論は抜きにして現在使われているSSMの構造を紹介します。SSMの基本になるのは次の式です。
\begin{align*}
h_t &= Ah_{t-1} + Bx_t \\
y_t &= Ch_t
\end{align*}
ここで、$x_t$は $x_1, x_2, \cdots, x_T$ という形の入力データの一つで、$y_t$は $y_1, y_2, \cdots, y_T$という形の出力データのうち、時刻 $t$ のものです。
SSMは $x_t$と$y_t$ のほかにもうひとつ、$h_t$ という値を持っています。これは過去の状態を保存したもので、隠れ状態や潜在変数というように考えることができます。
さらに、いくつかのパラメータ $A, B, C$ があります。これらは入力→潜在変数、前の潜在変数→今の潜在変数、潜在変数→出力への変換を司るものです。
さて、最初に提唱された 構造化状態空間系列モデル Structured State Space Sequence Model: S4では、 $A, B, C$は時刻に依存しない(Time-invariant)パラメータでした。実は、これらを固定してあげると、一次元畳み込みを用いて計算を高速に行うことができます。
一方で、このパラメータを固定すると表現力に限界があるという問題もありました。そこで、$A, B, C$ を時間の関数$A_t, B_t, C_t$としてモデル化したのが Mamba です。(正確には $A$ そのものはパラメータのままで、離散化幅を決める $\Delta$ という値を学習可能にしています)
これによって計算速度は落ちるため、これを実装上のいくつかの工夫(Induction Heads, GPUのメモリを意識したアルゴリズムなど)によって補った、というのがMambaの論文です。もちろんどうやって工夫したのかが工学的に大事なところなのですが、Mamba2の話をするためにはこの位で十分です。詳しくは別の記事を参照してください。
さて、最終的に提唱されたMambaのアーキテクチャは次のような形をしています。ニューラルネットによるゲート機構の中にSSMが入っています。
4. Mambaと系列半分離行列(Sequencial Semi-Separable Matrix)
いよいよここからがMamba-2の話です!
この章では、論文を通して重要な概念であるSSSを定義します。そのために、まずMambaの式から話を始めることにします。
\begin{align*}
h_t &= A_t h_{t-1} + B_tx_t \\
y_t &= C_th_t
\end{align*}
まず、ステップ0では $h_0 = B_0x_0$ であるとします。すると帰納的に計算して、
\begin{align*}
h_1 &= A_1 h_0 + B_1 x_1 = A_1 B_0x_0 + B_1 x_ 1 \\
h_2 &= A_2 h_1 + B_2 x_2 = A_2(A_1 h_0 + B_1 x_1) + B_2 x_2\\
&= A_2A_1 B_0x_0 + A_2B_1 x_1 + B_2 x_2
\end{align*}
一般には次のようになります。
\begin{align*}
h_t &= A_t \cdots A_1B_0x_0
\\ &+ A_t\cdots A_2B_1x_1
\\&+ \cdots
\\&+A_t A_{t-1} B_{t-2}x_{t-2}
\\&+ A_tB_{t-1}x_{t-1}
\\&+ B_t x_t\\
&= \sum_{s=0}^t \left(\prod_{i=s + 1}^t A_i\right)B_sx_s
\end{align*}
(元論文の式(3))
ここで、$y_t = C_t^Th_t$ でした。ここに今の$h_t$を展開した結果を代入すると、
y_t = \sum_{s=0}^t C_t^T\left(\prod_{i=s + 1}^t A_i\right)B_sx_s
このようになります。実はこの式、次のような行列$M$を定義すると一発で書くことができます。
\begin{align*}
\boldsymbol{y} &= M\boldsymbol{x}\\
M_{ij} &= C^T_jA_j \cdots A_{i+1}B_i
\end{align*}
実際、t番目の要素を取り出すと
\begin{align*}
y_t &= M_{t0}x_0 + M_{t1}x_1 + \cdots M_{tT}x_T\\
&= \sum_{s =0}^TM_{ts}x_s\\
&= \sum_{s =0}^TC^T_sA_t \cdots A_{s+1}B_tx_s
\end{align*}
となって、先ほどの式(3)の定義と一致します。
さて、ここからSSSを定義します。一般に、行列が 半分離 であるとは次のことを言います。
- 定義(半分離行列)
行列 $A$ が$N$-半分離($N$-semiseparable)であるとは、$A$の下三角成分に位置する部分行列のランクが最大 $N$ であることを言う。
ここで、実は先ほどの$M$ は半分離であることが示せます。(原論文の Lemma 3.3.) 証明は割愛します。
逆に、どんな半分離行列も、$M$と同じような形で表せます。 具体的には、まず$N\text{-SSS}$(N-系列半分離)というのを、行列 $B_i,C_i \in \mathbb{R}^N$と $A_i \in \mathbb{R}^{(T,T)}$に対して
$$M_{ij} = C^T_jA_j \cdots A_{i+1}B_i$$
という行列のことと定義します。
このとき、どんな$N$-半分離行列も、$N\text{-SSS}$表現をもちます。(Proposition 3.4.)
これでMambaの計算を、SSSを用いた行列計算として扱うことができました。
5. SSSの計算
MambaがSSSの操作と同一視できると何が嬉しいかというと、(1)GPU上では行列計算は他の形の計算より高速であること、(2)SSSを計算するためのアルゴリズムが使えること、(3)Attentionとの等価性が言えることです。
(1)は文字通りで、Attentionの計算は行列計算ですが、Mambaの計算は行列計算ではありませんでした。これを行列計算の形にすることで、GPU上での性能向上が見込めます。この章では(2)について扱います。
5.1 SSSの定義
まず、$1\text{-SS}$というものを定義しておきます。これは先に定義した $N$-半分離行列 で $N=1$ とした場合で、どんな部分行列もランク1になります。具体的には、
1SS(a_{0:T}) =
\begin{pmatrix}
1 & &&&\\
a_1 & 1 \\
a_2a_1 & a_2 & 1\\
\vdots & \vdots & \ddots & \ddots\\
a_{T-1}. . .a_1 & a_{T-1}...a_2 & \cdots & a_{T-1} & 1
\end{pmatrix}
という形です。
このようにすると、どんな部分行列も「ある列(行)が別の列(行)の定数倍である」ようになり、ランク1です。例えば、
\begin{pmatrix}
a_1 & 1 \\
a_2a_1 & a_2
\end{pmatrix}
この要素を抜き出すと、左の列は右の列の $a_1$ 倍なので、線型独立な要素はひとつです。行列のランクは線型独立な列ベクトルの数なので、これはランク1です。
5.2 SSSを使った (N=1) SSMの計算
ここで、先程の$M_{ij} = C^T_jA_j \cdots A_{i+1}B_i$ という行列は、
$$M = \text{diag}(C)\cdot \text{1SS}(a_{0:T})\cdot \text{diag}(B)$$
と書き表すことができます。対角行列の掛け算は各要素ごとのスカラー倍と同じなので、Mambaの計算で一番面倒なのは $1\text{-SS}$ の乗算になります。
通常、行列の積はサイズの二乗オーダーになりますが、なんと $T\times T$サイズの $1\text{-SS}$ の掛け算は $O(T)$で行うことができます。
(一般に、$N\text{-SS}$ の掛け算は $O(NT)$です。原論文のProposition 3.6. )
具体的には、まず次のような操作をすればよいことになります。
- (1) 入力 $X$ と、対角行列 $B$ との積を取る
- (2) さらに、$1\text{-SS}$ 行列である $L$ との積を取る
- (2) さらに、対角行列 $C$ との積を取る
(1)(3)は要素ごとの掛け算なので$O(T)$であり、 (2)も $1\text{-SS}$の掛け算なので $O(T)$ です。全部で $O(3T) = O(T)$ の線形時間で計算できることになりました。
次に、この操作を状態数 $N$ に拡張します。
5.3 SSSを使った一般のSSMの計算
Mambaの潜在状態 $h_t$ は入力と同じだけのサイズ(= $P$)を持っていましたが、ここにさらに大きい状態を格納することを考えます。具体的には、サイズを$N$倍します。
すると、先ほどの行列計算の場合では、まず $1\text{-SS}$表現が $N\text{-SS}$ になり、 $B$が$(T)$サイズから $(N,T)$ の行列になり、 $C$が $(T,P)$ サイズから $(T,P,N)$ サイズになります。なので、先の3ステップの操作はこのようになります。
- (1) $Z$ を、入力 $X$ と、$B$ との積とする (サイズ $SP, SN \rightarrow SPN$)
- (2) $H$ を$1\text{-SS}$ 行列である $L$ と $Z$ の積とする
(サイズ $TSN, SPN \rightarrow TPN$) - (2) 出力 $Y$ を、 $C$ と$H$ との積とする (サイズ $TN, TPN \rightarrow TP$)
ここでは単なる行列積(ドット積)ではなく、torch.einsum
(アインシュタインの縮約記法)によって計算しているものと考えてください。この場合の計算量は $O(TN)$になります。
ここまでをまとめます。
- MambaのようなSSM(状態空間モデル) の計算は、SSS(系列半分離行列) を用いて、行列の掛け算の形で表すことができる。
- 状態の大きさ $N$、系列の長さ $T$ のとき、 $O(TN)$ で計算できる。
6. AttentionとSSS
さて、前節ではSSMとSSSの等価性の話をしました。次に、Attentionの一部の形が SSSと等価になる、ということを示していきます。
まずは Transformer で使われた Scaled Dot-Product Attention の式をみます。
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
ここで、 $Q$は一般に $(T,N)$ サイズ、 $K$は $(S,N)$ サイズ、 $V$ は $(S,P)$ サイズです。
Self-Attentionの場合は $S=T$であり、普通はさらに $N=P$ です。この場合は3つの次元はすべて同じ $(T,N)$ になります。
最初の方にあったLinear Attentionを思い出してください。そこでは、$\text{softmax}({QK^T})$ を、$\Phi(Q)\Phi(K)^T$という形にすることで、計算時間を削減していました。
$\Phi$は、$\phi(x) = 1 + \text{elu}(x)$ のような操作で、要素ごとに行うので行列のサイズと同じ計算量がかかります。つまり、この操作は$O(NT)$です。
このようにSoftmaxを分解してよいなら、$\Phi$を作用させる操作はどうせ線形時間なので、無視しても差し支えないです。なので、あらかじめ $Q$, $K$ にこのような関数がかかっているものとして、
y = \text{Attention}(Q, K, V) = QK^TV
というように簡単にしてしまいます。$\sqrt{d_k}$で割る部分も定数倍なので省略します。
さらに、TransformerではMasked Attention というものが用いられます。これは系列の先にあるデータを見ないように、マスクする役割があります。
マスクする、ということは次のような行列$L$を用意して、要素ごとの積(アダマール積)を取ることと同じです。
L = \begin{pmatrix}
1 \\
1 & 1 \\
1 & 1 & 1 \\
\vdots & \vdots & \ddots & \ddots \\
1 & 1 & 1 & \cdots & 1
\end{pmatrix}
なので、結局 Attention は
$$y = (L \circ (QK^T))\cdot V$$
というように定式化できました。この操作は、再度einsum
風の記法を使うと
- (1) $G$ を、入力 $Q$ と、$K$ との積とする (サイズ $TN, SN \rightarrow TS$)
- (2) $H$ を下三角行列である $L$ と $Z$ のアダマール積とする
(サイズ $TS, SN \rightarrow TS$) - (3) 出力 $Y$ を、 $H$ と$V$ との積とする (サイズ $TS, SP \rightarrow TP$)
というように書けます。
ここからさらに式を変形します。まず、
$$(QK^T)V = Q (VK^T)$$
であることを利用します。(この式は一般のAttentionでは成り立ちませんが、Q, K にあらかじめ $\Phi$ を作用させているなら成立します)
これにマスク行列 $L$ も作用させると、次のように書けます。
$$y = Q \cdot \text{cumsum}(K^TV)$$
ただし、$\text{cumsum}$ は、指定した軸方向に累積和をとった同じサイズの行列を返す関数です。(torch.cumsum
やnumpy.cumsum
です)
この操作を、先ほどのようにアインシュタインの縮約記法を使って書くと次のようになります。
- (1) $Z$ を、入力 $V$ と、$K$ との積とする (サイズ $SP, SN \rightarrow SPN$)
- (2) $H$ を、マスク行列である $L$ と $Z$ の積とする
(サイズ $TSN, SPN \rightarrow TPN$) - (2) 出力 $Y$ を、 $Q$ と$H$ との積とする (サイズ $TN, TPN \rightarrow TP$)
6. SSMとAttentionの等価性
ここで、まず5.3で導出したSSMの計算の定式化を見てみます。
- (1) $Z$ を、入力 $X$ と、$B$ との積とする (サイズ $SP, SN \rightarrow SPN$)
- (2) $H$ を$1\text{-SS}$ 行列である $L$ と $Z$ の積とする
(サイズ $TSN, SPN \rightarrow TPN$) - (2) 出力 $Y$ を、 $C$ と$H$ との積とする (サイズ $TN, TPN \rightarrow TP$)
次に、6. で導出した Masked Linear Attentionの式は以下です。
- (1) $Z$ を、入力 $V$ と、$K$ との積とする (サイズ $SP, SN \rightarrow SPN$)
- (2) $H$ を、マスク行列である $L$ と $Z$ の積とする
(サイズ $TSN, SPN \rightarrow TPN$) - (2) 出力 $Y$ を、 $Q$ と$H$ との積とする (サイズ $TN, TPN \rightarrow TP$)
この二つはほぼ同一の操作になっています。
違うのは、行列 $L$ がSSMでは任意の $1\text{-SS}$ であり、Attentionでは $0$ と $1$ だけからなる下三角行列である、という点です。
ところで、「$0$ と $1$ だけからなる下三角行列」は$1\text{-SS}$ になっています。なので結局、 AttentionはSSMの特殊な場合である ということが言えます。論文の題名で言われていた、Transformers are SSMs というのはこのことです。
もちろんすべてのTransformerがSSMであるわけではなく、今までさまざまな仮定を置いてきました。おさらいしましょう。
(1)まず、Attention の類似度関数(通常 Softmax)が、$\text{sim}(q_i,k_j) = \phi(q_i) \phi(k_j)$ と分解できる Linear Attention であることを仮定しました。
(2)さらに、すべての入力を扱う(Non-masked)Attentionではなく、入力をマスクする Masked Attention のみを扱っています。
(SSMはもともと系列データを扱うもので、先のデータはけして見られないので、これは仕方ないです)
この二つの仮定を入れたAttentionはSSMと同じになります。これ以降はこの二つを区別しなくてよいので、まとめて SSD(State Space Dual) と呼びます。
おわりに
ここまでで、この論文の前半部分、TransformerとSSMの等価性について説明し終わりました。
次回はこれを踏まえた新しいアーキテクチャ、 Mamba2 について説明します。
-
指数関数 $\exp(x)$ はいわゆるガウスカーネルであり、これを特徴関数に分解すると無限の次元が必要になります。詳しくはビショップ『パターン認識と機械学習 下』などに説明があります。 ↩