Expand description
Matmul/segsum MIMO-first SSD with plain autodiff backward.
§Single-Pass SSD (Minimal / segsum variant)
This is the MIMO-first, single SSD pass implementation of the
Mamba-3 trapezoid recurrence. It is the Burn analogue of the official
Tilelang MIMO kernel and Triton SISO kernel; SISO is the mimo_rank = 1
degenerate case.
§Background — the single-ssd recurrence
The double-ssd trapezoid hidden state is
hₜ = αₜ hₜ₋₁ + βₜ (Bₜ₋₁ ⊗ xₜ₋₁) + γₜ (Bₜ ⊗ xₜ)Expanding the recurrence and grouping by (Bₛ ⊗ xₛ) gives the coefficient
(Πᵣ₌ₛ₊₁ᵗ αᵣ) · [γₛ + (1−λₛ₊₁)·Δₛ₊₁] for the contribution of step s to
state t (for s < t). At s = t the coefficient is just γₜ.
Define scaleₜ = γₜ + (1−λₜ₊₁)·Δₜ₊₁ (with scaleₜ = γₜ at the last
position). The single-SSD
h'ₜ = αₜ h'ₜ₋₁ + scaleₜ (Bₜ ⊗ xₜ)produces the same outputs yₜ = Cₜᵀ h'ₜ as the double-ssd one except
at the same-step diagonal (s = t), where the single-ssd form has scaleₜ
instead of γₜ. We compensate by:
- Using a strict lower-triangular mask in the intra-chunk path (the
s = tblock is excluded from the trapezoid sum). - Adding a separate γ-weighted same-step term
γₜ · (Cₜᵀ Bₜ) · xₜ.
§Algorithm (per chunk, MIMO-first)
K_scaled[t, m, h, n] = scaleₜ · B[t, m, h, n] // K scaled inside the SSD
y_lower = (C ⊗ K_scaledᵀ ⊙ L_strict) · PsiV // strict lower-tri
y_diag = γₜ · (C ⊗ Bᵀ at same step) · PsiV // diagonal correction
y_off = C · h'_chunk_in · exp(da_cs) // state-to-output
y = y_lower + y_diag + y_off
h'_chunk_out = exp(da_cs_last) · h'_chunk_in
+ K_scaled · exp(da_cs_rev)ᵀ · PsiV // standard state updateThe MIMO causal mask is identical to crate::mamba3::double_ssd::ssd::minimal but
with a stricter inequality (i_time > j_time rather than i_time ≥ j_time).
Reference implementations:
- SISO:
refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py - MIMO:
refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py