この記事は話題になった Mamba の入門記事です。Mamba を解説し、実際に PyTorch 上に Mamba を実装してみます。Mamba は、 Transformer の代替として注目されているアーキテクチャです。近年では Samba のように、Mamba と attention を層単位で組み合わせるハイブリッドも注目されています。
今回作成した実装は以下に置いています:
https://github.com/torotoki/mamba-lm
記事を書いた動機
Mamba は既存の Transformer ベースのモデルと大きく異なり、State Space Model (SSM) というモデルアーキテクチャを基にしています。一方で Transformer に比べ Mamba は、同モデルサイズ帯の精度が多くのタスクにおいて優位または同程度だと報告されています1。また、Mamba は特に推論時の計算が少なくて済んだり、長いコンテクストに対応しやすいという利点があります。
Transformer はたくさん記事がありますが、Mamba の実装について詳しく解説した記事や、シンプルな実装はまだ少ないため、今回は実際に私が再現実装したコードを元に解説記事を書くことにしました。
Mambaとは何か
Mamba (Gu and Dao, 2023) は State Space Model (SSM) というモデルの一種です。このSSMというのは、古くはカルマンフィルタ (Kalman, 1960) のような数式で書かれた制御システムを言います。非常に端的に言うと、Recurrent Neural Network (RNN) のように再帰的に計算するモデルです。
今回のMambaの実装にあたり作成した以下のダイアグラムを見てください:

