pub struct Mamba3DoubleSsdInput {
pub v_bnlmhp: Tensor<6>,
pub da_bnlh: Tensor<4>,
pub b_bnlmhr: Tensor<6>,
pub c_bnlmhr: Tensor<6>,
pub initial_state_bhpr: Tensor<4>,
pub init_state_hpr: Option<Tensor<3>>,
}Expand description
MIMO-first SSD input.
All tensors are pre-processed: B/C are already QK-normed, RoPE-applied, bias-added, and
expanded to per-head (not per-group). V is already scaled by the (double-ssd) trapezoidal
coefficient (γ or β). The combined log-decay da = Δ·A is pre-computed. D skip is handled
by the caller.
Fields§
§v_bnlmhp: Tensor<6>Value tensor, already scaled by (double-ssd) trapezoidal coefficient (γ or β).
§Shape
[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
da_bnlh: Tensor<4>§b_bnlmhr: Tensor<6>Key/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head, per-rank.
§Shape
[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
c_bnlmhr: Tensor<6>Query/C tensor: same processing as B.
§Shape
[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
initial_state_bhpr: Tensor<4>§init_state_hpr: Option<Tensor<3>>Implementations§
Source§impl Mamba3DoubleSsdInput
impl Mamba3DoubleSsdInput
Sourcepub fn double_ssd_minimal(self) -> (Tensor<6>, Tensor<4>)
pub fn double_ssd_minimal(self) -> (Tensor<6>, Tensor<4>)
MIMO-first chunkwise SSD — minimal/segsum variant.
Implements the four-step decomposition for the MIMO (double-ssd) trapezoidal recurrence. SISO (mimo_rank=1) is the degenerate case where the fused length equals the chunk length.
No D skip is applied here — the caller handles it.
§Shapes
- input: see
Mamba3DoubleSsdInput - output.0
y_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim] - output.1
final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl Mamba3DoubleSsdInput
impl Mamba3DoubleSsdInput
Sourcepub fn double_ssd_serial(self) -> (Tensor<6>, Tensor<4>)
pub fn double_ssd_serial(self) -> (Tensor<6>, Tensor<4>)
MIMO-first (Hybrid) Serial SSD.
Implements K1-K5 with a sequential loop (K4) for the inter-chunk scan instead
of the quadratic segsum approach in Self::double_ssd_minimal.
This is more memory-efficient for long sequences with many chunks.
SISO (mimo_rank=1) is the special case where the fused length equals the chunk length.
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl Mamba3DoubleSsdInput
impl Mamba3DoubleSsdInput
Sourcepub fn double_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>)
pub fn double_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>)
MIMO-first Serial SSD with recalculated backward.
Delegates the full K1-K5 computation to Mamba3DoubleSsdBackendExt::double_ssd_serial_recalculated
which can provide a memory-efficient custom backward for supported backends.
Falls back to the standard K1-K5 serial computation on unsupported backends.
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl Mamba3DoubleSsdInput
impl Mamba3DoubleSsdInput
Sourcepub fn sanity(&self)
pub fn sanity(&self)
Run the NaN/Inf guards on every input tensor.
Source§impl Mamba3DoubleSsdInput
impl Mamba3DoubleSsdInput
Sourcepub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>)
pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>)
Run the selected double-ssd algorithm on this MIMO-first input.
Dispatches by Mamba3SsdPath variant to double_ssd_minimal,
double_ssd_serial, or double_ssd_serial_recalculated.
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]