pub struct Mamba3SingleSsdInput {
pub v_bnlmhp: Tensor<6>,
pub b_bnlmhr: Tensor<6>,
pub c_bnlmhr: Tensor<6>,
pub da_bnlh: Tensor<4>,
pub gamma_bnlh: Tensor<4>,
pub scale_bnlh: Tensor<4>,
pub initial_state_bhpr: Tensor<4>,
pub init_state_hpr: Option<Tensor<3>>,
}Expand description
MIMO-first input bundle for the merged-form SSD.
All tensors are pre-processed by the caller (Mamba3::forward_single_ssd): B/C are
already QK-normed, RoPE-applied, bias-added, and expanded to per-head; V is
the raw, unscaled MIMO-expanded value. The combined log-decay da = Δ·A
is pre-computed. The two trapezoidal coefficients gammaₜ and scaleₜ are
supplied separately because the SSD itself does the K-scaling and γ-weighted
diagonal correction internally. D-skip and Z-gating are handled by the
caller.
Fields§
§v_bnlmhp: Tensor<6>Value tensor, MIMO-expanded but not trapezoidally scaled.
§Shape
[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
b_bnlmhr: Tensor<6>K/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head.
Not pre-scaled — the SSD multiplies by scaleₜ internally for the
lower-triangular and state-recurrence paths, while the diagonal
correction reuses the unscaled tensor.
§Shape
[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
c_bnlmhr: Tensor<6>Q/C tensor: same processing as b_bnlmhr.
§Shape
[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
da_bnlh: Tensor<4>§gamma_bnlh: Tensor<4>γₜ = λₜ · Δₜ — used as the per-token diagonal multiplier.
§Shape
[batch, nchunks, chunk_len, nheads]
scale_bnlh: Tensor<4>scaleₜ = γₜ + (1 − λₜ₊₁) · Δₜ₊₁ — K is multiplied by this for the
lower-triangular and state recurrence paths. The shifted term is zero
at the very last sequence position (no future token exists).
§Shape
[batch, nchunks, chunk_len, nheads]
initial_state_bhpr: Tensor<4>Initial SSM hidden state (merged-form accumulator).
When continuing from a prior call, this should already include the
boundary β contribution (1 − λ₀) · Δ₀ · Σₘ Kₜ₋₁[m] ⊗ (xₜ₋₁ ⊙ mimo_xₘ)
(which the previous call could not yet add because it did not know
λ₀, Δ₀).
§Shape
[batch, nheads, per_head_dim, state_rank]
init_state_hpr: Option<Tensor<3>>Implementations§
Source§impl Mamba3SingleSsdInput
impl Mamba3SingleSsdInput
Sourcepub fn single_ssd_minimal(self) -> (Tensor<6>, Tensor<4>)
pub fn single_ssd_minimal(self) -> (Tensor<6>, Tensor<4>)
MIMO-first single-SSD — segsum variant.
See module documentation for the algorithm. Returns the chunked outputs and the final single-ssd accumulator.
§Shapes
- input: see
Mamba3SingleSsdInput - output
(y_bnlmhp, final_state_bhpr):y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl Mamba3SingleSsdInput
impl Mamba3SingleSsdInput
Sourcepub fn single_ssd_serial(self) -> (Tensor<6>, Tensor<4>)
pub fn single_ssd_serial(self) -> (Tensor<6>, Tensor<4>)
MIMO-first Single-SSD — chunk-serial (K1–K5) variant.
Sequence of kernels (matches the double-ssd ssd_serial):
- K1: intra-chunk cumulative log-decay and per-chunk decay totals.
- K2:
cb = C · Bᵀblock matrix (unscaled). - K3: per-chunk hidden state assuming zero initial state, fed
K_scaled = scaleₜ · B. - K4: sequential state passing across chunks (loop over chunks).
- K5 (this module’s new function): single-ssd chunk scan with
strict lower-triangular masking, scale broadcasting, and the
γₜ-weighted same-step diagonal correction.
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]— the single-ssd accumulator at the last token.
Source§impl Mamba3SingleSsdInput
impl Mamba3SingleSsdInput
Sourcepub fn single_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>)
pub fn single_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>)
MIMO-first single-ssd form Serial SSD with recalculated backward.
Delegates the full K1–K5 (single-ssd) computation to
Mamba3SingleSsdBackendExt::single_ssd_serial_recalculated, which can provide
a memory-efficient custom backward for supported backends (the Autodiff
wrapper) and falls back to the standard K1–K5 forward on others.
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Source§impl Mamba3SingleSsdInput
impl Mamba3SingleSsdInput
Sourcepub fn sanity(&self)
pub fn sanity(&self)
Run the NaN/Inf guards on every input tensor.
Source§impl Mamba3SingleSsdInput
impl Mamba3SingleSsdInput
Sourcepub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>)
pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>)
Run the selected merged-form (single-ssd) algorithm on this MIMO-first input.
Dispatches by Mamba3SsdPath variant to single_ssd_minimal,
single_ssd_serial, or single_ssd_serial_recalculated.
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]— the merged-form accumulator at the last token (to be stored in the cache for streaming).