0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Derivation of Mamba-3 Equations (1),(2)

0
Last updated at Posted at 2026-03-25

This article derives the key equations (1),(2) in the Mamba-3 paper with one goal: every single step should be easy to follow. No steps are skipped. I will post the other proof later. Each equation transforms into the next with a clear reason. If you can follow basic calculus and linear algebra, you can follow this entire derivation.

1. Definitions: Dimensions and Parameters

1.1 Basic Dimensions

  • $T$: Sequence Length
  • $D$: Number of Input/Output Channels (Model Dimension)
  • $N$: State Dimension (per channel)

1.2 Variable Dimension Definitions

  • $\mathbf{X} \in \mathbb{R}^{T \times D}$: Input sequence
  • $\mathbf{Y} \in \mathbb{R}^{T \times D}$: Output sequence
  • $\mathbf{h}(t) \in \mathbb{C}^{D \times N}$: State matrix. Each column $\mathbf{h}_d(t) \in \mathbb{C}^{N \times 1}$ is the state vector of channel $d$.
  • $\mathbf{A}(t)$: Collection of diagonal components.
    • Storage dimension: $\mathbb{C}^{D \times N}$. Each element $A_{d,n}(t)$ is the complex coefficient for the $n$-th state of channel $d$.
    • Operational dimension: $\mathbb{C}^{D \times N \times N}$. For each channel $d$, it acts as an $N \times N$ diagonal matrix $\mathbf{A}_d(t) = \mathrm{diag}(A_{d,1}(t), \dots, A_{d,N}(t))$.
  • $\mathbf{B}(t) \in \mathbb{R}^{T \times N}$, $\mathbf{C}(t) \in \mathbb{R}^{T \times N}$: Projection vectors shared across all channels.

2. Derivation of Equation (1): The Discrete SSM Recurrence

The paper's equation (1) is the discrete-time recurrence:

$$\mathbf{h}_t = \alpha_t \mathbf{h}_{t-1} + \gamma_t \mathbf{B}_t \mathbf{x}_t, \qquad \mathbf{y}_t = \mathbf{C}_t^\top \mathbf{h}_t \tag{paper 1}$$

where:

  • $A_t < 0$ is a scalar (the *scalar SSM* parameterization: $\mathbf{A}\_t = A\_t \mathbf{I}\_N$). The constraint $A\_t < 0$ ensures stability: in the continuous-time ordinary differential equation (ODE), the homogeneous solution is $\mathbf{h}(t) \propto e^{At}$, which decays only when $A < 0$. If $A > 0$, the state would grow exponentially.
  • $\Delta_t > 0$ is the time step size.
  • $\alpha_t := e^{\Delta_t A_t} \in (0,1)$ is the scalar state-transition.
  • $\gamma_t := \Delta_t$ is the discretization factor.
  • $\mathbf{B}_t \in \mathbb{R}^{N}$: input projection, $\mathbf{C}_t \in \mathbb{R}^{N}$: output projection, $\mathbf{h}_t \in \mathbb{R}^{N}$: hidden state.

We derive this from the continuous-time state space model (SSM).

2.1 Continuous-Time SSM

The underlying continuous-time ODE is:

$$\dot{\mathbf{h}}(t) = \mathbf{A}(t) \mathbf{h}(t) + \mathbf{B}(t) x(t), \qquad y(t) = \mathbf{C}(t)^\top \mathbf{h}(t) \tag{1}$$

where $\mathbf{h}(t) \in \mathbb{R}^{N}$, $\mathbf{A}(t) \in \mathbb{R}^{N \times N}$, $\mathbf{B}(t), \mathbf{C}(t) \in \mathbb{R}^{N}$, and $x(t), y(t) \in \mathbb{R}$.
This is for a single channel; the multi-channel (multi-input, multi-output (MIMO)) extension is discussed in Section 3.3.

2.2 Continuous-Time Solution (Integrating Factor Method)

