Skip to main content

Mamba3

Struct Mamba3 

Source
pub struct Mamba3<B: Backend> {
Show 20 fields pub in_proj: Linear<B>, pub dt_bias_h: Param<Tensor<B, 1>>, pub dt_limit: (f64, f64), pub a_floor: f64, pub d_h: Param<Tensor<B, 1>>, pub b_norm: RmsNorm<B>, pub c_norm: RmsNorm<B>, pub b_bias_hrn: Param<Tensor<B, 3>>, pub c_bias_hrn: Param<Tensor<B, 3>>, pub mimo_x: Option<Param<Tensor<B, 3>>>, pub mimo_z: Option<Param<Tensor<B, 3>>>, pub mimo_o: Option<Param<Tensor<B, 3>>>, pub out_norm: Option<RmsNormGated<B>>, pub out_proj: Linear<B>, pub init_state_hpr: Option<Param<Tensor<B, 3>>>, pub state_rank: usize, pub ngroups: usize, pub num_rope_angles: usize, pub rope_dim: usize, pub mimo_rank: usize,
}
Expand description

The Mamba-3 SSM block.

Implements the full Mamba-3 layer with exponential-trapezoidal discretization and data-dependent RoPE. Supports SISO (mimo_rank=1) and MIMO (mimo_rank>1). Supports two execution modes:

  • Self::forward — chunkwise two-SSD algorithm for training / prefill
  • Self::step — recurrent form for token-by-token decoding

Fields§

§in_proj: Linear<B>

Input projection.

For SISO (R=1): maps d_model → 2·d_inner + 2·ngroups·state_rank + 3·nheads + num_rope_angles. For MIMO (R>1): maps d_model → 2·d_inner + 2·ngroups·state_rank·R + 3·nheads + num_rope_angles.

Output splits: [z | x | B_raw | C_raw | dd_dt | dd_A | lam_raw | theta_raw]

§dt_bias_h: Param<Tensor<B, 1>>

Per-head bias for the discretisation step size Δ. Shape: [nheads]

§dt_limit: (f64, f64)

Hard clamp applied to Δ after softplus.

§a_floor: f64

Minimum absolute value of A: A ∈ (−∞, −a_floor].

§d_h: Param<Tensor<B, 1>>

Per-head skip (D) coefficient. Shape: [nheads]; initialised to ones.

§b_norm: RmsNorm<B>

RMSNorm applied to the B projection (QK-Norm, no gating). Normalises over the state_rank dimension.

§c_norm: RmsNorm<B>

RMSNorm applied to the C projection (QK-Norm, no gating). Normalises over the state_rank dimension.

§b_bias_hrn: Param<Tensor<B, 3>>

Learnable per-head, per-rank bias for B, added after QK-norm. Shape: [nheads, mimo_rank, state_rank]; initialised to ones.

For SISO (mimo_rank=1) this has shape [nheads, 1, state_rank].

§c_bias_hrn: Param<Tensor<B, 3>>

Learnable per-head, per-rank bias for C, added after QK-norm. Shape: [nheads, mimo_rank, state_rank]; initialised to ones.

§mimo_x: Option<Param<Tensor<B, 3>>>

MIMO up-projection for x (values). Shape: [nheads, mimo_rank, per_head_dim]. Only present when mimo_rank > 1. When SISO, this is None.

§mimo_z: Option<Param<Tensor<B, 3>>>

MIMO up-projection for z (gate). Shape: [nheads, mimo_rank, per_head_dim]. Only present when mimo_rank > 1.

§mimo_o: Option<Param<Tensor<B, 3>>>

MIMO down-projection for the output. Shape: [nheads, mimo_rank, per_head_dim]. Only present when mimo_rank > 1.

§out_norm: Option<RmsNormGated<B>>

Optional gated RMSNorm applied before the output projection.

When Some, the SiLU gate at the block tail is replaced by RmsNormGated(y, z) which normalises y over per_head_dim and gates with SiLU(z). Created when has_outproj_norm = true.

§out_proj: Linear<B>

Output projection: maps d_inner → d_model.

§init_state_hpr: Option<Param<Tensor<B, 3>>>

Optional learnable initial hidden state h₀. Shape: [nheads, per_head_dim, state_rank]

