burn_mamba/mamba2/ssd/minimal.rs
1//! ## The Chunkwise SSD Algorithm
2//!
3//! During training (and prefill), a naive sequential recurrence cannot
4//! exploit GPU tensor cores. The **chunkwise SSD algorithm** (§4 of the
5//! paper) achieves this by splitting the sequence into chunks of length Q
6//! and decomposing the computation into four steps:
7//!
8//! ```text
9//! Step 1 (intra-chunk, quadratic form) → Y_diag
10//! Step 2 (input → chunk state) → state_bnhpr
11//! Step 3 (inter-chunk state scan) → state_bnhpr, final_state
12//! Step 4 (chunk state → output) → Y_off
13//!
14//! Y = Y_diag + Y_off
15//! ```
16//!
17//! Steps 1, 2, 4 are fully parallel across chunks and use batched matrix
18//! multiplications (exploiting tensor cores). Step 3 is a short sequential
19//! scan over `T/Q` elements rather than `T`.
20
21use crate::mamba2::prelude::*;
22use crate::utils::sanity::sanity as san;
23use burn::prelude::*;
24
25impl<B: Backend> Mamba2<B> {
26 // -----------------------------------------------------------------------
27 // chunked_selective_scan
28 // -----------------------------------------------------------------------
29
30 /// Minimal chunkwise SSD algorithm.
31 ///
32 /// Implements the four-step decomposition of the semiseparable matrix
33 /// multiplication described in §4 of the paper. The sequence of length T
34 /// is split into `nchunks = ⌈T/Q⌉` chunks of length Q.
35 ///
36 /// ## The four steps
37 ///
38 /// ### Step 1 — Intra-chunk outputs (Y_diag)
39 ///
40 /// Within each chunk, compute the output assuming the initial hidden state
41 /// is zero. This is the *quadratic attention form* of the SSD layer
42 /// restricted to a window of Q tokens (§4.1):
43 ///
44 /// ```text
45 /// Y_diag[n] = (L[n] ∘ C[n] B[n]ᵀ) · X[n]
46 /// ```
47 ///
48 /// where `L[n]` is the Q×Q 1-semiseparable mask for chunk n.
49 /// This step is a batched GEMM (exploits tensor cores).
50 ///
51 /// ### Step 2 — Chunk state (state_bnhpr)
52 ///
53 /// Compute the final SSM state of each chunk assuming zero initial state
54 /// (§4.1, Eq. 20):
55 ///
56 /// ```text
57 /// s[n] = Σ_{t ∈ chunk n} exp(A_cum[end] - A_cum[t]) · B̄[t] · x[t]ᵀ
58 /// ```
59 ///
60 /// This is also a batched GEMM and is fully parallel across chunks.
61 ///
62 /// ### Step 3 — Inter-chunk state scan (state passing)
63 ///
64 /// Propagate the true hidden state across chunk boundaries using the
65 /// recurrence (§4.1, Eq. 22):
66 ///
67 /// ```text
68 /// h[n] = Ā[n]_chunk · h[n-1] + s[n]
69 /// ```
70 ///
71 /// where `Ā[n]_chunk = exp(Σ_{t ∈ chunk n} Δₜ · A)` is the cumulative
72 /// decay over the whole chunk. This step is implemented as a single
73 /// batched matrix multiplication using the 1-semiseparable structure of
74 /// the inter-chunk decay matrix (same `segsum` trick, now over chunks).
75 /// The scan has length `nchunks = T/Q` rather than T, so its cost is
76 /// negligible for typical chunk sizes.
77 ///
78 /// ### Step 4 — State-to-output (Y_off)
79 ///
80 /// For each chunk n, compute the contribution of the true initial state
81 /// `h[n-1]` to the outputs within that chunk (§4.1, Eq. 23):
82 ///
83 /// ```text
84 /// Y_off[n, t] = C[n, t]ᵀ · exp(A_cum[t]) · h[n-1]
85 /// ```
86 ///
87 /// This is again a batched GEMM.
88 ///
89 /// ### Final output (with D skip-connection)
90 ///
91 /// ```text
92 /// Y = Y_diag + Y_off + D · X
93 /// ```
94 #[allow(non_snake_case)]
95 pub fn ssd_minimal(input: super::Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>) {
96 let [batch, nchunks, chunk_len, nheads, per_head_dim] = input.x_bnlhp.dims();
97 let [.., ngroups, state_rank] = input.b_bnlgr.dims();
98 let device = &input.x_bnlhp.device();
99
100 assert_eq!(nheads % ngroups, 0);
101 assert!(nchunks >= 1, "sequence must be non-empty");
102 assert!(chunk_len > 0, "chunk_len must be positive");
103
104 // ── Compute discretised parameters ────────────────────────────────────
105 // Ā = exp(Δ · A) stored in log-space as a_bnlh = Δ · A (negative)
106 // B̄ = Δ · B (Euler/ZOH approximation)
107
108 // Expand B and C from ngroups to nheads by repeating each group's
109 // projection across all heads_per_group heads in that group.
110 let heads_per_group = nheads / ngroups;
111
112 // b_bnlgr → b_bnlhr [batch, nchunks, chunk_len, nheads, state_rank]
113 let b_bnlhr = input
114 .b_bnlgr
115 .clone()
116 .unsqueeze_dim::<6>(4) // b_bnlg1r
117 .expand([
118 batch,
119 nchunks,
120 chunk_len,
121 ngroups,
122 heads_per_group,
123 state_rank,
124 ]) // b_bnlgHr
125 .reshape([batch, nchunks, chunk_len, nheads, state_rank]);
126
127 // c_bnlgr → c_bnlhr [batch, nchunks, chunk_len, nheads, state_rank]
128 let c_bnlhr = input
129 .c_bnlgr
130 .clone()
131 .unsqueeze_dim::<6>(4) // c_bnlg1r
132 .expand([
133 batch,
134 nchunks,
135 chunk_len,
136 ngroups,
137 heads_per_group,
138 state_rank,
139 ]) // c_bnlgHr
140 .reshape([batch, nchunks, chunk_len, nheads, state_rank]);
141
142 // B̄ₜ = Δₜ · Bₜ [batch, nchunks, chunk_len, nheads, state_rank]
143 let delta_b_bnlhr = input.dt_bnlh.clone().unsqueeze_dim(4) * b_bnlhr.clone();
144 assert_eq!(
145 [batch, nchunks, chunk_len, nheads, state_rank],
146 delta_b_bnlhr.dims()
147 );
148 san(&delta_b_bnlhr);
149
150 // Ā in log-space: a_bnlh = Δₜ · A
151 let a_bnlh = input.dt_bnlh.clone()
152 * input
153 .a_decay_h
154 .clone()
155 .unsqueeze_dims::<4>(&[0, 1, 2]) // a_head_decay_111h
156 .expand([batch, nchunks, chunk_len, nheads]);
157 san(&a_bnlh);
158
159 // ── Reshape ───────────────────────────────────────────────────────────
160 // a (log-decay): [B, nchunks, chunk_len, H] → [B, H, nchunks, chunk_len]
161 let a_bhnl = a_bnlh.permute([0, 3, 1, 2]);
162 assert_eq!([batch, nheads, nchunks, chunk_len], a_bhnl.dims());
163
164 // Cumulative sum of log-decays within each chunk.
165 // a_cumsum_bhnl[b, h, n, t] = Σ_{k=0..t} Δ_{n,k} · A
166 // This is the log of the cumulative decay factor from the start of the
167 // chunk to position t (inclusive).
168 let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
169 assert_eq!([batch, nheads, nchunks, chunk_len], a_cumsum_bhnl.dims());
170 san(&a_cumsum_bhnl);
171
172 // =============================================================
173 // STEP 1: Intra-chunk outputs (diagonal blocks, Y_diag)
174 // =============================================================
175 //
176 // For each chunk n, compute Y_diag[n] = (L[n] ∘ C[n] B[n]ᵀ) · X[n]
177 // where L[n] ∈ ℝ^{Q×Q} is the 1-semiseparable mask for the chunk.
178 //
179 // L[n]_{i,j} = exp(Σ_{k=j+1..i} a_{n,k}) for i ≥ j
180 // = exp(a_cumsum[n,i] - a_cumsum[n,j]) (using segsum trick)
181 //
182 // Implementation uses three batched matmuls:
183 // (a) C[n] · B[n]ᵀ (contract over state_rank N) → temp1 [B, nchunks, H, Q, Q]
184 // (b) temp1 ∘ L[n] → temp2 [B, nchunks, H, Q, Q]
185 // (c) temp2 · X[n] (contract over Q) → Y_diag [B, nchunks, Q, H, P]
186 let y_diag_bnlhp = {
187 // Permute to [B, nchunks, H, Q, N] for the matmul along Q and N.
188 let b_bnhlr = delta_b_bnlhr.clone().permute([0, 1, 3, 2, 4]);
189 let c_bnhlr = c_bnlhr.clone().permute([0, 1, 3, 2, 4]);
190 assert_eq!(
191 [batch, nchunks, nheads, chunk_len, state_rank],
192 b_bnhlr.dims()
193 );
194 assert_eq!(
195 [batch, nchunks, nheads, chunk_len, state_rank],
196 c_bnhlr.dims()
197 );
198
199 // (a) C[n] · B[n]ᵀ → [B, nchunks, H, Q, Q]
200 // Contracts over state_rank N.
201 let b_bnhrl = b_bnhlr.permute([0, 1, 2, 4, 3]); // [B, n, H, N, Q]
202 let cb_bnhll = c_bnhlr.matmul(b_bnhrl); // [B, n, H, Q, Q]
203 assert_eq!(
204 [batch, nchunks, nheads, chunk_len, chunk_len],
205 cb_bnhll.dims()
206 );
207 san(&cb_bnhll);
208
209 // (b) Element-wise multiply with the 1-SS mask L.
210 // L = exp(segsum(a_bhnl)) [B, H, nchunks, Q, Q]
211 // Lᵢⱼ = exp(a_cumsum[n,i] - a_cumsum[n,j]) (Eq. 4–5)
212 let l_bhnll = segsum(a_bhnl.clone()).exp();
213 assert_eq!(
214 [batch, nheads, nchunks, chunk_len, chunk_len],
215 l_bhnll.dims()
216 );
217 san(&l_bhnll);
218
219 // Permute both to [B, n, Q, H, Q] for the broadcast multiply.
220 let cb_bnlhl = cb_bnhll.permute([0, 1, 3, 2, 4]);
221 assert_eq!(
222 [batch, nchunks, chunk_len, nheads, chunk_len],
223 cb_bnlhl.dims()
224 );
225 let l_bnlhl = l_bhnll.permute([0, 2, 3, 1, 4]);
226 assert_eq!(
227 [batch, nchunks, chunk_len, nheads, chunk_len],
228 l_bnlhl.dims()
229 );
230 san(&cb_bnlhl);
231 san(&l_bnlhl);
232 let masked_cb_bnlhl = cb_bnlhl * l_bnlhl;
233 san(&masked_cb_bnlhl);
234
235 // (c) masked_CB · X → Y_diag.
236 // Contract over the last Q dimension.
237 let masked_cb_bnhll = masked_cb_bnlhl.permute([0, 1, 3, 2, 4]);
238 assert_eq!(
239 [batch, nchunks, nheads, chunk_len, chunk_len],
240 masked_cb_bnhll.dims()
241 );
242
243 let x_bnhlp = input.x_bnlhp.clone().permute([0, 1, 3, 2, 4]); // [B, n, H, Q, P]
244 assert_eq!(
245 [batch, nchunks, nheads, chunk_len, per_head_dim],
246 x_bnhlp.dims()
247 );
248
249 let y_diag_bnhlp = masked_cb_bnhll.matmul(x_bnhlp);
250 assert_eq!(
251 [batch, nchunks, nheads, chunk_len, per_head_dim],
252 y_diag_bnhlp.dims()
253 );
254 san(&y_diag_bnhlp);
255
256 y_diag_bnhlp.permute([0, 1, 3, 2, 4]) // → [B, n, Q, H, P]
257 };
258 assert_eq!(
259 [batch, nchunks, chunk_len, nheads, per_head_dim],
260 y_diag_bnlhp.dims()
261 );
262
263 // =============================================================
264 // STEP 2: Compute chunk state (input → state)
265 // =============================================================
266 //
267 // For each chunk n, compute the SSM state at the end of the chunk
268 // assuming the initial state is zero (Eq. 20):
269 //
270 // s[n] = Σ_{t ∈ [0, Q)} exp(a_cumsum[n,-1] - a_cumsum[n,t]) · B̄[n,t] · x[n,t]ᵀ
271 //
272 // Equivalently:
273 // decay_state[n, t] = exp(a_cum_last[n] - a_cum[n, t])
274 // s[n] = Σ_t decay_state[n, t] · x[n, t]ᵀ · B̄[n, t] (outer product over P and N)
275 //
276 // This is a batched GEMM, fully parallel across n and b.
277 let state_bnhpr = {
278 // Decay from each position t to the end of the chunk:
279 // decay_state[n, t] = exp(a_cum[n, Q-1] - a_cum[n, t])
280 let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
281 assert_eq!([batch, nheads, nchunks, 1], a_cumsum_last_bhn1.dims());
282
283 let decay_state_bhnl = (a_cumsum_last_bhn1 - a_cumsum_bhnl.clone()).exp();
284 assert_eq!([batch, nheads, nchunks, chunk_len], decay_state_bhnl.dims());
285 san(&decay_state_bhnl);
286
287 // Multiply decay into x: decay[n, t] · x[n, t] → [B, n, Q, H, P]
288 let decay_state_bnlh1 = decay_state_bhnl.permute([0, 2, 3, 1]).unsqueeze_dim(4);
289 assert_eq!(
290 [batch, nchunks, chunk_len, nheads, 1],
291 decay_state_bnlh1.dims()
292 );
293 let decayed_x_bnlhp = decay_state_bnlh1 * input.x_bnlhp.clone();
294 assert_eq!(
295 [batch, nchunks, chunk_len, nheads, per_head_dim],
296 decayed_x_bnlhp.dims()
297 );
298 san(&decayed_x_bnlhp);
299
300 // Contract over Q: (decayed_x[n, :, h, :])ᵀ · B̄[n, :, h, :]
301 // [B, n, H, P, Q] × [B, n, H, Q, N] → [B, n, H, P, N]
302 let decayed_x_bnhpl = decayed_x_bnlhp.permute([0, 1, 3, 4, 2]);
303 assert_eq!(
304 [batch, nchunks, nheads, per_head_dim, chunk_len],
305 decayed_x_bnhpl.dims()
306 );
307 let b_bnhlr = delta_b_bnlhr.clone().permute([0, 1, 3, 2, 4]);
308 assert_eq!(
309 [batch, nchunks, nheads, chunk_len, state_rank],
310 b_bnhlr.dims()
311 );
312
313 decayed_x_bnhpl.matmul(b_bnhlr)
314 };
315 assert_eq!(
316 [batch, nchunks, nheads, per_head_dim, state_rank],
317 state_bnhpr.dims()
318 );
319 san(&state_bnhpr);
320
321 // =============================================================
322 // STEP 3: Inter-chunk state scan (state passing)
323 // =============================================================
324 //
325 // Propagate hidden state across chunk boundaries. The recurrence is
326 //
327 // h[n] = Ā_chunk[n] · h[n-1] + s[n] (Eq. 22)
328 //
329 // where Ā_chunk[n] = exp(Σ_{t ∈ chunk n} Δₜ · A) = exp(a_cum[n, Q-1]).
330 //
331 // Unrolling the recurrence gives a matrix form identical to Step 2 but
332 // at the chunk level: each new state is a weighted sum of all previous
333 // chunk state. We implement this with the same 1-SS segsum trick,
334 // now applied over the nchunks dimension.
335 //
336 // The result is `new_state[n]`, the true hidden state entering chunk n,
337 // for n ∈ {0, ..., nchunks-1}, plus the final state after all chunks.
338 let (state_bnhpr, final_state_bnpr) = {
339 // Prepend the initial state h₀ to the array of chunk state.
340 // Shape: [B, 1+nchunks, H, P, N]
341 let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
342 assert_eq!(
343 [batch, 1, nheads, per_head_dim, state_rank],
344 initial_state_b1hpr.dims()
345 );
346
347 // Optionally add learnable initial state (broadcast over batch).
348 let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
349 let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
350 batch,
351 1,
352 nheads,
353 per_head_dim,
354 state_rank,
355 ]);
356 initial_state_b1hpr + init_b1hpr
357 } else {
358 initial_state_b1hpr
359 };
360 san(&initial_state_b1hpr);
361
362 let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
363 assert_eq!(
364 [batch, 1 + nchunks, nheads, per_head_dim, state_rank],
365 state_bNhpr.dims()
366 );
367
368 // Build the inter-chunk decay matrix using segsum.
369 // a_cum_last[n] = Σ_{t ∈ chunk n} Δₜ · A (the total log-decay of chunk n)
370 let a_cumsum_last_bhn = a_cumsum_bhnl
371 .clone()
372 .slice(s![.., .., .., -1])
373 .squeeze_dim(3); // [B, H, nchunks]
374 assert_eq!([batch, nheads, nchunks], a_cumsum_last_bhn.dims());
375
376 // Prepend a zero for the initial state (no decay before chunk 0).
377 let a_chunk_pad_bhN = Tensor::cat(
378 vec![
379 Tensor::zeros(Shape::new([batch, nheads, 1]), device),
380 a_cumsum_last_bhn,
381 ],
382 2,
383 ); // [B, H, 1+nchunks]
384 assert_eq!([batch, nheads, 1 + nchunks], a_chunk_pad_bhN.dims());
385
386 // 1-SS inter-chunk decay matrix.
387 // decay_chunk[i, j] = exp(Σ_{k=j+1..i} a_cum_last[k]) (i ≥ j)
388 // Row i of this matrix, when multiplied by the state vector,
389 // gives the true hidden state entering chunk i.
390 let decay_chunk_bhNN = segsum(a_chunk_pad_bhN).exp();
391 assert_eq!(
392 [batch, nheads, 1 + nchunks, 1 + nchunks],
393 decay_chunk_bhNN.dims()
394 );
395 san(&decay_chunk_bhNN);
396
397 // Flatten the state's (P, N) dimensions for the matmul.
398 let flat_state_dim = per_head_dim * state_rank; // f = P·N
399 let state_bhNf = state_bNhpr
400 .clone()
401 .permute([0, 2, 1, 3, 4]) // [B, H, 1+n, P, N]
402 .reshape([batch, nheads, 1 + nchunks, flat_state_dim]);
403 assert_eq!(
404 [batch, nheads, 1 + nchunks, flat_state_dim],
405 state_bhNf.dims()
406 );
407
408 // Matmul: [B, H, 1+n, 1+n] × [B, H, 1+n, f] → [B, H, 1+n, f]
409 let new_state_bhNf = decay_chunk_bhNN.matmul(state_bhNf);
410 assert_eq!(
411 [batch, nheads, 1 + nchunks, flat_state_dim],
412 new_state_bhNf.dims()
413 );
414 san(&new_state_bhNf);
415
416 let new_state_bhNpr =
417 new_state_bhNf.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
418
419 // Slice to get:
420 // state[0..nchunks] — the initial state entering each chunk
421 // state[nchunks] — the final state after the last real token
422 //
423 // For padded sequences the padding steps are identity operations
424 // (Δ=0 ⇒ Ā=1, B̄=0), so the state is carried unchanged through the
425 // pad region, and `state[nchunks]` is the correct final state.
426 let state_bhnpr = new_state_bhNpr
427 .clone()
428 .slice(s![.., .., 0..nchunks, .., ..]);
429 let final_state_bhpr = new_state_bhNpr
430 .slice(s![.., .., nchunks, .., ..])
431 .squeeze_dim(2);
432
433 (
434 state_bhnpr.permute([0, 2, 1, 3, 4]), // → [B, n, H, P, N]
435 final_state_bhpr,
436 )
437 };
438 assert_eq!(
439 [batch, nchunks, nheads, per_head_dim, state_rank],
440 state_bnhpr.dims()
441 );
442 assert_eq!(
443 [batch, nheads, per_head_dim, state_rank],
444 final_state_bnpr.dims()
445 );
446
447 // =============================================================
448 // STEP 4: State-to-output contribution (Y_off)
449 // =============================================================
450 //
451 // For each chunk n, compute the contribution of the true initial state
452 // h[n-1] to the outputs within that chunk (Eq. 23):
453 //
454 // Y_off[n, t] = C[n, t]ᵀ · exp(a_cumsum[n, t]) · h[n-1]
455 // = exp(a_cum[n,t]) · (C[n,t]ᵀ · h[n-1])
456 //
457 // where the scalar `exp(a_cum[n,t])` is the cumulative decay from the
458 // start of the chunk to position t.
459 //
460 // Implementation:
461 // (a) C[n] · h[n-1]ᵀ (contract over N) → [B, n, H, Q, P]
462 // (b) element-wise multiply with exp(a_cum)
463 let y_off_bnlhp = {
464 // exp(a_cumsum[n, t]): decay from start of chunk to position t.
465 let state_decay_out_bhnl = a_cumsum_bhnl.exp();
466 assert_eq!(
467 [batch, nheads, nchunks, chunk_len],
468 state_decay_out_bhnl.dims()
469 );
470 san(&state_decay_out_bhnl);
471
472 // (a) C[n] · h[n-1]ᵀ → [B, n, H, Q, P]
473 // C: [B, n, H, Q, N], h: [B, n, H, N, P] (transposed from [B,n,H,P,N])
474 let c_bnhlr = c_bnlhr.permute([0, 1, 3, 2, 4]); // [B, n, H, Q, N]
475 assert_eq!(
476 [batch, nchunks, nheads, chunk_len, state_rank],
477 c_bnhlr.dims()
478 );
479
480 let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]); // [B, n, H, N, P]
481 assert_eq!(
482 [batch, nchunks, nheads, state_rank, per_head_dim],
483 state_bnhrp.dims()
484 );
485
486 let ch_bnhlp = c_bnhlr.matmul(state_bnhrp); // [B, n, H, Q, P]
487 assert_eq!(
488 [batch, nchunks, nheads, chunk_len, per_head_dim],
489 ch_bnhlp.dims()
490 );
491 san(&ch_bnhlp);
492
493 // (b) Multiply by the intra-chunk cumulative decay.
494 let state_decay_out_bnhl1 = state_decay_out_bhnl.permute([0, 2, 1, 3]).unsqueeze_dim(4);
495 assert_eq!(
496 [batch, nchunks, nheads, chunk_len, 1],
497 state_decay_out_bnhl1.dims()
498 );
499
500 let y_off_bnhlp = ch_bnhlp * state_decay_out_bnhl1;
501 assert_eq!(
502 [batch, nchunks, nheads, chunk_len, per_head_dim],
503 y_off_bnhlp.dims()
504 );
505 san(&y_off_bnhlp);
506
507 y_off_bnhlp.permute([0, 1, 3, 2, 4]) // → [B, n, Q, H, P]
508 };
509 assert_eq!(
510 [batch, nchunks, chunk_len, nheads, per_head_dim],
511 y_off_bnlhp.dims()
512 );
513
514 // ── Combine Y_diag and Y_off, undo padding ────────────────────────────
515 let y_bnlhp = y_diag_bnlhp + y_off_bnlhp;
516 san(&y_bnlhp);
517
518 // ── D skip connection ─────────────────────────────────────────────────
519 // yₜ += D · xₜ
520 // D is a per-head scalar; broadcast over batch, sequence, and per_head_dim.
521 let d_bnlhp = input
522 .d_h
523 .unsqueeze_dims::<5>(&[0, 1, 2, 4]) // d_111h1
524 .expand([batch, nchunks, chunk_len, nheads, per_head_dim]);
525 let y_bnlhp = y_bnlhp + d_bnlhp * input.x_bnlhp;
526 san(&y_bnlhp);
527
528 (y_bnlhp, final_state_bnpr)
529 }
530}
531
532// ---------------------------------------------------------------------------
533// segsum (stable segment sum for the 1-SS mask)
534// ---------------------------------------------------------------------------
535
536/// Compute stable segment sums for constructing the 1-semiseparable mask.
537///
538/// Given a tensor `x` of shape `[..., T]`, returns a tensor of shape
539/// `[..., T, T]` where:
540///
541/// ```text
542/// out[..., i, j] = Σ_{k=j+1}^{i} x[..., k] for i ≥ j (lower triangle)
543/// out[..., i, j] = -∞ for i < j (upper triangle)
544/// ```
545///
546/// The 1-semiseparable mask is then obtained by exponentiating:
547///
548/// ```text
549/// L = exp(segsum(log_A))
550/// L[i, j] = exp(log_A[j+1] + ... + log_A[i])
551/// = A[j+1] · A[j+2] · ... · A[i] (Eq. 4–5 in the paper)
552/// ```
553///
554/// ## Implementation
555///
556/// A naive computation of all pairwise products `A[j+1]·...·A[i]` would
557/// suffer from underflow for long sequences (e.g. `0.9^1000 ≈ 2.6×10⁻⁴⁶`).
558/// Working in log-space and computing differences of prefix sums avoids this:
559///
560/// ```text
561/// segsum(x)[i, j] = cumsum(x)[i] - cumsum(x)[j]
562/// ```
563///
564/// The upper triangle is masked to -∞ so that `exp(segsum(...))` gives 0
565/// for non-causal positions (the strict upper triangle of L must be zero).
566///
567/// ## Const-generic dimension handling
568///
569/// This function is generic over the input rank `D` and returns a tensor of
570/// rank `D + 1`. Burn requires the output rank to be known at compile time,
571/// which is achieved through the const generic expression `{ D + 1 }`.
572fn segsum<B: Backend, const D: usize, const D2: usize>(x: Tensor<B, D>) -> Tensor<B, D2> {
573 assert!(D > 0);
574 assert_eq!(D + 1, D2);
575
576 // cumsum[..., t] = x[..., 0] + x[..., 1] + ... + x[..., t]
577 let x_cumsum = x.cumsum(D - 1);
578 san(&x_cumsum);
579
580 // Broadcast along two different axes to compute all pairwise differences:
581 // x_cumsum_row[..., i, j] = cumsum[..., i] (i varies along axis D)
582 // x_cumsum_col[..., i, j] = cumsum[..., j] (j varies along axis D-1)
583 let x_cumsum_row = x_cumsum.clone().unsqueeze_dim(D); // [..., T, 1]
584 let x_cumsum_col = x_cumsum.unsqueeze_dim(D - 1); // [..., 1, T]
585
586 // diff[..., i, j] = cumsum[i] - cumsum[j]
587 // = x[j+1] + ... + x[i] for i ≥ j
588 let diff = x_cumsum_row - x_cumsum_col; // [..., T, T]
589 san(&diff);
590
591 // Mask the strict upper triangle (i < j) with -∞.
592 // triu(1) returns a tensor that is -∞ above the main diagonal and 0
593 // elsewhere; adding it to `diff` zeroes out the upper triangle of exp(diff).
594 let neg_inf_mask = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(1);
595 diff + neg_inf_mask
596}