Skip to main content

burn_mamba/mamba3/single_ssd/ssd/
minimal.rs

1//! # Single-Pass SSD (Minimal / segsum variant)
2//!
3//! This is the MIMO-first, single SSD pass implementation of the
4//! Mamba-3 trapezoid recurrence. It is the Burn analogue of the official
5//! Tilelang MIMO kernel and Triton SISO kernel; SISO is the `mimo_rank = 1`
6//! degenerate case.
7//!
8//! ## Background — the single-ssd recurrence
9//!
10//! The double-ssd trapezoid hidden state is
11//!
12//! ```text
13//!   hₜ = αₜ hₜ₋₁ + βₜ (Bₜ₋₁ ⊗ xₜ₋₁) + γₜ (Bₜ ⊗ xₜ)
14//! ```
15//!
16//! Expanding the recurrence and grouping by `(Bₛ ⊗ xₛ)` gives the coefficient
17//! `(Πᵣ₌ₛ₊₁ᵗ αᵣ) · [γₛ + (1−λₛ₊₁)·Δₛ₊₁]` for the contribution of step `s` to
18//! state `t` (for `s < t`). At `s = t` the coefficient is just `γₜ`.
19//!
20//! Define `scaleₜ = γₜ + (1−λₜ₊₁)·Δₜ₊₁` (with `scaleₜ = γₜ` at the last
21//! position). The single-SSD
22//!
23//! ```text
24//!   h'ₜ = αₜ h'ₜ₋₁ + scaleₜ (Bₜ ⊗ xₜ)
25//! ```
26//!
27//! produces the same outputs `yₜ = Cₜᵀ h'ₜ` as the double-ssd one **except**
28//! at the same-step diagonal (`s = t`), where the single-ssd form has `scaleₜ`
29//! instead of `γₜ`. We compensate by:
30//!
31//! 1. Using a **strict** lower-triangular mask in the intra-chunk path (the
32//!    `s = t` block is excluded from the trapezoid sum).
33//! 2. Adding a separate γ-weighted same-step term `γₜ · (Cₜᵀ Bₜ) · xₜ`.
34//!
35//! ## Algorithm (per chunk, MIMO-first)
36//!
37//! ```text
38//!   K_scaled[t, m, h, n] = scaleₜ · B[t, m, h, n]     // K scaled inside the SSD
39//!
40//!   y_lower  = (C ⊗ K_scaledᵀ ⊙ L_strict) · PsiV     // strict lower-tri
41//!   y_diag   = γₜ · (C ⊗ Bᵀ at same step) · PsiV      // diagonal correction
42//!   y_off    = C · h'_chunk_in · exp(da_cs)           // state-to-output
43//!
44//!   y        = y_lower + y_diag + y_off
45//!
46//!   h'_chunk_out = exp(da_cs_last) · h'_chunk_in
47//!                + K_scaled · exp(da_cs_rev)ᵀ · PsiV   // standard state update
48//! ```
49//!
50//! The MIMO causal mask is identical to [`crate::mamba3::double_ssd::ssd::minimal`] but
51//! with a stricter inequality (`i_time > j_time` rather than `i_time ≥ j_time`).
52//!
53//! Reference implementations:
54//! - SISO: `refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py`
55//! - MIMO: `refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py`
56
57use crate::mamba3::single_ssd::prelude::*;
58use crate::modules::segsum;
59use burn::prelude::*;
60
61impl Mamba3SingleSsdInput {
62    /// MIMO-first single-SSD — segsum variant.
63    ///
64    /// See module documentation for the algorithm. Returns the chunked outputs
65    /// and the final single-ssd accumulator.
66    ///
67    /// # Shapes
68    /// - input: see [`Mamba3SingleSsdInput`]
69    /// - output `(y_bnlmhp, final_state_bhpr)`:
70    ///   - `y_bnlmhp`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
71    ///   - `final_state_bhpr`:   `[batch, nheads, per_head_dim, state_rank]`
72    #[allow(non_snake_case)]
73    pub fn single_ssd_minimal(self) -> (Tensor<6>, Tensor<4>) {
74        let input = self;
75        input.sanity();
76        let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
77        let [.., state_rank] = input.b_bnlmhr.dims();
78        let device = &input.v_bnlmhp.device();
79
80        assert!(nchunks >= 1, "sequence must be non-empty");
81        assert!(chunk_len > 0, "chunk_len must be positive");
82        assert_eq!(
83            [batch, nchunks, chunk_len, nheads],
84            input.gamma_bnlh.dims(),
85            "gamma must align with da"
86        );
87        assert_eq!(
88            [batch, nchunks, chunk_len, nheads],
89            input.scale_bnlh.dims(),
90            "scale must align with da"
91        );
92
93        // ── Fuse mimo_rank into chunk_len (matches `ssd_minimal`) ─────────────
94        let c_bnLMhr = input.c_bnlmhr.clone().reshape([
95            batch,
96            nchunks,
97            chunk_len * mimo_rank,
98            nheads,
99            state_rank,
100        ]);
101        let v_bnLMhp = input.v_bnlmhp.clone().reshape([
102            batch,
103            nchunks,
104            chunk_len * mimo_rank,
105            nheads,
106            per_head_dim,
107        ]);
108
109        // Per-time-step cumulative log-decay (used for L_strict, decay_states, y_off)
110        let a_bhnl = input.da_bnlh.permute([0, 3, 1, 2]);
111        let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
112
113        // K scaled for lower-triangular and state recurrence paths
114        // (the diagonal correction reuses the unscaled `b_bnlmhr`).
115        // scale_bnlh broadcast over (mimo_rank, state_rank):
116        let scale_bnlh11 = input
117            .scale_bnlh
118            .clone()
119            .unsqueeze_dims::<6>(&[3, 5]) // scale_bnlh -> scale_bnl1h1
120            ;
121        let k_scaled_bnlmhr = input.b_bnlmhr.clone() * scale_bnlh11;
122        let k_scaled_bnLMhr =
123            k_scaled_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
124
125        // =============================================================
126        // STEP 1a: Strict lower-triangular intra-chunk output (y_lower)
127        //
128        // y_lower[t1] = Σ_{t2 < t1}  (C[t1] · K_scaled[t2]^T)
129        //              · exp(cumA[t1] - cumA[t2])  · PsiV[t2]
130        // (block-diagonal in time t1 = t2 is excluded — handled by y_diag.)
131        // =============================================================
132        let y_lower_bnLMhp = {
133            let c_bnhLMr = c_bnLMhr.clone().permute([0, 1, 3, 2, 4]);
134            let k_bnhrLM = k_scaled_bnLMhr.clone().permute([0, 1, 3, 4, 2]);
135            // [batch, nchunks, nheads, chunk_len*mimo_rank, chunk_len*mimo_rank]
136            let cb_bnhLMLM = c_bnhLMr.matmul(k_bnhrLM);
137
138            // L_strict_base[i, j] = exp(cumA[i] - cumA[j]) for i > j, else 0.
139            //
140            // Like `segsum` but with -inf on the diagonal as well (so exp = 0
141            // there). Replaces the existing `triu(1)` masking with `triu(0)`.
142            let l_strict_base_bhnll = {
143                let x_cumsum = a_bhnl.clone().cumsum(3);
144                let row: Tensor<5> = x_cumsum.clone().unsqueeze_dim(4); // [..., l, 1]
145                let col: Tensor<5> = x_cumsum.unsqueeze_dim(3); // [..., 1, l]
146                let diff = row - col; // [..., l, l]
147                let neg_inf_strict = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(0);
148                (diff + neg_inf_strict).exp()
149            };
150
151            // Interleave-expand to fused length (L_strict[i,j] = L_strict_base[i//m, j//m]):
152            let l_strict_bhnLMLM = l_strict_base_bhnll
153                .unsqueeze_dim::<6>(4)
154                .expand([batch, nheads, nchunks, chunk_len, mimo_rank, chunk_len])
155                .reshape([batch, nheads, nchunks, chunk_len * mimo_rank, chunk_len])
156                .unsqueeze_dim::<6>(5)
157                .expand([
158                    batch,
159                    nheads,
160                    nchunks,
161                    chunk_len * mimo_rank,
162                    chunk_len,
163                    mimo_rank,
164                ])
165                .reshape([
166                    batch,
167                    nheads,
168                    nchunks,
169                    chunk_len * mimo_rank,
170                    chunk_len * mimo_rank,
171                ]);
172
173            // (CB ⊙ L_strict) · V    (back in MIMO-fused layout)
174            let cb_bnLMhLM = cb_bnhLMLM.permute([0, 1, 3, 2, 4]);
175            let l_bnLMhLM = l_strict_bhnLMLM.permute([0, 2, 3, 1, 4]);
176            let masked_cb_bnhLMLM = (cb_bnLMhLM * l_bnLMhLM).permute([0, 1, 3, 2, 4]);
177
178            let v_bnhLMp = v_bnLMhp.clone().permute([0, 1, 3, 2, 4]);
179            let y_lower_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
180
181            y_lower_bnhLMp.permute([0, 1, 3, 2, 4]) // y_lower_bnLMhp
182        };
183
184        // =============================================================
185        // STEP 1b: γ-weighted same-step diagonal correction (y_diag)
186        //
187        // y_diag[t, m_out, h, p] = γₜ · Σ_{m_in} (C[t, m_out, h, ·] · B[t, m_in, h, ·]) · PsiV[t, m_in, h, p]
188        // =============================================================
189        let y_diag_bnlmhp = {
190            // C @ B^T contracts over state_rank, leaving mimo_rank on both sides.
191            // c_bnlmhr  [b, n, l, m, h, r] -> c_bnlhmr  [b, n, l, h, m, r]
192            // b_bnlmhr  [b, n, l, m, h, r] -> b_bnlhrm  [b, n, l, h, r, m]
193            let c_bnlhmr = input.c_bnlmhr.permute([0, 1, 2, 4, 3, 5]);
194            let b_bnlhrm = input.b_bnlmhr.permute([0, 1, 2, 4, 5, 3]);
195            // qk_dot_bnlhmM [b, n, l, h, m_out, m_in]
196            let qk_dot_bnlhmM = c_bnlhmr.matmul(b_bnlhrm);
197
198            // V in [b, n, l, h, m_in, p] layout for the next matmul:
199            let v_bnlhmp = input.v_bnlmhp.permute([0, 1, 2, 4, 3, 5]);
200            // (qk_dot) · V → [b, n, l, h, m_out, p]
201            let y_d_bnlhmp = qk_dot_bnlhmM.matmul(v_bnlhmp);
202
203            // Multiply by γₜ (per (batch, nchunks, chunk_len, nheads)):
204            let gamma_bnlh11 = input.gamma_bnlh.clone().unsqueeze_dims::<6>(&[4, 5]);
205            let y_d_bnlhmp_scaled = y_d_bnlhmp * gamma_bnlh11;
206
207            // Back to [b, n, l, m, h, p]:
208            y_d_bnlhmp_scaled.permute([0, 1, 2, 4, 3, 5])
209        };
210        // Reshape to fused layout for combination with y_lower / y_off.
211        let y_diag_bnLMhp =
212            y_diag_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
213
214        // =============================================================
215        // STEP 2: Per-chunk single-ssd state (standard SSD with K_scaled)
216        //
217        // s[n] = Σ_{t,m} exp(cumA[n,-1] - cumA[n,t]) · V[n,t*M+m] · K_scaled[n,t*M+m]^T
218        // =============================================================
219        let state_bnhpr = {
220            let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
221            let a_cumsum_bhnLM = a_cumsum_bhnl
222                .clone()
223                .unsqueeze_dim::<5>(4)
224                .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
225                .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]);
226            let decay_bhnLM = (a_cumsum_last_bhn1 - a_cumsum_bhnLM).exp();
227
228            let decay_bnLMh1 = decay_bhnLM.permute([0, 2, 3, 1]).unsqueeze_dim(4);
229            let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp.clone();
230
231            let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
232            let k_scaled_bnhLMr = k_scaled_bnLMhr.permute([0, 1, 3, 2, 4]);
233            decayed_v_bnhpLM.matmul(k_scaled_bnhLMr) // state_bnhpr
234        };
235
236        // =============================================================
237        // STEP 3: Inter-chunk state scan (segsum-based state passing)
238        //
239        // h'[n] = Ā_chunk[n] · h'[n-1] + s[n]
240        // =============================================================
241        let (state_bnhpr, final_state_bhpr) = {
242            let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
243            let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
244                let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
245                    batch,
246                    1,
247                    nheads,
248                    per_head_dim,
249                    state_rank,
250                ]);
251                initial_state_b1hpr + init_b1hpr
252            } else {
253                initial_state_b1hpr
254            };
255
256            let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
257
258            let a_cumsum_last_bhn: Tensor<3> = a_cumsum_bhnl
259                .clone()
260                .slice(s![.., .., .., -1])
261                .squeeze_dim(3);
262            let a_chunk_pad_bhN = Tensor::cat(
263                vec![Tensor::zeros([batch, nheads, 1], device), a_cumsum_last_bhn],
264                2,
265            );
266            let decay_chunk_bhNN = segsum::<3, 4>(a_chunk_pad_bhN).exp();
267
268            let flat = per_head_dim * state_rank;
269            let state_bhNPR = state_bNhpr.clone().permute([0, 2, 1, 3, 4]).reshape([
270                batch,
271                nheads,
272                1 + nchunks,
273                flat,
274            ]);
275
276            let new_state_bhNPR = decay_chunk_bhNN.matmul(state_bhNPR);
277            let new_state_bhNpr =
278                new_state_bhNPR.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
279
280            let new_state_bnhpr = new_state_bhNpr
281                .clone()
282                .slice(s![.., .., 0..nchunks, .., ..])
283                .permute([0, 2, 1, 3, 4]);
284            let last_state_bhpr: Tensor<4> = new_state_bhNpr
285                .slice(s![.., .., nchunks, .., ..])
286                .squeeze_dim(2);
287
288            (new_state_bnhpr, last_state_bhpr)
289        };
290
291        // =============================================================
292        // STEP 4: State-to-output (y_off)
293        //
294        // y_off[n, t*M+m] = C[t*M+m]^T · exp(cumA[t]) · h'[n-1]
295        // =============================================================
296        let y_off_bnLMhp = {
297            let state_decay_bhnLM = a_cumsum_bhnl
298                .unsqueeze_dim::<5>(4)
299                .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
300                .reshape([batch, nheads, nchunks, chunk_len * mimo_rank])
301                .exp();
302
303            let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
304            let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]);
305            let ch_bnhLMp = c_bnhLMr.matmul(state_bnhrp);
306
307            let decay_bnhLM1 = state_decay_bhnLM.permute([0, 2, 1, 3]).unsqueeze_dim(4);
308            let y_off_bnhLMp = ch_bnhLMp * decay_bnhLM1;
309            y_off_bnhLMp.permute([0, 1, 3, 2, 4])
310        };
311
312        // ── Combine and reshape ───────────────────────────────────────────────
313        let y_bnLMhp = y_lower_bnLMhp + y_diag_bnLMhp + y_off_bnLMhp;
314        let y_bnlmhp =
315            y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
316
317        (y_bnlmhp, final_state_bhpr)
318    }
319}