Skip to main content

Mamba2

Struct Mamba2 

Source
pub struct Mamba2<B: Backend> {
    pub in_proj: Linear<B>,
    pub conv1d: Conv1d<B>,
    pub dt_bias_h: Param<Tensor<B, 1>>,
    pub dt_limit: (f64, f64),
    pub a_log_h: Param<Tensor<B, 1>>,
    pub d_h: Param<Tensor<B, 1>>,
    pub norm: RmsNormGated<B>,
    pub out_proj: Linear<B>,
    pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
    pub state_rank: usize,
    pub ngroups: usize,
}
Expand description

The Mamba-2 SSM block.

Implements the full SSD layer as described in §5 of the paper. Supports two execution modes:

  • Self::forward — chunkwise SSD for training / prefill (exploits tensor cores; linear in sequence length T)
  • Self::step — pure recurrent form for token-by-token decoding (O(H·P·N) per step; no KV-cache)

§Architecture (one forward pass through the block)

  u  [B, T, D]
  ├─ in_proj ──────────────────────────────────┐
  │                                            │
  │  z [B,T,I]   xbc [B,T,V]   dt_raw [B,T,H] │
  │                │                           │
  │            causal Conv1d                   │
  │                │ SiLU                      │
  │           split into                       │
  │       x [B,T,H,P]  B [B,T,G,N]  C [B,T,G,N]
  │                                            │
  │     Δ = softplus(dt_raw + dt_bias)         │
  │     Ā = exp(Δ · A)   [scalar per head]     │
  │     B̄ = Δ · B                              │
  │                                            │
  │  ┌──── chunked_selective_scan ─────────┐   │
  │  │  (Steps 1–4, see below)             │   │
  │  └────────────────────────────────────-┘   │
  │     y [B,T,H,P]                            │
  │     + D skip                               │
  │     RmsNormGated(·, z)                     │
  └─ out_proj ─────────────────────────────────┘
  output  [B, T, D]

Fields§

§in_proj: Linear<B>

Input projection: maps d_model → d_inner + conv_dim + nheads.

The output is split into three parts:

  • z [B, T, d_inner] — multiplicative gate for the output norm
  • xbc [B, T, conv_dim] — input to the causal convolution, which is then split into (x, B, C) after activation
  • dt_raw [B, T, nheads] — raw (pre-softplus) discretisation step Δ
§conv1d: Conv1d<B>

Causal depthwise Conv1d applied to the xbc projection.

  • Input/output channels: conv_dim
  • Kernel size: conv_kernel (typically 4)
  • Groups: conv_dim (fully depthwise — each channel is independent)
  • Padding: none (left-padding is applied manually so the convolution is strictly causal)

The convolution provides a local conv_kernel-token context window before the SSM, which helps the model capture short-range dependencies that the SSM’s recurrent form handles less efficiently.

§dt_bias_h: Param<Tensor<B, 1>>

Per-head bias for the discretisation step size Δ.

Shape: [nheads]

At inference time, Δₜ = softplus(dt_raw_t + dt_bias). Initialised such that the corresponding initial Δ values are log-uniformly distributed in [dt_min, dt_max].

§dt_limit: (f64, f64)

Hard clamp applied to Δ after softplus: Δ ∈ [dt_limit.0, dt_limit.1].

Prevents degenerate discretisations (e.g. Δ → 0 causes Ā → 1, meaning the state never decays; Δ → ∞ causes Ā → 0, meaning the state is immediately wiped each step).

§a_log_h: Param<Tensor<B, 1>>

Per-head log-magnitude of the continuous-time decay parameter A.

Shape: [nheads]

The actual (negative) decay rate is A = -exp(a_log). The discrete decay is Āₜ = exp(Δₜ · A) = exp(-Δₜ · exp(a_log)) ∈ (0, 1).

Storing the log of the magnitude and negating ensures A < 0 (decaying system) unconditionally and avoids any sign-constraint during gradient descent.

§d_h: Param<Tensor<B, 1>>

Per-head skip (D) coefficient.

Shape: [nheads]

Adds a direct path from the (post-convolution, pre-SSM) input to the output: yₜ += D · xₜ. Initialised to ones.

§norm: RmsNormGated<B>

Gated RMSNorm applied to the SSM output, conditioned on the gate z.

Input channel dimension: d_inner.

This combines the multiplicative gate (from z) and a normalisation step into a single fused operation, matching the architecture in §5.2 of the paper.

§out_proj: Linear<B>

Output projection: maps d_inner → d_model.

§init_state_hpr: Option<Param<Tensor<B, 3>>>

Optional learnable initial hidden state h₀.

Shape: [nheads, per_head_dim, state_rank] (i.e. [H, P, N])

When None, the initial state is zero (the standard default). When Some, the stored tensor is used as the initial condition for every forward call (not per-batch; it is broadcast over the batch dimension).

§state_rank: usize

State rank N — the number of latent dimensions in the SSM hidden state h ∈ ℝ^{N×P} per head. Corresponds to the paper’s N.

§ngroups: usize