今回はこの図を用いて Mamba のアーキテクチャを解説していきます。
入力$X_t$: ここで入力は自然言語のトークン列の埋め込み$X_1, ..., X_t \in \mathbb{R}^{\text{embed_dim}}$で、次のトークンの埋め込み$X_{t+1}$を出力しています。(以降、利便性のため$X_t$をトークンと言いますが、実際には$X_t$はトークンに対応する埋め込みです。)
隠れ層$h_t$: 右上に隠れ層$h_t$についての再帰的な式がありますが、このような状態方程式で書かれたものを State Space Model (SSM) と呼ばれます。この部分がこのモデルの肝となっています。とはいえ、再帰的な計算なので、RNN の拡張に近いものと言えば分かりやすいかと思います。
離散化: もともとの SSM の式は入力が連続系(時間$t$が連続)ですが、今回は離散時間 $t \in \lbrace1, ..., n\rbrace$ の自然言語のトークン列 $X_1, ... X_t, ..., X_n$ が入力なので、離散化する処理もダイアグラムに入っています。この図では、$(\Delta, A, B)$ は連続的なパラメータ、それに対して $(\bar{A}, \bar{B})$ は離散的なパラメータです。この離散化は色々な方法があるのですが、元論文と同じゼロ次ホールド(Zero-order Hold, ZOH)と呼ばれる方法で行っています。名前がいかめしいですが、ようするに$t$が整数のタイミングで信号が変化し、それ以外は前と同じになっている連続系を考えているだけです。上図の状態方程式の離散化の式が導くには、ZOH で定義した連続系の状態方程式を積分することで出てきます(ここでは省略)。結果的に、図のような $\bar{A}$ と $\bar{B}$ の計算式が得られます。
状態$A,B,C$: 隠れ層 $h_t$ は $A_t, B_t, C_t$ に基づいて計算されます。$B_t$, $C_t$ は入力から畳み込み・線形層を通して計算するので、直観的でしょう。一方で $A_t$ は特殊で、対角行列に制限するため固有値から計算しています。これには以下のような数値計算的な理由があります。
状態$A$の計算: 状態方程式を見ると$\Delta_t$と$\bar{A_t}$は以下のような計算で使われています。本来、$\bar{A_t}$ は $\bar{B_t}, C_t$ と同じように制約なく計算しても構いません。
$$h_t = \bar{A_t} h_{t-1} + \bar{B_t} X_t$$
ただ、上記の式で一般的な $\bar{A_t}$ を許してしまうと、どんどん$h_t$が大きくなってしまいます。
そこで、$\bar{A_t}$ の固有値のみをすべて負の値で固定します。そして、$\lambda$ を $\bar{A_t}$ の固有値を対角成分に並べた対角行列として、$A_t$ を $$A_t = \lambda\Delta_t$$ と定義します。これにより、$h_t$ がオーバーフローする問題を回避しています。実際、$$\bar{A} = \exp{(A_t)}$$ なので、行列 $\bar{A_t}$ の要素の最小値は$0$、最大値は$1$になります。ようするに $A_t$ のパラメータを固有値に限定することで、値に制約を設けているわけですね。
並列化: 最後に、Mamba で採用されている並列計算について概要を解説します。状態方程式を見ると、SSMはプログラム的にはアフィン変換($Ax+b$ のような処理)の、$x$ についての再帰的な繰り返しとなります。この計算は scan と呼ばれる並列計算でよく知られるアルゴリズムを使うことにより、学習時は並列計算(計算結果は同じだが $h_t$ を $t=1, ..., n$ について並列で計算する)、推論時は autoregressive model として逐次的に出力(上図の式と同じ)を出すことが可能になっています。scan は機械学習の文脈ではあまり見かけませんでしたが、並列計算や GPU ではよく知られている問題設定なので、高速な計算が可能というわけです。
Mamba の実装
この章では、前章で解説した Mamba のアーキテクチャを実際に実装に落とし込んでみます。今回は以降の利便性を考えて、Hugging Face Transformer や PyTorch を用います。詳細は以下のファイルを見てください。
https://github.com/torotoki/mamba-lm/blob/v1.0/mamba.py
実装のひな型
Mamba の実装を直接実装する前に、まず「言語モデル」を実装するためのひな型を用意しました。ここでは、以下のクラスを継承して書きます。
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.utils import ModelOutput
class MambaConfig(PretrainedConfig)
class MambaBlock(GradientCheckpointingLayer)
class Mamba(PreTrainedModel)
@dataclass
class MambaOutput(ModelOutput)
特に PreTrainedModel, GradientCheckpointingLayer は torch.nn.Module の継承になっています。そのため、__init__(self, ...) の他に forward(self, x) を実装します。ここら辺は PyTorch とほぼ同じです。
MambaBlock というクラスは、上述で解説した Mamba の中身になっています。Mamba クラスは MambaBlock を複数(設定によりますが4つなど)縦に繋げて、さらに最初に embedding、最後に言語モデルの head を付けて、次の単語の logits を MambaOutput クラスとして出力するものになっています。
これらのクラスを用いてMambaを実装することで、後でhuggingfaceのTrainerによる学習ができます。詳細な実装は上記のリンクを見てください。
MambaBlock の実装
ここでは、実際に MambaBlock の実装を段階を追って見てみましょう。まずはフルの実装は以下です:
class MambaBlock(GradientCheckpointingLayer):
def forward(self, x: torch.Tensor, attention_mask: torch.LongTensor | None):
# x: (B, T, D ( = config.d_model))
# attention_mask: (B, T)
B, T, D = x.shape
x = self.norm(x)
if attention_mask is not None:
# Add dimension to attention_mask for broadcasting
x = x * attention_mask.unsqueeze(2)
# 1. Depth-wise convolution
# Swap the dimensions of x since
# self.seq_conv1d expects the dimension (B, D, T)
xT = x.transpose(1,2) # (B, D, T)
k = self.config.d_conv
xT = F.pad(xT, (k - 1, 0)) # left-most padding
hidden_states = self.seq_conv1d(xT)
hidden_states = hidden_states.transpose(1, 2)
# 2. Gated MLP's linear projection
dlt, Bt, Ct = self.to_params(hidden_states).chunk(3, dim=-1)
# Stabilization using softplus
dlt = F.softplus(dlt) # \Delta_t > 0
lam = -F.softplus(self.lam) # \lambda < 0 (D,)
# 3. Autoregressive State-Space Models (SSM)
ht = torch.zeros(B, D, device=x.device, dtype=x.dtype)
ys = []
for t in range(T):
# NOTE: dlt and x can be transposed for faster computation
dt = dlt[:, t, :] # (B,D)
At_diag = torch.exp(dt * lam) # (B,D) ← lam (D,) is broadcasted
Bt_bar = torch.where(
lam.abs() > 1e-4,
((At_diag - 1.0) / lam) * Bt[:, t, :],
dt * Bt[:, t, :]
) # (B,D)
ht = At_diag * ht + Bt_bar * x[:, t, :] # (B,D)
y = Ct[:, t, :] * ht
ys.append(y)
y = torch.stack(ys, dim=1) # (B,T,D)
# 4. Final linear projection
return self.out(y)
1. 畳み込み層
まずは入力の処理と最初の畳み込み層の部分です。
def forward(self, x: torch.Tensor, attention_mask: torch.LongTensor | None):
# x: (B, T, D ( = config.d_model))
# attention_mask: (B, T)
B, T, D = x.shape
x = self.norm(x)
...
# 1. Depth-wise convolution
# Swap the dimensions of x since
# self.seq_conv1d expects the dimension (B, D, T)
xT = x.transpose(1,2) # (B, D, T)
k = self.config.d_conv
xT = F.pad(xT, (k - 1, 0)) # left-most padding
hidden_states = self.seq_conv1d(xT)
hidden_states = hidden_states.transpose(1, 2)
PyTorch で用意されている畳み込み関数(self.seq_conv1dに相当)は最後の次元について畳み込みを取るため、x.transpose(1,2)でトークンに関する次元(=時間)が最後に来るようにして、畳み込みの後は再び次元を戻しています。
また、data leakage が発生しないように、トークン列の padding を左側に付けていることも重要です。
2. 線形層
続いて畳み込みの次にかけられる線形層を見てみましょう
# 2. Gated MLP's linear projection
dlt, Bt, Ct = self.to_params(hidden_states).chunk(3, dim=-1)
この部分では巨大な線形層を書けた後、chunk でそれぞれ均等な数のパラメータを隠れ層の次元について分割しています。1つずつ線形層をかけても良いですが、この方が計算が速いです。
また、前章で紹介したように、数値安定性のために変数の符号をほぼほぼ固定します。とはいえ、勾配がなくなってしまうと困るので、softplus関数を使って以下のように書きます。
# Stabilization using softplus
dlt = F.softplus(dlt) # \Delta_t > 0
lam = -F.softplus(self.lam) # \lambda < 0 (D,)
3. 自己回帰 State-Space Model (SSM)
あとは、トークン列について再帰的に計算するコードです。ここは次元に気を付ければ比較的図と同じで素直なのではないかと思います。
# 3. Autoregressive State-Space Models (SSM)
ht = torch.zeros(B, D, device=x.device, dtype=x.dtype)
ys = []
for t in range(T):
# NOTE: dlt and x can be transposed for faster computation
dt = dlt[:, t, :] # (B,D)
At_diag = torch.exp(dt * lam) # (B,D) ← lam (D,) is broadcasted
Bt_bar = torch.where(
lam.abs() > 1e-4,
((At_diag - 1.0) / lam) * Bt[:, t, :],
dt * Bt[:, t, :]
) # (B,D)
ht = At_diag * ht + Bt_bar * x[:, t, :] # (B,D)
y = Ct[:, t, :] * ht
ys.append(y)
y = torch.stack(ys, dim=1) # (B,T,D)
# 4. Final linear projection
return self.out(y)
Bt_barの定義を計算する部分では、0除算を避けるためあらかじめ場合分けしています。
なお、ここではscanを使った学習の高速化の部分は書いていません。学習時・推論時を共通のコードで書いています。
言語モデルの実装
さらに、今回は Mamba を言語モデルとして使用したいので、トークンの先読みが発生しないようにcausal maskをかける必要があります。
以下の2点を対応する必要がありました:
-
MambaBlock.forward(x, attention_mask)の実装時に、xのバッチ・トークンの次元にcausal maskを適用する2。 -
ロスの計算時に padding や mask を除外する。
実際の実験
非常に簡単なデータを使って、実際にロスが下がるか検証したところ、以下のように実際に安定してロスが下がりました。
実験のコードも同ファイルのTrainer を動かす部分に入っています。
作っている途中では、特に最初のイテレーションではロスが比較的大きくなるため、bfloat16などだとオーバーフローするなども発生しました。今回のコードは数値安定性を持たせる工夫もしていますが、residuationを入れたり、本家のコードに基づいてもう少し追加の工夫をしても良いかなと思います。
ただ、ロスは安定して下がっており、計算はある程度正しいようです。もう少し複雑な実験も行いたいところですが、長くなってしまうので別記事にしようかと思います。
結論・感想
この記事では、
- Mamba 自体のアーキテクチャの解説
- 実際に私が作った hugging face における実装
を書きました。
Mamba を実装する過程で論文や公式の実装(参考文献[3-4])を読みましたが、今回はコードでは紹介しきれなかった並列計算の部分など、速度面も非常に良く設計されていました。また、論文では自然言語だけでなく音声データも扱うなど、モデル・コード・論文で非常に良く設計されていると感じました。Transformer の代替というよりは、特徴が Transformer と異なるため、強みが発揮できる点で使われていくのではないかと思います。
参考文献
[1] Albert Gu and Tri Dao. "Mamba: Linear-time sequence modeling with selective state spaces." arXiv preprint arXiv:2312.00752 (2023).
[2] Rudolph Emil Kalman. "A new approach to linear filtering and prediction problems." (1960): 35-45.
[3] An official implementation of Mamba (includes highly-optimized implementations): https://github.com/state-spaces/mamba
[4] Another official implementation on huggingface transformer: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mamba/modeling_mamba.py
[5] Kaiyue Wen, Xingyu Dang, and Kaifeng Lyu. "Rnns are not transformers (yet): The key bottleneck on in-context retrieval." arXiv preprint arXiv:2402.18510 (2024).