To solve the first-order linear non-homogeneous ODE (1), we multiply both sides by the integrating factor $\exp\left(-\int_{0}^{t} \mathbf{A}(\tau) d\tau\right)$:

$$\exp\left(-\int_{0}^{t}\mathbf{A}(\tau) d\tau\right) \left[\dot{\mathbf{h}}(t) - \mathbf{A}(t) \mathbf{h}(t) \right] = \exp\left(-\int_{0}^{t}\mathbf{A}(\tau) d\tau\right) \mathbf{B}(t) x(t) \tag{2}$$

We apply the reverse of the product rule $\frac{d}{dt}(fg) = f'g + fg'$ to the left-hand side:

$$\frac{d}{dt}\left[\exp\left(-\int_{0}^{t}\mathbf{A}(\tau) d\tau\right) \mathbf{h}(t) \right] = \exp\left(-\int_{0}^{t}\mathbf{A}(\tau) d\tau\right) \mathbf{B}(t) x(t) \tag{3}$$

Setting the initial state $\mathbf{h}(0)=\mathbf{0}$ and integrating both sides from $0$ to $t$:

$$\exp\left(-\int_{0}^{t}\mathbf{A}(\tau) d\tau\right) \mathbf{h}(t) - \underbrace{\exp(0) \mathbf{h}(0)}_{= \mathbf{0}} = \int_{0}^{t}\exp\left(-\int_{0}^{s}\mathbf{A}(\tau) d\tau\right) \mathbf{B}(s) x(s) ds \tag{4}$$

Multiplying both sides from the left by $\exp\left(\int_{0}^{t}\mathbf{A}(\tau) d\tau\right)$:

$$\mathbf{h}(t) = \exp\left(\int_{0}^{t}\mathbf{A}(\tau) d\tau\right) \int_{0}^{t}\exp\left(-\int_{0}^{s}\mathbf{A}(\tau) d\tau\right) \mathbf{B}(s) x(s) ds \tag{5}$$

Since $\exp\left(\int_{0}^{t}\mathbf{A} d\tau\right)$ does not depend on $s$, we bring it inside the integral by linearity. Then, because $\mathbf{A}$ is diagonal, $\int_{0}^{t}\mathbf{A} d\tau$ and $\int_{0}^{s}\mathbf{A} d\tau$ commute, so we can combine the two matrix exponentials via $\exp(\mathbf{M}_1)\exp(\mathbf{M}_2) = \exp(\mathbf{M}_1 + \mathbf{M}_2)$:

$$\mathbf{h}(t) = \int_{0}^{t}\exp\left(\int_{0}^{t}\mathbf{A}(\tau) d\tau - \int_{0}^{s}\mathbf{A}(\tau) d\tau\right) \mathbf{B}(s) x(s) ds \tag{6}$$

By the linearity of integration $\int_{0}^{t} - \int_{0}^{s} = \int_{s}^{t}$, we obtain the continuous-time general solution:

$$\mathbf{h}(t) = \int_{0}^{t} \underbrace{\exp\left(\int_{s}^{t}\mathbf{A}(\tau) d\tau\right)}_{\Phi(t,s)} \mathbf{B}(s) x(s) ds \qquad[\mathbf{h}(t)\in\mathbb{R}^{N}] \tag{7}$$

2.3 Discretization

We derive the state $\mathbf{h}_t$ at time $\tau_t$ by splitting the integration range $[0, \tau_t]$ into $[0, \tau_{t-1}]$ and $[\tau_{t-1}, \tau_t]$:

$$\mathbf{h}_t = \int_{0}^{\tau_{t-1}}\Phi(\tau_t,s) \mathbf{B}(s) x(s) ds + \int_{\tau_{t-1}}^{\tau_t}\Phi(\tau_t,s) \mathbf{B}(s) x(s) ds \tag{8}$$

We apply the semigroup property $\Phi(t,s)=\Phi(t,u)\Phi(u,s)$ to the first term:

