Skip to main content

Mamba3DoubleSsdInput

Struct Mamba3DoubleSsdInput 

Source
pub struct Mamba3DoubleSsdInput {
    pub v_bnlmhp: Tensor<6>,
    pub da_bnlh: Tensor<4>,
    pub b_bnlmhr: Tensor<6>,
    pub c_bnlmhr: Tensor<6>,
    pub initial_state_bhpr: Tensor<4>,
    pub init_state_hpr: Option<Tensor<3>>,
}
Expand description

MIMO-first SSD input.

All tensors are pre-processed: B/C are already QK-normed, RoPE-applied, bias-added, and expanded to per-head (not per-group). V is already scaled by the (double-ssd) trapezoidal coefficient (γ or β). The combined log-decay da = Δ·A is pre-computed. D skip is handled by the caller.

Fields§

§v_bnlmhp: Tensor<6>

Value tensor, already scaled by (double-ssd) trapezoidal coefficient (γ or β).

§Shape

  • [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
§da_bnlh: Tensor<4>

Pre-combined log-decay Δ·A (negative).

§Shape

  • [batch, nchunks, chunk_len, nheads]
§b_bnlmhr: Tensor<6>

Key/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head, per-rank.

§Shape

  • [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
§c_bnlmhr: Tensor<6>

Query/C tensor: same processing as B.

§Shape

  • [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
§initial_state_bhpr: Tensor<4>

Initial SSM hidden state.

§Shape

  • [batch, nheads, per_head_dim, state_rank]
§init_state_hpr: Option<Tensor<3>>

Optional learnable initial state (broadcast over batch).

§Shape

  • [nheads, per_head_dim, state_rank]

Implementations§

Source§

impl Mamba3DoubleSsdInput

Source

pub fn double_ssd_minimal(self) -> (Tensor<6>, Tensor<4>)

MIMO-first chunkwise SSD — minimal/segsum variant.

Implements the four-step decomposition for the MIMO (double-ssd) trapezoidal recurrence. SISO (mimo_rank=1) is the degenerate case where the fused length equals the chunk length.

No D skip is applied here — the caller handles it.

§Shapes
  • input: see Mamba3DoubleSsdInput
  • output.0 y_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]
  • output.1 final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl Mamba3DoubleSsdInput

Source

pub fn double_ssd_serial(self) -> (Tensor<6>, Tensor<4>)

MIMO-first (Hybrid) Serial SSD.

Implements K1-K5 with a sequential loop (K4) for the inter-chunk scan instead of the quadratic segsum approach in Self::double_ssd_minimal. This is more memory-efficient for long sequences with many chunks.

SISO (mimo_rank=1) is the special case where the fused length equals the chunk length.

§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl Mamba3DoubleSsdInput

Source

pub fn double_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>)

MIMO-first Serial SSD with recalculated backward.

Delegates the full K1-K5 computation to Mamba3DoubleSsdBackendExt::double_ssd_serial_recalculated which can provide a memory-efficient custom backward for supported backends.

Falls back to the standard K1-K5 serial computation on unsupported backends.

§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]
Source§

impl Mamba3DoubleSsdInput

Source

pub fn sanity(&self)

Run the NaN/Inf guards on every input tensor.

Source§

impl Mamba3DoubleSsdInput

Source

pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>)

Run the selected double-ssd algorithm on this MIMO-first input.

Dispatches by Mamba3SsdPath variant to double_ssd_minimal, double_ssd_serial, or double_ssd_serial_recalculated.

§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.