§state_rank: usize

State rank N.

§ngroups: usize

Number of B/C groups G. Must divide nheads.

§num_rope_angles: usize

Number of RoPE angle pairs (rope_dim / 2).

§rope_dim: usize

Effective RoPE dimension (= 2 · num_rope_angles). Always even and ≤ state_rank. Only the first rope_dim entries of B/C are rotated.

§mimo_rank: usize

MIMO rank R. 1 = SISO (standard Mamba-3).

Implementations§

Source§

impl<B: Backend> Mamba3<B>

Source

pub fn step( &self, input_bm: Tensor<B, 2>, cache: Option<Mamba3Cache<B>>, ) -> (Tensor<B, 2>, Mamba3Cache<B>)

Process a single token using the pure recurrent form.

For SISO (mimo_rank=1):

  hₜ = αₜ hₜ₋₁ + βₜ B_{t-1} ⊗ x_{t-1} + γₜ Bₜ ⊗ xₜ
  yₜ = Cₜᵀ hₜ + D xₜ

For MIMO (mimo_rank=R>1):

  hₜ = αₜ hₜ₋₁ + Σ_r βₜ B_{t-1}[r] ⊗ (x_{t-1} ⊙ mimo_x[r])
                 + Σ_r γₜ Bₜ[r] ⊗ (xₜ ⊙ mimo_x[r])
  yₜ[r] = Cₜ[r]ᵀ hₜ + D xₜ ⊙ mimo_x[r]
  outₜ = Σ_r mimo_o[r] ⊙ silu(zₜ ⊙ mimo_z[r]) ⊙ yₜ[r]
§Shapes
  • input_bm : [batch, d_model]
  • output : [batch, d_model]
Source§

impl<B: Backend> Mamba3<B>

Source

pub fn d_inner(&self) -> usize

d_inner = expand · d_model. Inferred from out_proj.

Source

pub fn nheads(&self) -> usize

nheads = d_inner / per_head_dim. Inferred from d_h.

Source

pub fn per_head_dim(&self) -> usize

per_head_dim P = d_inner / nheads.

Source§

impl<B: Backend + Mamba3BackendExt> Mamba3<B>

Source

pub fn forward( &self, input_bsm: Tensor<B, 3>, cache: Option<Mamba3Cache<B>>, ssd_path: Mamba3SsdPath, ) -> (Tensor<B, 3>, Mamba3Cache<B>)

Process a full input sequence using the trapezoidal two-SSD algorithm.

For SISO (mimo_rank=1), this is the standard two-SSD decomposition. For MIMO (mimo_rank=R>1), B/C have R parallel rank channels. The hidden state is shared across ranks; each rank contributes independently.

§Shapes
  • input_bsm : [batch, sequence, d_model]
  • output : [batch, sequence, d_model]
Source§

impl<B: Backend> Mamba3<B>

Source

pub fn ssd_minimal(input: Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>)

MIMO-first chunkwise SSD — minimal/segsum variant.

Implements the four-step decomposition for the MIMO trapezoidal recurrence. SISO (R=1) is the degenerate case where the fused length equals the chunk length.

No D skip is applied here — the caller handles it.

§Shapes
  • input: see Mamba3SsdInput
  • output.0 y_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]
  • output.1 final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl<B: Backend> Mamba3<B>

Source

pub fn ssd_serial(input: Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>)

MIMO-first (Hybrid) Serial SSD.

Implements K1-K5 with a sequential loop (K4) for the inter-chunk scan instead of the quadratic segsum approach in ssd_minimal. This is more memory-efficient for long sequences with many chunks.

SISO (R=1) is the special case where the fused length equals the chunk length.

§Returns
  • y_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl<B: Backend + Mamba3BackendExt> Mamba3<B>

Source

pub fn ssd_serial_recalculated( input: Mamba3SsdInput<B>, ) -> (Tensor<B, 6>, Tensor<B, 4>)

MIMO-first Serial SSD with recalculated backward.

Computes K1 eagerly (so the cumsum is available for the backward pass), then delegates the remaining computation to Mamba3BackendExt::ssd_serial_recalculated which can provide a memory-efficient custom backward for supported backends.

Falls back to the standard K2-K5 serial computation on unsupported backends.

§Returns
  • y_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]