$$\mathbf{h}_t = \Phi(\tau_t,\tau_{t-1}) \underbrace{\int_{0}^{\tau_{t-1}}\Phi(\tau_{t-1},s) \mathbf{B}(s) x(s) ds}_{\mathbf{h}_{t-1}} + \int_{\tau_{t-1}}^{\tau_t}\Phi(\tau_t,s) \mathbf{B}(s) x(s) ds \tag{9}$$

Right-hand approximation. Following the paper (Section 3.1.1), we approximate $\mathbf{A}(s) \approx \mathbf{A}(\tau_t) =: \mathbf{A}_t$ for all $s \in [\tau_{t-1},\tau_t]$, with $\Delta_t := \tau_t - \tau_{t-1}$.

State-transition. From the definition of $\Phi$ in (7):

$$\Phi(\tau_t,\tau_{t-1}) = \exp\left(\int_{\tau_{t-1}}^{\tau_t}\mathbf{A}(s) ds\right) \approx \exp\left(\int_{\tau_{t-1}}^{\tau_t}\mathbf{A}_t ds\right) = \exp\left(\mathbf{A}_t(\tau_t - \tau_{t-1})\right) = \exp(\Delta_t \mathbf{A}_t) \tag{10}$$

State-input integral. Starting from the second term in (9), we expand $\Phi(\tau_t,s)$ under the right-hand approximation $\mathbf{A}(s)\approx\mathbf{A}_t$:

$$\int_{\tau_{t-1}}^{\tau_t}\Phi(\tau_t,s) \mathbf{B}(s) x(s) ds = \int_{\tau_{t-1}}^{\tau_t}\exp\left(\int_{s}^{\tau_t}\mathbf{A}(\sigma)d\sigma\right) \mathbf{B}(s) x(s) ds \approx \int_{\tau_{t-1}}^{\tau_t}\exp\left((\tau_t-s)\mathbf{A}_t\right) \mathbf{B}(s) x(s) ds \tag{11}$$

Under exponential-Euler, we further hold $\mathbf{B}(s)\approx\mathbf{B}_t$ and $x(s)\approx x_t$:

$$\approx \int_{\tau_{t-1}}^{\tau_t}\exp\left((\tau_t-s)\mathbf{A}_t\right) ds ; \mathbf{B}_t x_t \approx \Delta_t \mathbf{B}_t x_t \tag{12}$$

This yields the general discrete update:

$$\mathbf{h}_t \approx \exp(\Delta_t \mathbf{A}_t)\mathbf{h}_{t-1} + \Delta_t \mathbf{B}_t x_t \tag{13}$$

Scalar SSM parameterization. Mamba-2/3 parameterizes $\mathbf{A}_t = A_t \mathbf{I}_N$ with $A_t < 0$ (scalar times identity). Defining

$$\alpha_t := e^{\Delta_t A_t} \in (0,1), \qquad \gamma_t := \Delta_t \tag{14}$$

the update (13) becomes:

$$\mathbf{h}_t = \alpha_t \mathbf{h}_{t-1} + \gamma_t \mathbf{B}_t x_t, \qquad y_t = \mathbf{C}_t^\top \mathbf{h}_t \tag{15}$$

which is the paper's equation (1). $\blacksquare$

2.4 Recursive Unrolling

We unroll the recurrence (15) starting from $\mathbf{h}_0 = \mathbf{0}$.

Base cases.

$$\begin{aligned} \mathbf{h}_0 &= \mathbf{0} \\ \mathbf{h}_1 &= \alpha_1\mathbf{h}_0 + \gamma_1\mathbf{B}_1 x_1 = \gamma_1\mathbf{B}_1 x_1 \\ \mathbf{h}_2 &= \alpha_2\mathbf{h}_1 + \gamma_2\mathbf{B}_2 x_2 = \alpha_2\gamma_1\mathbf{B}_1 x_1 + \gamma_2\mathbf{B}_2 x_2 \\ \mathbf{h}_3 &= \alpha_3\mathbf{h}_2 + \gamma_3\mathbf{B}_3 x_3 = \alpha_3\alpha_2\gamma_1\mathbf{B}_1 x_1 + \alpha_3\gamma_2\mathbf{B}_2 x_2 + \gamma_3\mathbf{B}_3 x_3 \end{aligned}$$

