Skip to main content

Module minimal

Module minimal 

Source
Expand description

Matmul/segsum MIMO-first SSD with plain autodiff backward.

§Single-Pass SSD (Minimal / segsum variant)

This is the MIMO-first, single SSD pass implementation of the Mamba-3 trapezoid recurrence. It is the Burn analogue of the official Tilelang MIMO kernel and Triton SISO kernel; SISO is the mimo_rank = 1 degenerate case.

§Background — the single-ssd recurrence

The double-ssd trapezoid hidden state is

  hₜ = αₜ hₜ₋₁ + βₜ (Bₜ₋₁ ⊗ xₜ₋₁) + γₜ (Bₜ ⊗ xₜ)

Expanding the recurrence and grouping by (Bₛ ⊗ xₛ) gives the coefficient (Πᵣ₌ₛ₊₁ᵗ αᵣ) · [γₛ + (1−λₛ₊₁)·Δₛ₊₁] for the contribution of step s to state t (for s < t). At s = t the coefficient is just γₜ.

Define scaleₜ = γₜ + (1−λₜ₊₁)·Δₜ₊₁ (with scaleₜ = γₜ at the last position). The single-SSD

  h'ₜ = αₜ h'ₜ₋₁ + scaleₜ (Bₜ ⊗ xₜ)

produces the same outputs yₜ = Cₜᵀ h'ₜ as the double-ssd one except at the same-step diagonal (s = t), where the single-ssd form has scaleₜ instead of γₜ. We compensate by:

  1. Using a strict lower-triangular mask in the intra-chunk path (the s = t block is excluded from the trapezoid sum).
  2. Adding a separate γ-weighted same-step term γₜ · (Cₜᵀ Bₜ) · xₜ.

§Algorithm (per chunk, MIMO-first)

  K_scaled[t, m, h, n] = scaleₜ · B[t, m, h, n]     // K scaled inside the SSD

  y_lower  = (C ⊗ K_scaledᵀ ⊙ L_strict) · PsiV     // strict lower-tri
  y_diag   = γₜ · (C ⊗ Bᵀ at same step) · PsiV      // diagonal correction
  y_off    = C · h'_chunk_in · exp(da_cs)           // state-to-output

  y        = y_lower + y_diag + y_off

  h'_chunk_out = exp(da_cs_last) · h'_chunk_in
               + K_scaled · exp(da_cs_rev)ᵀ · PsiV   // standard state update

The MIMO causal mask is identical to crate::mamba3::double_ssd::ssd::minimal but with a stricter inequality (i_time > j_time rather than i_time ≥ j_time).

Reference implementations:

  • SISO: refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py
  • MIMO: refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py