Number of B/C groups G for grouped SSM heads (analogous to grouped-query attention). G divides nheads; all nheads/G heads within a group share the same B and C projections while having independent X, A, and Z projections.

Implementations§

Source§

impl<B: Backend> Mamba2<B>

Source

pub fn step( &self, input_bm: Tensor<B, 2>, cache: Option<Mamba2Cache<B>>, ) -> (Tensor<B, 2>, Mamba2Cache<B>)

Process a single token using the pure recurrent SSM form.

This is the O(H·P·N)-per-token decoding path. It runs one tick of the discretised Mamba-2 recurrence:

  Āₜ  = exp(Δₜ · A)           scalar per head, ∈ (0, 1)
  B̄ₜ  = Δₜ · Bₜ               ∈ ℝᴺ   (Euler discretisation)
  hₜ  = Āₜ · hₜ₋₁ + B̄ₜ · xₜᵀ  ∈ ℝ^{P×N}   (outer product update)
  yₜ  = Cₜᵀ · hₜ + D · xₜ     ∈ ℝᴾ   (output)

The convolution is handled by manually sliding the cache window: the oldest input column is dropped and the new token’s projection is appended.

The SSM hidden state cache.ssm_bhpr is updated in-place via the recurrence above.

§Shapes
  • input_bm : [batch, d_model]
  • output : [batch, d_model]
Source§

impl<B: Backend> Mamba2<B>

Source

pub fn d_inner(&self) -> usize

d_inner = expand · d_model. Inferred from the norm’s weight shape.

Source

pub fn nheads(&self) -> usize

nheads = d_inner / per_head_dim. Inferred from a_log_h.

Source

pub fn per_head_dim(&self) -> usize

per_head_dim P = d_inner / nheads.

Source

pub fn conv_dim(&self) -> usize

conv_dim = d_inner + 2 · ngroups · state_rank.

Source§

impl<B: Backend + Mamba2BackendExt> Mamba2<B>

Source

pub fn forward( &self, input_bsm: Tensor<B, 3>, cache: Option<Mamba2Cache<B>>, ssd_path: Mamba2SsdPath, ) -> (Tensor<B, 3>, Mamba2Cache<B>)

Process a full input sequence using the chunkwise SSD algorithm.

This is the primary training and prefill path. The computation is linear in T but uses batched matrix multiplications (GEMMs) that can exploit GPU tensor cores — unlike the naive sequential recurrence, which requires O(T) serial steps.

§Full dataflow
  1. In-projection: u → (z, xbc, dt_raw) via a single linear layer.
  2. Causal Conv1d + SiLU: local context mixing over xbc.
  3. Split: xbc → (x, B, C).
  4. Discretise: Δ = softplus(dt_raw + dt_bias); Ā = exp(Δ · A); B̄ = Δ · B.
  5. Padding: sequence padding.
  6. Chunked SSD: four-step chunkwise algorithm (see [Self::chunked_selective_scan]).
  7. Gated RMSNorm: y = RMSNorm(y) · σ(z).
  8. Out-projection: y → output.
§Sequence padding

If sequence_unpadded % chunk_len ≠ 0, the sequence is zero-padded to the next multiple of Q. Zero-padding is equivalent to inserting identity steps (Δ = 0 ⇒ Ā = exp(0) = 1, B̄ = 0), so the SSM state is carried forward unchanged through the pad — making it safe to read the final state of the padded last chunk as the true final state.

§Shapes
  • input_bsm : [batch, sequence, d_model]
  • output : [batch, sequence, d_model]
  • cache (out) : updated convolution window and SSM state
Source§

impl<B: Backend> Mamba2<B>

Source

pub fn ssd_minimal(input: Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>)

Minimal chunkwise SSD algorithm.

Implements the four-step decomposition of the semiseparable matrix multiplication described in §4 of the paper. The sequence of length T is split into nchunks = ⌈T/Q⌉ chunks of length Q.

§The four steps
§Step 1 — Intra-chunk outputs (Y_diag)

Within each chunk, compute the output assuming the initial hidden state is zero. This is the quadratic attention form of the SSD layer restricted to a window of Q tokens (§4.1):

  Y_diag[n] = (L[n] ∘ C[n] B[n]ᵀ) · X[n]

where L[n] is the Q×Q 1-semiseparable mask for chunk n. This step is a batched GEMM (exploits tensor cores).

§Step 2 — Chunk state (state_bnhpr)

Compute the final SSM state of each chunk assuming zero initial state (§4.1, Eq. 20):

  s[n] = Σ_{t ∈ chunk n}  exp(A_cum[end] - A_cum[t]) · B̄[t] · x[t]ᵀ

This is also a batched GEMM and is fully parallel across chunks.

§Step 3 — Inter-chunk state scan (state passing)

Propagate the true hidden state across chunk boundaries using the recurrence (§4.1, Eq. 22):

  h[n] = Ā[n]_chunk · h[n-1] + s[n]

where Ā[n]_chunk = exp(Σ_{t ∈ chunk n} Δₜ · A) is the cumulative decay over the whole chunk. This step is implemented as a single batched matrix multiplication using the 1-semiseparable structure of the inter-chunk decay matrix (same segsum trick, now over chunks). The scan has length nchunks = T/Q rather than T, so its cost is negligible for typical chunk sizes.