Pattern. Each term indexed by $s$ carries the input $\gamma_s \mathbf{B}_s x_s$ propagated forward through the product of transition scalars from step $s+1$ to $t$. Using the paper's cumulative product notation $\alpha_{t:s}^\times := \prod_{j=s}^{t}\alpha_j$ (with $\alpha_{t:t+1}^\times = 1$ when the range is empty):

$$\mathbf{h}_t = \sum_{s=1}^{t}\alpha_{t:s+1}^\times \gamma_s \mathbf{B}_s x_s \tag{16}$$

3. Derivation of Equation (2): The Matrix Output Form

Using the discrete SSM derived in Section 2, we now derive the matrix output form $\mathbf{Y} = (\mathbf{L}\odot\mathbf{C}\mathbf{B}^\top)\mathbf{X}$ (paper equation 2).

3.1 Output Equation

The output at time step $t$ is:

$$y_t = \mathbf{C}_t^\top \mathbf{h}_t \tag{17}$$

Substituting the unrolled state (16):

$$y_t = \mathbf{C}_t^\top \sum_{s=1}^{t}\alpha_{t:s+1}^\times \gamma_s \mathbf{B}_s x_s = \sum_{s=1}^{t}\alpha_{t:s+1}^\times \gamma_s (\mathbf{C}_t^\top\mathbf{B}_s) x_s \tag{18}$$

Since $\alpha_t$ is scalar, the cumulative product $\alpha_{t:s+1}^\times$ commutes freely with the vectors.

3.2 Kernel Factorization

Define (for $t,s = 1,\dots,T$):

  • Structured mask $\mathbf{L}\in\mathbb{R}^{T\times T}$: $(t,s)$-element $L_{ts}=\alpha_{t:s+1}^\times \gamma_s = \left(\prod_{j=s+1}^{t}\alpha_j\right)\gamma_s$ (lower triangular: $L_{ts}=0$ for $s>t$, because the recurrence is causal --- the unrolled sum (16) only includes inputs up to time $t$, so future inputs $x_s$ with $s>t$ have no contribution)
  • Projection kernel $\mathbf{C}\mathbf{B}^\top\in\mathbb{R}^{T\times T}$: $(t,s)$-element $(\mathbf{C}\mathbf{B}^\top)_{ts}=\mathbf{C}_t^\top\mathbf{B}_s = \sum_{n=1}^{N}C_{t,n}B_{s,n}$

Note that $\mathbf{L}$ absorbs both the cumulative transition $\alpha_{t:s+1}^\times$ and the discretization factor $\gamma_s$. Both $\mathbf{L}$ and $\mathbf{C}\mathbf{B}^\top$ are $T\times T$ matrices (shared across channels).

The output becomes:

$$y_t = \sum_{s=1}^{T}L_{ts}\cdot(\mathbf{C}\mathbf{B}^\top)_{ts}\cdot x_s = \sum_{s=1}^{T}(\mathbf{L}\odot\mathbf{C}\mathbf{B}^\top)_{ts} x_s \tag{19}$$

3.3 Multi-Channel (MIMO) Extension

The recurrence (15) is applied independently to each of $D$ channels with the same $\mathbf{B}_t$, $\mathbf{C}_t$, and $\mathbf{L}$, but potentially different inputs $x_t^{(d)}$. Collecting all channels into matrices $\mathbf{X},\mathbf{Y}\in\mathbb{R}^{T\times D}$, the per-channel outputs stack into:

$$\mathbf{Y} = (\mathbf{L}\odot\mathbf{C}\mathbf{B}^\top) \mathbf{X} \tag{20}$$

where $(\mathbf{L}\odot\mathbf{C}\mathbf{B}^\top)\in\mathbb{R}^{T\times T}$ is a single matrix applied to all $D$ columns of $\mathbf{X}$. $\blacksquare$

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?