Trait Implementations§

Source§

impl<B> AutodiffModule<B> for Mamba3<B>
where B: AutodiffBackend + Backend, <B as AutodiffBackend>::InnerBackend: Backend,

Source§

type InnerModule = Mamba3<<B as AutodiffBackend>::InnerBackend>

Inner module without auto-differentiation.
Source§

fn valid(&self) -> Self::InnerModule

Returns the same module, but on the inner backend without auto-differentiation.
Source§

fn from_inner(module: Self::InnerModule) -> Self

Wraps an inner module back into an auto-diff module.
Source§

impl<B: Backend> Clone for Mamba3<B>

Source§

fn clone(&self) -> Self

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl<B: Debug + Backend> Debug for Mamba3<B>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<B: Backend> Display for Mamba3<B>

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl<B> HasAutodiffModule<B> for Mamba3<B::InnerBackend>
where B: AutodiffBackend + Backend, <B as AutodiffBackend>::InnerBackend: Backend,

Source§

type TrainModule = Mamba3<B>

The module with auto-differentiation.
Source§

impl<B: Backend> Module<B> for Mamba3<B>

Source§

type Record = Mamba3Record<B>

Type to save and load the module.
Source§

fn load_record(self, record: Self::Record) -> Self

Load the module state from a record.
Source§

fn into_record(self) -> Self::Record

Convert the module into a record containing the state.
Source§

fn num_params(&self) -> usize

Get the number of parameters the module has, including all of its sub-modules.
Source§

fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor)

Visit each tensor parameter in the module with a visitor.
Source§

fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self

Map each tensor parameter in the module with a mapper.
Source§

fn collect_devices(&self, devices: Devices<B>) -> Devices<B>

Return all the devices found in the underneath module tree added to the given vector without duplicates.
Source§

fn to_device(self, device: &B::Device) -> Self

Move the module and all of its sub-modules to the given device. Read more
Source§

fn fork(self, device: &B::Device) -> Self

Fork the module and all of its sub-modules to the given device. Read more
§

fn devices(&self) -> Vec<<B as BackendTypes>::Device>

Return all the devices found in the underneath module tree without duplicates.
§

fn no_grad(self) -> Self

Each tensor in the module tree will not require grad. Read more
§

fn train<AB>(self) -> Self::TrainModule
where AB: AutodiffBackend<InnerBackend = B>, Self: HasAutodiffModule<AB>,

Move the module and all of its sub-modules to the autodiff backend. Read more
§

fn quantize_weights(self, quantizer: &mut Quantizer) -> Self

Quantize the weights of the module.
Source§

impl<B: Backend> ModuleDisplay for Mamba3<B>

§

fn format(&self, passed_settings: DisplaySettings) -> String

Formats the module with provided display settings. Read more
§

fn custom_settings(&self) -> Option<DisplaySettings>

Custom display settings for the module. Read more
§

fn custom_content(&self, _content: Content) -> Option<Content>

Custom attributes for the module. Read more
Source§

impl<B: Backend> ModuleDisplayDefault for Mamba3<B>

Source§

fn content(&self, content: Content) -> Option<Content>

Attributes of the module used for display purposes. Read more
Source§

fn num_params(&self) -> usize

Gets the number of the parameters of the module.

Auto Trait Implementations§

§

impl<B> !Freeze for Mamba3<B>

§

impl<B> !RefUnwindSafe for Mamba3<B>

§

impl<B> Send for Mamba3<B>

§

impl<B> Sync for Mamba3<B>

§

impl<B> Unpin for Mamba3<B>
where <B as BackendTypes>::Device: Unpin, <B as BackendTypes>::FloatTensorPrimitive: Unpin, <B as BackendTypes>::QuantizedTensorPrimitive: Unpin,

§

impl<B> UnsafeUnpin for Mamba3<B>
where <B as BackendTypes>::Device: UnsafeUnpin, <B as BackendTypes>::FloatTensorPrimitive: UnsafeUnpin, <B as BackendTypes>::QuantizedTensorPrimitive: UnsafeUnpin,

§

impl<B> !UnwindSafe for Mamba3<B>

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T> ToString for T
where T: Display + ?Sized,

Source§

fn to_string(&self) -> String

Converts the given value to a String. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.