pub struct Mamba3<B: Backend> {Show 20 fields
pub in_proj: Linear<B>,
pub dt_bias_h: Param<Tensor<B, 1>>,
pub dt_limit: (f64, f64),
pub a_floor: f64,
pub d_h: Param<Tensor<B, 1>>,
pub b_norm: RmsNorm<B>,
pub c_norm: RmsNorm<B>,
pub b_bias_hrn: Param<Tensor<B, 3>>,
pub c_bias_hrn: Param<Tensor<B, 3>>,
pub mimo_x: Option<Param<Tensor<B, 3>>>,
pub mimo_z: Option<Param<Tensor<B, 3>>>,
pub mimo_o: Option<Param<Tensor<B, 3>>>,
pub out_norm: Option<RmsNormGated<B>>,
pub out_proj: Linear<B>,
pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
pub state_rank: usize,
pub ngroups: usize,
pub num_rope_angles: usize,
pub rope_dim: usize,
pub mimo_rank: usize,
}Expand description
The Mamba-3 SSM block.
Implements the full Mamba-3 layer with exponential-trapezoidal discretization and data-dependent RoPE. Supports SISO (mimo_rank=1) and MIMO (mimo_rank>1). Supports two execution modes:
Self::forward— chunkwise two-SSD algorithm for training / prefillSelf::step— recurrent form for token-by-token decoding
Fields§
§in_proj: Linear<B>Input projection.
For SISO (R=1): maps d_model → 2·d_inner + 2·ngroups·state_rank + 3·nheads + num_rope_angles.
For MIMO (R>1): maps d_model → 2·d_inner + 2·ngroups·state_rank·R + 3·nheads + num_rope_angles.
Output splits: [z | x | B_raw | C_raw | dd_dt | dd_A | lam_raw | theta_raw]
dt_bias_h: Param<Tensor<B, 1>>Per-head bias for the discretisation step size Δ.
Shape: [nheads]
dt_limit: (f64, f64)Hard clamp applied to Δ after softplus.
a_floor: f64Minimum absolute value of A: A ∈ (−∞, −a_floor].
d_h: Param<Tensor<B, 1>>Per-head skip (D) coefficient.
Shape: [nheads]; initialised to ones.
b_norm: RmsNorm<B>RMSNorm applied to the B projection (QK-Norm, no gating).
Normalises over the state_rank dimension.
c_norm: RmsNorm<B>RMSNorm applied to the C projection (QK-Norm, no gating).
Normalises over the state_rank dimension.
b_bias_hrn: Param<Tensor<B, 3>>Learnable per-head, per-rank bias for B, added after QK-norm.
Shape: [nheads, mimo_rank, state_rank]; initialised to ones.
For SISO (mimo_rank=1) this has shape [nheads, 1, state_rank].
c_bias_hrn: Param<Tensor<B, 3>>Learnable per-head, per-rank bias for C, added after QK-norm.
Shape: [nheads, mimo_rank, state_rank]; initialised to ones.
mimo_x: Option<Param<Tensor<B, 3>>>MIMO up-projection for x (values).
Shape: [nheads, mimo_rank, per_head_dim].
Only present when mimo_rank > 1. When SISO, this is None.
mimo_z: Option<Param<Tensor<B, 3>>>MIMO up-projection for z (gate).
Shape: [nheads, mimo_rank, per_head_dim].
Only present when mimo_rank > 1.
mimo_o: Option<Param<Tensor<B, 3>>>MIMO down-projection for the output.
Shape: [nheads, mimo_rank, per_head_dim].
Only present when mimo_rank > 1.
out_norm: Option<RmsNormGated<B>>Optional gated RMSNorm applied before the output projection.
When Some, the SiLU gate at the block tail is replaced by
RmsNormGated(y, z) which normalises y over per_head_dim and
gates with SiLU(z). Created when has_outproj_norm = true.
out_proj: Linear<B>Output projection: maps d_inner → d_model.
init_state_hpr: Option<Param<Tensor<B, 3>>>Optional learnable initial hidden state h₀.
Shape: [nheads, per_head_dim, state_rank]
state_rank: usizeState rank N.
ngroups: usizeNumber of B/C groups G. Must divide nheads.
num_rope_angles: usizeNumber of RoPE angle pairs (rope_dim / 2).
rope_dim: usizeEffective RoPE dimension (= 2 · num_rope_angles). Always even and
≤ state_rank. Only the first rope_dim entries of B/C are rotated.
mimo_rank: usizeMIMO rank R. 1 = SISO (standard Mamba-3).
Implementations§
Source§impl<B: Backend> Mamba3<B>
impl<B: Backend> Mamba3<B>
Sourcepub fn step(
&self,
input_bm: Tensor<B, 2>,
cache: Option<Mamba3Cache<B>>,
) -> (Tensor<B, 2>, Mamba3Cache<B>)
pub fn step( &self, input_bm: Tensor<B, 2>, cache: Option<Mamba3Cache<B>>, ) -> (Tensor<B, 2>, Mamba3Cache<B>)
Process a single token using the pure recurrent form.
For SISO (mimo_rank=1):
hₜ = αₜ hₜ₋₁ + βₜ B_{t-1} ⊗ x_{t-1} + γₜ Bₜ ⊗ xₜ
yₜ = Cₜᵀ hₜ + D xₜFor MIMO (mimo_rank=R>1):
hₜ = αₜ hₜ₋₁ + Σ_r βₜ B_{t-1}[r] ⊗ (x_{t-1} ⊙ mimo_x[r])
+ Σ_r γₜ Bₜ[r] ⊗ (xₜ ⊙ mimo_x[r])
yₜ[r] = Cₜ[r]ᵀ hₜ + D xₜ ⊙ mimo_x[r]
outₜ = Σ_r mimo_o[r] ⊙ silu(zₜ ⊙ mimo_z[r]) ⊙ yₜ[r]§Shapes
input_bm:[batch, d_model]- output :
[batch, d_model]
Source§impl<B: Backend + Mamba3BackendExt> Mamba3<B>
impl<B: Backend + Mamba3BackendExt> Mamba3<B>
Sourcepub fn forward(
&self,
input_bsm: Tensor<B, 3>,
cache: Option<Mamba3Cache<B>>,
ssd_path: Mamba3SsdPath,
) -> (Tensor<B, 3>, Mamba3Cache<B>)
pub fn forward( &self, input_bsm: Tensor<B, 3>, cache: Option<Mamba3Cache<B>>, ssd_path: Mamba3SsdPath, ) -> (Tensor<B, 3>, Mamba3Cache<B>)
Process a full input sequence using the trapezoidal two-SSD algorithm.
For SISO (mimo_rank=1), this is the standard two-SSD decomposition. For MIMO (mimo_rank=R>1), B/C have R parallel rank channels. The hidden state is shared across ranks; each rank contributes independently.
§Shapes
input_bsm:[batch, sequence, d_model]- output :
[batch, sequence, d_model]
Source§impl<B: Backend> Mamba3<B>
impl<B: Backend> Mamba3<B>
Sourcepub fn ssd_minimal(input: Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>)
pub fn ssd_minimal(input: Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>)
MIMO-first chunkwise SSD — minimal/segsum variant.
Implements the four-step decomposition for the MIMO trapezoidal recurrence. SISO (R=1) is the degenerate case where the fused length equals the chunk length.
No D skip is applied here — the caller handles it.
§Shapes
- input: see
Mamba3SsdInput - output.0
y_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim] - output.1
final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl<B: Backend> Mamba3<B>
impl<B: Backend> Mamba3<B>
Sourcepub fn ssd_serial(input: Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>)
pub fn ssd_serial(input: Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>)
MIMO-first (Hybrid) Serial SSD.
Implements K1-K5 with a sequential loop (K4) for the inter-chunk scan instead
of the quadratic segsum approach in ssd_minimal.
This is more memory-efficient for long sequences with many chunks.
SISO (R=1) is the special case where the fused length equals the chunk length.
§Returns
y_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl<B: Backend + Mamba3BackendExt> Mamba3<B>
impl<B: Backend + Mamba3BackendExt> Mamba3<B>
Sourcepub fn ssd_serial_recalculated(
input: Mamba3SsdInput<B>,
) -> (Tensor<B, 6>, Tensor<B, 4>)
pub fn ssd_serial_recalculated( input: Mamba3SsdInput<B>, ) -> (Tensor<B, 6>, Tensor<B, 4>)
MIMO-first Serial SSD with recalculated backward.
Computes K1 eagerly (so the cumsum is available for the backward pass),
then delegates the remaining computation to Mamba3BackendExt::ssd_serial_recalculated
which can provide a memory-efficient custom backward for supported backends.
Falls back to the standard K2-K5 serial computation on unsupported backends.
§Returns
y_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]