burn_mamba/mamba3/single_ssd/ssd/minimal.rs
1//! # Single-Pass SSD (Minimal / segsum variant)
2//!
3//! This is the MIMO-first, single SSD pass implementation of the
4//! Mamba-3 trapezoid recurrence. It is the Burn analogue of the official
5//! Tilelang MIMO kernel and Triton SISO kernel; SISO is the `mimo_rank = 1`
6//! degenerate case.
7//!
8//! ## Background — the single-ssd recurrence
9//!
10//! The double-ssd trapezoid hidden state is
11//!
12//! ```text
13//! hₜ = αₜ hₜ₋₁ + βₜ (Bₜ₋₁ ⊗ xₜ₋₁) + γₜ (Bₜ ⊗ xₜ)
14//! ```
15//!
16//! Expanding the recurrence and grouping by `(Bₛ ⊗ xₛ)` gives the coefficient
17//! `(Πᵣ₌ₛ₊₁ᵗ αᵣ) · [γₛ + (1−λₛ₊₁)·Δₛ₊₁]` for the contribution of step `s` to
18//! state `t` (for `s < t`). At `s = t` the coefficient is just `γₜ`.
19//!
20//! Define `scaleₜ = γₜ + (1−λₜ₊₁)·Δₜ₊₁` (with `scaleₜ = γₜ` at the last
21//! position). The single-SSD
22//!
23//! ```text
24//! h'ₜ = αₜ h'ₜ₋₁ + scaleₜ (Bₜ ⊗ xₜ)
25//! ```
26//!
27//! produces the same outputs `yₜ = Cₜᵀ h'ₜ` as the double-ssd one **except**
28//! at the same-step diagonal (`s = t`), where the single-ssd form has `scaleₜ`
29//! instead of `γₜ`. We compensate by:
30//!
31//! 1. Using a **strict** lower-triangular mask in the intra-chunk path (the
32//! `s = t` block is excluded from the trapezoid sum).
33//! 2. Adding a separate γ-weighted same-step term `γₜ · (Cₜᵀ Bₜ) · xₜ`.
34//!
35//! ## Algorithm (per chunk, MIMO-first)
36//!
37//! ```text
38//! K_scaled[t, m, h, n] = scaleₜ · B[t, m, h, n] // K scaled inside the SSD
39//!
40//! y_lower = (C ⊗ K_scaledᵀ ⊙ L_strict) · PsiV // strict lower-tri
41//! y_diag = γₜ · (C ⊗ Bᵀ at same step) · PsiV // diagonal correction
42//! y_off = C · h'_chunk_in · exp(da_cs) // state-to-output
43//!
44//! y = y_lower + y_diag + y_off
45//!
46//! h'_chunk_out = exp(da_cs_last) · h'_chunk_in
47//! + K_scaled · exp(da_cs_rev)ᵀ · PsiV // standard state update
48//! ```
49//!
50//! The MIMO causal mask is identical to [`crate::mamba3::double_ssd::ssd::minimal`] but
51//! with a stricter inequality (`i_time > j_time` rather than `i_time ≥ j_time`).
52//!
53//! Reference implementations:
54//! - SISO: `refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py`
55//! - MIMO: `refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py`
56
57use crate::mamba3::single_ssd::prelude::*;
58use crate::modules::segsum;
59use burn::prelude::*;
60
61impl Mamba3SingleSsdInput {
62 /// MIMO-first single-SSD — segsum variant.
63 ///
64 /// See module documentation for the algorithm. Returns the chunked outputs
65 /// and the final single-ssd accumulator.
66 ///
67 /// # Shapes
68 /// - input: see [`Mamba3SingleSsdInput`]
69 /// - output `(y_bnlmhp, final_state_bhpr)`:
70 /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
71 /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
72 #[allow(non_snake_case)]
73 pub fn single_ssd_minimal(self) -> (Tensor<6>, Tensor<4>) {
74 let input = self;
75 input.sanity();
76 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
77 let [.., state_rank] = input.b_bnlmhr.dims();
78 let device = &input.v_bnlmhp.device();
79
80 assert!(nchunks >= 1, "sequence must be non-empty");
81 assert!(chunk_len > 0, "chunk_len must be positive");
82 assert_eq!(
83 [batch, nchunks, chunk_len, nheads],
84 input.gamma_bnlh.dims(),
85 "gamma must align with da"
86 );
87 assert_eq!(
88 [batch, nchunks, chunk_len, nheads],
89 input.scale_bnlh.dims(),
90 "scale must align with da"
91 );
92
93 // ── Fuse mimo_rank into chunk_len (matches `ssd_minimal`) ─────────────
94 let c_bnLMhr = input.c_bnlmhr.clone().reshape([
95 batch,
96 nchunks,
97 chunk_len * mimo_rank,
98 nheads,
99 state_rank,
100 ]);
101 let v_bnLMhp = input.v_bnlmhp.clone().reshape([
102 batch,
103 nchunks,
104 chunk_len * mimo_rank,
105 nheads,
106 per_head_dim,
107 ]);
108
109 // Per-time-step cumulative log-decay (used for L_strict, decay_states, y_off)
110 let a_bhnl = input.da_bnlh.permute([0, 3, 1, 2]);
111 let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
112
113 // K scaled for lower-triangular and state recurrence paths
114 // (the diagonal correction reuses the unscaled `b_bnlmhr`).
115 // scale_bnlh broadcast over (mimo_rank, state_rank):
116 let scale_bnlh11 = input
117 .scale_bnlh
118 .clone()
119 .unsqueeze_dims::<6>(&[3, 5]) // scale_bnlh -> scale_bnl1h1
120 ;
121 let k_scaled_bnlmhr = input.b_bnlmhr.clone() * scale_bnlh11;
122 let k_scaled_bnLMhr =
123 k_scaled_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
124
125 // =============================================================
126 // STEP 1a: Strict lower-triangular intra-chunk output (y_lower)
127 //
128 // y_lower[t1] = Σ_{t2 < t1} (C[t1] · K_scaled[t2]^T)
129 // · exp(cumA[t1] - cumA[t2]) · PsiV[t2]
130 // (block-diagonal in time t1 = t2 is excluded — handled by y_diag.)
131 // =============================================================
132 let y_lower_bnLMhp = {
133 let c_bnhLMr = c_bnLMhr.clone().permute([0, 1, 3, 2, 4]);
134 let k_bnhrLM = k_scaled_bnLMhr.clone().permute([0, 1, 3, 4, 2]);
135 // [batch, nchunks, nheads, chunk_len*mimo_rank, chunk_len*mimo_rank]
136 let cb_bnhLMLM = c_bnhLMr.matmul(k_bnhrLM);
137
138 // L_strict_base[i, j] = exp(cumA[i] - cumA[j]) for i > j, else 0.
139 //
140 // Like `segsum` but with -inf on the diagonal as well (so exp = 0
141 // there). Replaces the existing `triu(1)` masking with `triu(0)`.
142 let l_strict_base_bhnll = {
143 let x_cumsum = a_bhnl.clone().cumsum(3);
144 let row: Tensor<5> = x_cumsum.clone().unsqueeze_dim(4); // [..., l, 1]
145 let col: Tensor<5> = x_cumsum.unsqueeze_dim(3); // [..., 1, l]
146 let diff = row - col; // [..., l, l]
147 let neg_inf_strict = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(0);
148 (diff + neg_inf_strict).exp()
149 };
150
151 // Interleave-expand to fused length (L_strict[i,j] = L_strict_base[i//m, j//m]):
152 let l_strict_bhnLMLM = l_strict_base_bhnll
153 .unsqueeze_dim::<6>(4)
154 .expand([batch, nheads, nchunks, chunk_len, mimo_rank, chunk_len])
155 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank, chunk_len])
156 .unsqueeze_dim::<6>(5)
157 .expand([
158 batch,
159 nheads,
160 nchunks,
161 chunk_len * mimo_rank,
162 chunk_len,
163 mimo_rank,
164 ])
165 .reshape([
166 batch,
167 nheads,
168 nchunks,
169 chunk_len * mimo_rank,
170 chunk_len * mimo_rank,
171 ]);
172
173 // (CB ⊙ L_strict) · V (back in MIMO-fused layout)
174 let cb_bnLMhLM = cb_bnhLMLM.permute([0, 1, 3, 2, 4]);
175 let l_bnLMhLM = l_strict_bhnLMLM.permute([0, 2, 3, 1, 4]);
176 let masked_cb_bnhLMLM = (cb_bnLMhLM * l_bnLMhLM).permute([0, 1, 3, 2, 4]);
177
178 let v_bnhLMp = v_bnLMhp.clone().permute([0, 1, 3, 2, 4]);
179 let y_lower_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
180
181 y_lower_bnhLMp.permute([0, 1, 3, 2, 4]) // y_lower_bnLMhp
182 };
183
184 // =============================================================
185 // STEP 1b: γ-weighted same-step diagonal correction (y_diag)
186 //
187 // y_diag[t, m_out, h, p] = γₜ · Σ_{m_in} (C[t, m_out, h, ·] · B[t, m_in, h, ·]) · PsiV[t, m_in, h, p]
188 // =============================================================
189 let y_diag_bnlmhp = {
190 // C @ B^T contracts over state_rank, leaving mimo_rank on both sides.
191 // c_bnlmhr [b, n, l, m, h, r] -> c_bnlhmr [b, n, l, h, m, r]
192 // b_bnlmhr [b, n, l, m, h, r] -> b_bnlhrm [b, n, l, h, r, m]
193 let c_bnlhmr = input.c_bnlmhr.permute([0, 1, 2, 4, 3, 5]);
194 let b_bnlhrm = input.b_bnlmhr.permute([0, 1, 2, 4, 5, 3]);
195 // qk_dot_bnlhmM [b, n, l, h, m_out, m_in]
196 let qk_dot_bnlhmM = c_bnlhmr.matmul(b_bnlhrm);
197
198 // V in [b, n, l, h, m_in, p] layout for the next matmul:
199 let v_bnlhmp = input.v_bnlmhp.permute([0, 1, 2, 4, 3, 5]);
200 // (qk_dot) · V → [b, n, l, h, m_out, p]
201 let y_d_bnlhmp = qk_dot_bnlhmM.matmul(v_bnlhmp);
202
203 // Multiply by γₜ (per (batch, nchunks, chunk_len, nheads)):
204 let gamma_bnlh11 = input.gamma_bnlh.clone().unsqueeze_dims::<6>(&[4, 5]);
205 let y_d_bnlhmp_scaled = y_d_bnlhmp * gamma_bnlh11;
206
207 // Back to [b, n, l, m, h, p]:
208 y_d_bnlhmp_scaled.permute([0, 1, 2, 4, 3, 5])
209 };
210 // Reshape to fused layout for combination with y_lower / y_off.
211 let y_diag_bnLMhp =
212 y_diag_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
213
214 // =============================================================
215 // STEP 2: Per-chunk single-ssd state (standard SSD with K_scaled)
216 //
217 // s[n] = Σ_{t,m} exp(cumA[n,-1] - cumA[n,t]) · V[n,t*M+m] · K_scaled[n,t*M+m]^T
218 // =============================================================
219 let state_bnhpr = {
220 let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
221 let a_cumsum_bhnLM = a_cumsum_bhnl
222 .clone()
223 .unsqueeze_dim::<5>(4)
224 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
225 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]);
226 let decay_bhnLM = (a_cumsum_last_bhn1 - a_cumsum_bhnLM).exp();
227
228 let decay_bnLMh1 = decay_bhnLM.permute([0, 2, 3, 1]).unsqueeze_dim(4);
229 let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp.clone();
230
231 let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
232 let k_scaled_bnhLMr = k_scaled_bnLMhr.permute([0, 1, 3, 2, 4]);
233 decayed_v_bnhpLM.matmul(k_scaled_bnhLMr) // state_bnhpr
234 };
235
236 // =============================================================
237 // STEP 3: Inter-chunk state scan (segsum-based state passing)
238 //
239 // h'[n] = Ā_chunk[n] · h'[n-1] + s[n]
240 // =============================================================
241 let (state_bnhpr, final_state_bhpr) = {
242 let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
243 let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
244 let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
245 batch,
246 1,
247 nheads,
248 per_head_dim,
249 state_rank,
250 ]);
251 initial_state_b1hpr + init_b1hpr
252 } else {
253 initial_state_b1hpr
254 };
255
256 let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
257
258 let a_cumsum_last_bhn: Tensor<3> = a_cumsum_bhnl
259 .clone()
260 .slice(s![.., .., .., -1])
261 .squeeze_dim(3);
262 let a_chunk_pad_bhN = Tensor::cat(
263 vec![Tensor::zeros([batch, nheads, 1], device), a_cumsum_last_bhn],
264 2,
265 );
266 let decay_chunk_bhNN = segsum::<3, 4>(a_chunk_pad_bhN).exp();
267
268 let flat = per_head_dim * state_rank;
269 let state_bhNPR = state_bNhpr.clone().permute([0, 2, 1, 3, 4]).reshape([
270 batch,
271 nheads,
272 1 + nchunks,
273 flat,
274 ]);
275
276 let new_state_bhNPR = decay_chunk_bhNN.matmul(state_bhNPR);
277 let new_state_bhNpr =
278 new_state_bhNPR.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
279
280 let new_state_bnhpr = new_state_bhNpr
281 .clone()
282 .slice(s![.., .., 0..nchunks, .., ..])
283 .permute([0, 2, 1, 3, 4]);
284 let last_state_bhpr: Tensor<4> = new_state_bhNpr
285 .slice(s![.., .., nchunks, .., ..])
286 .squeeze_dim(2);
287
288 (new_state_bnhpr, last_state_bhpr)
289 };
290
291 // =============================================================
292 // STEP 4: State-to-output (y_off)
293 //
294 // y_off[n, t*M+m] = C[t*M+m]^T · exp(cumA[t]) · h'[n-1]
295 // =============================================================
296 let y_off_bnLMhp = {
297 let state_decay_bhnLM = a_cumsum_bhnl
298 .unsqueeze_dim::<5>(4)
299 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
300 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank])
301 .exp();
302
303 let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
304 let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]);
305 let ch_bnhLMp = c_bnhLMr.matmul(state_bnhrp);
306
307 let decay_bnhLM1 = state_decay_bhnLM.permute([0, 2, 1, 3]).unsqueeze_dim(4);
308 let y_off_bnhLMp = ch_bnhLMp * decay_bnhLM1;
309 y_off_bnhLMp.permute([0, 1, 3, 2, 4])
310 };
311
312 // ── Combine and reshape ───────────────────────────────────────────────
313 let y_bnLMhp = y_lower_bnLMhp + y_diag_bnLMhp + y_off_bnLMhp;
314 let y_bnlmhp =
315 y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
316
317 (y_bnlmhp, final_state_bhpr)
318 }
319}