pub struct Mamba2<B: Backend> {
pub in_proj: Linear<B>,
pub conv1d: Conv1d<B>,
pub dt_bias_h: Param<Tensor<B, 1>>,
pub dt_limit: (f64, f64),
pub a_log_h: Param<Tensor<B, 1>>,
pub d_h: Param<Tensor<B, 1>>,
pub norm: RmsNormGated<B>,
pub out_proj: Linear<B>,
pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
pub state_rank: usize,
pub ngroups: usize,
}Expand description
The Mamba-2 SSM block.
Implements the full SSD layer as described in §5 of the paper. Supports two execution modes:
Self::forward— chunkwise SSD for training / prefill (exploits tensor cores; linear in sequence length T)Self::step— pure recurrent form for token-by-token decoding (O(H·P·N) per step; no KV-cache)
§Architecture (one forward pass through the block)
u [B, T, D]
├─ in_proj ──────────────────────────────────┐
│ │
│ z [B,T,I] xbc [B,T,V] dt_raw [B,T,H] │
│ │ │
│ causal Conv1d │
│ │ SiLU │
│ split into │
│ x [B,T,H,P] B [B,T,G,N] C [B,T,G,N]
│ │
│ Δ = softplus(dt_raw + dt_bias) │
│ Ā = exp(Δ · A) [scalar per head] │
│ B̄ = Δ · B │
│ │
│ ┌──── chunked_selective_scan ─────────┐ │
│ │ (Steps 1–4, see below) │ │
│ └────────────────────────────────────-┘ │
│ y [B,T,H,P] │
│ + D skip │
│ RmsNormGated(·, z) │
└─ out_proj ─────────────────────────────────┘
output [B, T, D]Fields§
§in_proj: Linear<B>Input projection: maps d_model → d_inner + conv_dim + nheads.
The output is split into three parts:
z [B, T, d_inner]— multiplicative gate for the output normxbc [B, T, conv_dim]— input to the causal convolution, which is then split into (x, B, C) after activationdt_raw [B, T, nheads]— raw (pre-softplus) discretisation step Δ
conv1d: Conv1d<B>Causal depthwise Conv1d applied to the xbc projection.
- Input/output channels:
conv_dim - Kernel size:
conv_kernel(typically 4) - Groups:
conv_dim(fully depthwise — each channel is independent) - Padding: none (left-padding is applied manually so the convolution is strictly causal)
The convolution provides a local conv_kernel-token context window
before the SSM, which helps the model capture short-range dependencies
that the SSM’s recurrent form handles less efficiently.
dt_bias_h: Param<Tensor<B, 1>>Per-head bias for the discretisation step size Δ.
Shape: [nheads]
At inference time, Δₜ = softplus(dt_raw_t + dt_bias).
Initialised such that the corresponding initial Δ values are
log-uniformly distributed in [dt_min, dt_max].
dt_limit: (f64, f64)Hard clamp applied to Δ after softplus: Δ ∈ [dt_limit.0, dt_limit.1].
Prevents degenerate discretisations (e.g. Δ → 0 causes Ā → 1, meaning the state never decays; Δ → ∞ causes Ā → 0, meaning the state is immediately wiped each step).
a_log_h: Param<Tensor<B, 1>>Per-head log-magnitude of the continuous-time decay parameter A.
Shape: [nheads]
The actual (negative) decay rate is A = -exp(a_log). The discrete
decay is Āₜ = exp(Δₜ · A) = exp(-Δₜ · exp(a_log)) ∈ (0, 1).
Storing the log of the magnitude and negating ensures A < 0 (decaying system) unconditionally and avoids any sign-constraint during gradient descent.
d_h: Param<Tensor<B, 1>>Per-head skip (D) coefficient.
Shape: [nheads]
Adds a direct path from the (post-convolution, pre-SSM) input to the
output: yₜ += D · xₜ. Initialised to ones.
norm: RmsNormGated<B>Gated RMSNorm applied to the SSM output, conditioned on the gate z.
Input channel dimension: d_inner.
This combines the multiplicative gate (from z) and a normalisation
step into a single fused operation, matching the architecture in §5.2
of the paper.
out_proj: Linear<B>Output projection: maps d_inner → d_model.
init_state_hpr: Option<Param<Tensor<B, 3>>>Optional learnable initial hidden state h₀.
Shape: [nheads, per_head_dim, state_rank] (i.e. [H, P, N])
When None, the initial state is zero (the standard default).
When Some, the stored tensor is used as the initial condition for
every forward call (not per-batch; it is broadcast over the batch
dimension).
state_rank: usizeState rank N — the number of latent dimensions in the SSM hidden
state h ∈ ℝ^{N×P} per head. Corresponds to the paper’s N.
ngroups: usizeNumber of B/C groups G for grouped SSM heads (analogous to
grouped-query attention). G divides nheads; all nheads/G heads
within a group share the same B and C projections while having
independent X, A, and Z projections.
Implementations§
Source§impl<B: Backend> Mamba2<B>
impl<B: Backend> Mamba2<B>
Sourcepub fn step(
&self,
input_bm: Tensor<B, 2>,
cache: Option<Mamba2Cache<B>>,
) -> (Tensor<B, 2>, Mamba2Cache<B>)
pub fn step( &self, input_bm: Tensor<B, 2>, cache: Option<Mamba2Cache<B>>, ) -> (Tensor<B, 2>, Mamba2Cache<B>)
Process a single token using the pure recurrent SSM form.
This is the O(H·P·N)-per-token decoding path. It runs one tick of the discretised Mamba-2 recurrence:
Āₜ = exp(Δₜ · A) scalar per head, ∈ (0, 1)
B̄ₜ = Δₜ · Bₜ ∈ ℝᴺ (Euler discretisation)
hₜ = Āₜ · hₜ₋₁ + B̄ₜ · xₜᵀ ∈ ℝ^{P×N} (outer product update)
yₜ = Cₜᵀ · hₜ + D · xₜ ∈ ℝᴾ (output)The convolution is handled by manually sliding the cache window: the oldest input column is dropped and the new token’s projection is appended.
The SSM hidden state cache.ssm_bhpr is updated in-place via
the recurrence above.
§Shapes
input_bm:[batch, d_model]- output :
[batch, d_model]
Source§impl<B: Backend> Mamba2<B>
impl<B: Backend> Mamba2<B>
Source§impl<B: Backend + Mamba2BackendExt> Mamba2<B>
impl<B: Backend + Mamba2BackendExt> Mamba2<B>
Sourcepub fn forward(
&self,
input_bsm: Tensor<B, 3>,
cache: Option<Mamba2Cache<B>>,
ssd_path: Mamba2SsdPath,
) -> (Tensor<B, 3>, Mamba2Cache<B>)
pub fn forward( &self, input_bsm: Tensor<B, 3>, cache: Option<Mamba2Cache<B>>, ssd_path: Mamba2SsdPath, ) -> (Tensor<B, 3>, Mamba2Cache<B>)
Process a full input sequence using the chunkwise SSD algorithm.
This is the primary training and prefill path. The computation is linear in T but uses batched matrix multiplications (GEMMs) that can exploit GPU tensor cores — unlike the naive sequential recurrence, which requires O(T) serial steps.
§Full dataflow
- In-projection:
u → (z, xbc, dt_raw)via a single linear layer. - Causal Conv1d + SiLU: local context mixing over
xbc. - Split:
xbc → (x, B, C). - Discretise:
Δ = softplus(dt_raw + dt_bias);Ā = exp(Δ · A);B̄ = Δ · B. - Padding: sequence padding.
- Chunked SSD: four-step chunkwise algorithm (see
[
Self::chunked_selective_scan]). - Gated RMSNorm:
y = RMSNorm(y) · σ(z). - Out-projection:
y → output.
§Sequence padding
If sequence_unpadded % chunk_len ≠ 0, the sequence is zero-padded
to the next multiple of Q. Zero-padding is equivalent to inserting
identity steps (Δ = 0 ⇒ Ā = exp(0) = 1, B̄ = 0), so the SSM
state is carried forward unchanged through the pad — making it safe to
read the final state of the padded last chunk as the true final state.
§Shapes
input_bsm:[batch, sequence, d_model]- output :
[batch, sequence, d_model] - cache (out) : updated convolution window and SSM state
Source§impl<B: Backend> Mamba2<B>
impl<B: Backend> Mamba2<B>
Sourcepub fn ssd_minimal(input: Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>)
pub fn ssd_minimal(input: Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>)
Minimal chunkwise SSD algorithm.
Implements the four-step decomposition of the semiseparable matrix
multiplication described in §4 of the paper. The sequence of length T
is split into nchunks = ⌈T/Q⌉ chunks of length Q.
§The four steps
§Step 1 — Intra-chunk outputs (Y_diag)
Within each chunk, compute the output assuming the initial hidden state is zero. This is the quadratic attention form of the SSD layer restricted to a window of Q tokens (§4.1):
Y_diag[n] = (L[n] ∘ C[n] B[n]ᵀ) · X[n]where L[n] is the Q×Q 1-semiseparable mask for chunk n.
This step is a batched GEMM (exploits tensor cores).
§Step 2 — Chunk state (state_bnhpr)
Compute the final SSM state of each chunk assuming zero initial state (§4.1, Eq. 20):
s[n] = Σ_{t ∈ chunk n} exp(A_cum[end] - A_cum[t]) · B̄[t] · x[t]ᵀThis is also a batched GEMM and is fully parallel across chunks.
§Step 3 — Inter-chunk state scan (state passing)
Propagate the true hidden state across chunk boundaries using the recurrence (§4.1, Eq. 22):
h[n] = Ā[n]_chunk · h[n-1] + s[n]where Ā[n]_chunk = exp(Σ_{t ∈ chunk n} Δₜ · A) is the cumulative
decay over the whole chunk. This step is implemented as a single
batched matrix multiplication using the 1-semiseparable structure of
the inter-chunk decay matrix (same segsum trick, now over chunks).
The scan has length nchunks = T/Q rather than T, so its cost is
negligible for typical chunk sizes.
§Step 4 — State-to-output (Y_off)
For each chunk n, compute the contribution of the true initial state
h[n-1] to the outputs within that chunk (§4.1, Eq. 23):
Y_off[n, t] = C[n, t]ᵀ · exp(A_cum[t]) · h[n-1]This is again a batched GEMM.
§Final output (with D skip-connection)
Y = Y_diag + Y_off + D · XSource§impl<B: Backend> Mamba2<B>
impl<B: Backend> Mamba2<B>
Sourcepub fn ssd_serial(input: Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>)
pub fn ssd_serial(input: Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>)
Forward pass for the Mamba-2 SSD module.
Returns:
y_bnlhp.final_state_bhpr.
Source§impl<B: Backend + Mamba2BackendExt> Mamba2<B>
impl<B: Backend + Mamba2BackendExt> Mamba2<B>
Sourcepub fn ssd_serial_recalculated(
input: Mamba2SsdInput<B>,
) -> (Tensor<B, 5>, Tensor<B, 4>)
pub fn ssd_serial_recalculated( input: Mamba2SsdInput<B>, ) -> (Tensor<B, 5>, Tensor<B, 4>)
Forward pass for the Mamba-2 SSD module.
Returns:
y_bnlhp.final_state_bhpr.