§Step 4 — State-to-output (Y_off)

For each chunk n, compute the contribution of the true initial state h[n-1] to the outputs within that chunk (§4.1, Eq. 23):

  Y_off[n, t] = C[n, t]ᵀ · exp(A_cum[t]) · h[n-1]

This is again a batched GEMM.

§Final output (with D skip-connection)
  Y = Y_diag + Y_off + D · X
Source§

impl<B: Backend> Mamba2<B>

Source

pub fn ssd_serial(input: Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>)

Forward pass for the Mamba-2 SSD module.

Returns:

  • y_bnlhp.
  • final_state_bhpr.
Source§

impl<B: Backend + Mamba2BackendExt> Mamba2<B>

Source

pub fn ssd_serial_recalculated( input: Mamba2SsdInput<B>, ) -> (Tensor<B, 5>, Tensor<B, 4>)

Forward pass for the Mamba-2 SSD module.

Returns:

  • y_bnlhp.
  • final_state_bhpr.

Trait Implementations§

Source§

impl<B> AutodiffModule<B> for Mamba2<B>
where B: AutodiffBackend + Backend, <B as AutodiffBackend>::InnerBackend: Backend,

Source§

type InnerModule = Mamba2<<B as AutodiffBackend>::InnerBackend>

Inner module without auto-differentiation.
Source§

fn valid(&self) -> Self::InnerModule

Returns the same module, but on the inner backend without auto-differentiation.
Source§

fn from_inner(module: Self::InnerModule) -> Self

Wraps an inner module back into an auto-diff module.
Source§

impl<B: Backend> Clone for Mamba2<B>

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<B: Debug + Backend> Debug for Mamba2<B>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<B: Backend> Display for Mamba2<B>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<B> HasAutodiffModule<B> for Mamba2<B::InnerBackend>
where B: AutodiffBackend + Backend, <B as AutodiffBackend>::InnerBackend: Backend,

Source§

type TrainModule = Mamba2<B>

The module with auto-differentiation.
Source§

impl<B: Backend> Module<B> for Mamba2<B>

Source§

type Record = Mamba2Record<B>

Type to save and load the module.
Source§

fn load_record(self, record: Self::Record) -> Self

Load the module state from a record.
Source§

fn into_record(self) -> Self::Record

Convert the module into a record containing the state.
Source§

fn num_params(&self) -> usize

Get the number of parameters the module has, including all of its sub-modules.
Source§

fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor)

Visit each tensor parameter in the module with a visitor.
Source§

fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self

Map each tensor parameter in the module with a mapper.
Source§

fn collect_devices(&self, devices: Devices<B>) -> Devices<B>

Return all the devices found in the underneath module tree added to the given vector without duplicates.
Source§

fn to_device(self, device: &B::Device) -> Self

Move the module and all of its sub-modules to the given device. Read more
Source§

fn fork(self, device: &B::Device) -> Self

Fork the module and all of its sub-modules to the given device. Read more
§

fn devices(&self) -> Vec<<B as BackendTypes>::Device>

Return all the devices found in the underneath module tree without duplicates.
§

fn no_grad(self) -> Self

Each tensor in the module tree will not require grad. Read more
§

fn train<AB>(self) -> Self::TrainModule
where AB: AutodiffBackend<InnerBackend = B>, Self: HasAutodiffModule<AB>,

Move the module and all of its sub-modules to the autodiff backend. Read more
§

fn quantize_weights(self, quantizer: &mut Quantizer) -> Self

Quantize the weights of the module.
Source§

impl<B: Backend> ModuleDisplay for Mamba2<B>

§

fn format(&self, passed_settings: DisplaySettings) -> String

Formats the module with provided display settings. Read more
§

fn custom_settings(&self) -> Option<DisplaySettings>

Custom display settings for the module. Read more
§

fn custom_content(&self, _content: Content) -> Option<Content>

Custom attributes for the module. Read more
Source§

impl<B: Backend> ModuleDisplayDefault for Mamba2<B>

Source§

fn content(&self, content: Content) -> Option<Content>

Attributes of the module used for display purposes. Read more
Source§

fn num_params(&self) -> usize

Gets the number of the parameters of the module.

Auto Trait Implementations§

§

impl<B> !Freeze for Mamba2<B>

§

impl<B> !RefUnwindSafe for Mamba2<B>

§

impl<B> Send for Mamba2<B>

§

impl<B> Sync for Mamba2<B>

§

impl<B> Unpin for Mamba2<B>
where <B as BackendTypes>::Device: Unpin, <B as BackendTypes>::FloatTensorPrimitive: Unpin, <B as BackendTypes>::QuantizedTensorPrimitive: Unpin,

§

impl<B> UnsafeUnpin for Mamba2<B>
where <B as BackendTypes>::Device: UnsafeUnpin, <B as BackendTypes>::FloatTensorPrimitive: UnsafeUnpin, <B as BackendTypes>::QuantizedTensorPrimitive: UnsafeUnpin,

§

impl<B> !UnwindSafe for Mamba2<B>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.