Skip to main content

Mamba3SingleSsdInput

Struct Mamba3SingleSsdInput 

Source
pub struct Mamba3SingleSsdInput {
    pub v_bnlmhp: Tensor<6>,
    pub b_bnlmhr: Tensor<6>,
    pub c_bnlmhr: Tensor<6>,
    pub da_bnlh: Tensor<4>,
    pub gamma_bnlh: Tensor<4>,
    pub scale_bnlh: Tensor<4>,
    pub initial_state_bhpr: Tensor<4>,
    pub init_state_hpr: Option<Tensor<3>>,
}
Expand description

MIMO-first input bundle for the merged-form SSD.

All tensors are pre-processed by the caller (Mamba3::forward_single_ssd): B/C are already QK-normed, RoPE-applied, bias-added, and expanded to per-head; V is the raw, unscaled MIMO-expanded value. The combined log-decay da = Δ·A is pre-computed. The two trapezoidal coefficients gammaₜ and scaleₜ are supplied separately because the SSD itself does the K-scaling and γ-weighted diagonal correction internally. D-skip and Z-gating are handled by the caller.

Fields§

§v_bnlmhp: Tensor<6>

Value tensor, MIMO-expanded but not trapezoidally scaled.

§Shape

  • [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
§b_bnlmhr: Tensor<6>

K/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head. Not pre-scaled — the SSD multiplies by scaleₜ internally for the lower-triangular and state-recurrence paths, while the diagonal correction reuses the unscaled tensor.

§Shape

  • [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
§c_bnlmhr: Tensor<6>

Q/C tensor: same processing as b_bnlmhr.

§Shape

  • [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
§da_bnlh: Tensor<4>

Pre-combined log-decay Δ·A (negative).

§Shape

  • [batch, nchunks, chunk_len, nheads]
§gamma_bnlh: Tensor<4>

γₜ = λₜ · Δₜ — used as the per-token diagonal multiplier.

§Shape

  • [batch, nchunks, chunk_len, nheads]
§scale_bnlh: Tensor<4>

scaleₜ = γₜ + (1 − λₜ₊₁) · Δₜ₊₁ — K is multiplied by this for the lower-triangular and state recurrence paths. The shifted term is zero at the very last sequence position (no future token exists).

§Shape

  • [batch, nchunks, chunk_len, nheads]
§initial_state_bhpr: Tensor<4>

Initial SSM hidden state (merged-form accumulator).

When continuing from a prior call, this should already include the boundary β contribution (1 − λ₀) · Δ₀ · Σₘ Kₜ₋₁[m] ⊗ (xₜ₋₁ ⊙ mimo_xₘ) (which the previous call could not yet add because it did not know λ₀, Δ₀).

§Shape

  • [batch, nheads, per_head_dim, state_rank]
§init_state_hpr: Option<Tensor<3>>

Optional learnable initial state (broadcast over batch).

§Shape

  • [nheads, per_head_dim, state_rank]

Implementations§

Source§

impl Mamba3SingleSsdInput

Source

pub fn single_ssd_minimal(self) -> (Tensor<6>, Tensor<4>)

MIMO-first single-SSD — segsum variant.

See module documentation for the algorithm. Returns the chunked outputs and the final single-ssd accumulator.

§Shapes
  • input: see Mamba3SingleSsdInput
  • output (y_bnlmhp, final_state_bhpr):
    • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
    • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl Mamba3SingleSsdInput

Source

pub fn single_ssd_serial(self) -> (Tensor<6>, Tensor<4>)

MIMO-first Single-SSD — chunk-serial (K1–K5) variant.

Sequence of kernels (matches the double-ssd ssd_serial):

  1. K1: intra-chunk cumulative log-decay and per-chunk decay totals.
  2. K2: cb = C · Bᵀ block matrix (unscaled).
  3. K3: per-chunk hidden state assuming zero initial state, fed K_scaled = scaleₜ · B.
  4. K4: sequential state passing across chunks (loop over chunks).
  5. K5 (this module’s new function): single-ssd chunk scan with strict lower-triangular masking, scale broadcasting, and the γₜ-weighted same-step diagonal correction.
§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank] — the single-ssd accumulator at the last token.
Source§

impl Mamba3SingleSsdInput

Source

pub fn single_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>)

MIMO-first single-ssd form Serial SSD with recalculated backward.

Delegates the full K1–K5 (single-ssd) computation to Mamba3SingleSsdBackendExt::single_ssd_serial_recalculated, which can provide a memory-efficient custom backward for supported backends (the Autodiff wrapper) and falls back to the standard K1–K5 forward on others.

§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl Mamba3SingleSsdInput

Source

pub fn sanity(&self)

Run the NaN/Inf guards on every input tensor.

Source§

impl Mamba3SingleSsdInput

Source

pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>)

Run the selected merged-form (single-ssd) algorithm on this MIMO-first input.

Dispatches by Mamba3SsdPath variant to single_ssd_minimal, single_ssd_serial, or single_ssd_serial_recalculated.

§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank] — the merged-form accumulator at the last token (to be stored in the cache for streaming).

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.