Skip to main content

burn_mamba/mamba2/ssd/
minimal.rs

1//! ## The Chunkwise SSD Algorithm
2//!
3//! During training (and prefill), a naive sequential recurrence cannot
4//! exploit GPU tensor cores.  The **chunkwise SSD algorithm** (§4 of the
5//! paper) achieves this by splitting the sequence into chunks of length Q
6//! and decomposing the computation into four steps:
7//!
8//! ```text
9//!   Step 1  (intra-chunk, quadratic form)   →  Y_diag
10//!   Step 2  (input → chunk state)           →  state_bnhpr
11//!   Step 3  (inter-chunk state scan)        →  state_bnhpr, final_state
12//!   Step 4  (chunk state → output)          →  Y_off
13//!
14//!   Y = Y_diag + Y_off
15//! ```
16//!
17//! Steps 1, 2, 4 are fully parallel across chunks and use batched matrix
18//! multiplications (exploiting tensor cores).  Step 3 is a short sequential
19//! scan over `T/Q` elements rather than `T`.
20
21use crate::mamba2::prelude::*;
22use crate::utils::sanity::sanity as san;
23use burn::prelude::*;
24
25impl<B: Backend> Mamba2<B> {
26    // -----------------------------------------------------------------------
27    // chunked_selective_scan
28    // -----------------------------------------------------------------------
29
30    /// Minimal chunkwise SSD algorithm.
31    ///
32    /// Implements the four-step decomposition of the semiseparable matrix
33    /// multiplication described in §4 of the paper.  The sequence of length T
34    /// is split into `nchunks = ⌈T/Q⌉` chunks of length Q.
35    ///
36    /// ## The four steps
37    ///
38    /// ### Step 1 — Intra-chunk outputs (Y_diag)
39    ///
40    /// Within each chunk, compute the output assuming the initial hidden state
41    /// is zero.  This is the *quadratic attention form* of the SSD layer
42    /// restricted to a window of Q tokens (§4.1):
43    ///
44    /// ```text
45    ///   Y_diag[n] = (L[n] ∘ C[n] B[n]ᵀ) · X[n]
46    /// ```
47    ///
48    /// where `L[n]` is the Q×Q 1-semiseparable mask for chunk n.
49    /// This step is a batched GEMM (exploits tensor cores).
50    ///
51    /// ### Step 2 — Chunk state (state_bnhpr)
52    ///
53    /// Compute the final SSM state of each chunk assuming zero initial state
54    /// (§4.1, Eq. 20):
55    ///
56    /// ```text
57    ///   s[n] = Σ_{t ∈ chunk n}  exp(A_cum[end] - A_cum[t]) · B̄[t] · x[t]ᵀ
58    /// ```
59    ///
60    /// This is also a batched GEMM and is fully parallel across chunks.
61    ///
62    /// ### Step 3 — Inter-chunk state scan (state passing)
63    ///
64    /// Propagate the true hidden state across chunk boundaries using the
65    /// recurrence (§4.1, Eq. 22):
66    ///
67    /// ```text
68    ///   h[n] = Ā[n]_chunk · h[n-1] + s[n]
69    /// ```
70    ///
71    /// where `Ā[n]_chunk = exp(Σ_{t ∈ chunk n} Δₜ · A)` is the cumulative
72    /// decay over the whole chunk.  This step is implemented as a single
73    /// batched matrix multiplication using the 1-semiseparable structure of
74    /// the inter-chunk decay matrix (same `segsum` trick, now over chunks).
75    /// The scan has length `nchunks = T/Q` rather than T, so its cost is
76    /// negligible for typical chunk sizes.
77    ///
78    /// ### Step 4 — State-to-output (Y_off)
79    ///
80    /// For each chunk n, compute the contribution of the true initial state
81    /// `h[n-1]` to the outputs within that chunk (§4.1, Eq. 23):
82    ///
83    /// ```text
84    ///   Y_off[n, t] = C[n, t]ᵀ · exp(A_cum[t]) · h[n-1]
85    /// ```
86    ///
87    /// This is again a batched GEMM.
88    ///
89    /// ### Final output (with D skip-connection)
90    ///
91    /// ```text
92    ///   Y = Y_diag + Y_off + D · X
93    /// ```
94    #[allow(non_snake_case)]
95    pub fn ssd_minimal(input: super::Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>) {
96        let [batch, nchunks, chunk_len, nheads, per_head_dim] = input.x_bnlhp.dims();
97        let [.., ngroups, state_rank] = input.b_bnlgr.dims();
98        let device = &input.x_bnlhp.device();
99
100        assert_eq!(nheads % ngroups, 0);
101        assert!(nchunks >= 1, "sequence must be non-empty");
102        assert!(chunk_len > 0, "chunk_len must be positive");
103
104        // ── Compute discretised parameters ────────────────────────────────────
105        // Ā = exp(Δ · A)   stored in log-space as  a_bnlh = Δ · A  (negative)
106        // B̄ = Δ · B        (Euler/ZOH approximation)
107
108        // Expand B and C from ngroups to nheads by repeating each group's
109        // projection across all heads_per_group heads in that group.
110        let heads_per_group = nheads / ngroups;
111
112        // b_bnlgr → b_bnlhr  [batch, nchunks, chunk_len, nheads, state_rank]
113        let b_bnlhr = input
114            .b_bnlgr
115            .clone()
116            .unsqueeze_dim::<6>(4) // b_bnlg1r
117            .expand([
118                batch,
119                nchunks,
120                chunk_len,
121                ngroups,
122                heads_per_group,
123                state_rank,
124            ]) // b_bnlgHr
125            .reshape([batch, nchunks, chunk_len, nheads, state_rank]);
126
127        // c_bnlgr → c_bnlhr  [batch, nchunks, chunk_len, nheads, state_rank]
128        let c_bnlhr = input
129            .c_bnlgr
130            .clone()
131            .unsqueeze_dim::<6>(4) // c_bnlg1r
132            .expand([
133                batch,
134                nchunks,
135                chunk_len,
136                ngroups,
137                heads_per_group,
138                state_rank,
139            ]) // c_bnlgHr
140            .reshape([batch, nchunks, chunk_len, nheads, state_rank]);
141
142        // B̄ₜ = Δₜ · Bₜ   [batch, nchunks, chunk_len, nheads, state_rank]
143        let delta_b_bnlhr = input.dt_bnlh.clone().unsqueeze_dim(4) * b_bnlhr.clone();
144        assert_eq!(
145            [batch, nchunks, chunk_len, nheads, state_rank],
146            delta_b_bnlhr.dims()
147        );
148        san(&delta_b_bnlhr);
149
150        // Ā in log-space: a_bnlh = Δₜ · A
151        let a_bnlh = input.dt_bnlh.clone()
152            * input
153                .a_decay_h
154                .clone()
155                .unsqueeze_dims::<4>(&[0, 1, 2]) // a_head_decay_111h
156                .expand([batch, nchunks, chunk_len, nheads]);
157        san(&a_bnlh);
158
159        // ── Reshape ───────────────────────────────────────────────────────────
160        // a (log-decay): [B, nchunks, chunk_len, H] → [B, H, nchunks, chunk_len]
161        let a_bhnl = a_bnlh.permute([0, 3, 1, 2]);
162        assert_eq!([batch, nheads, nchunks, chunk_len], a_bhnl.dims());
163
164        // Cumulative sum of log-decays within each chunk.
165        // a_cumsum_bhnl[b, h, n, t] = Σ_{k=0..t} Δ_{n,k} · A
166        // This is the log of the cumulative decay factor from the start of the
167        // chunk to position t (inclusive).
168        let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
169        assert_eq!([batch, nheads, nchunks, chunk_len], a_cumsum_bhnl.dims());
170        san(&a_cumsum_bhnl);
171
172        // =============================================================
173        // STEP 1: Intra-chunk outputs (diagonal blocks, Y_diag)
174        // =============================================================
175        //
176        // For each chunk n, compute Y_diag[n] = (L[n] ∘ C[n] B[n]ᵀ) · X[n]
177        // where L[n] ∈ ℝ^{Q×Q} is the 1-semiseparable mask for the chunk.
178        //
179        // L[n]_{i,j} = exp(Σ_{k=j+1..i} a_{n,k})  for i ≥ j
180        //            = exp(a_cumsum[n,i] - a_cumsum[n,j])   (using segsum trick)
181        //
182        // Implementation uses three batched matmuls:
183        //   (a) C[n] · B[n]ᵀ  (contract over state_rank N)  → temp1 [B, nchunks, H, Q, Q]
184        //   (b) temp1 ∘ L[n]                                 → temp2 [B, nchunks, H, Q, Q]
185        //   (c) temp2 · X[n]  (contract over Q)              → Y_diag [B, nchunks, Q, H, P]
186        let y_diag_bnlhp = {
187            // Permute to [B, nchunks, H, Q, N] for the matmul along Q and N.
188            let b_bnhlr = delta_b_bnlhr.clone().permute([0, 1, 3, 2, 4]);
189            let c_bnhlr = c_bnlhr.clone().permute([0, 1, 3, 2, 4]);
190            assert_eq!(
191                [batch, nchunks, nheads, chunk_len, state_rank],
192                b_bnhlr.dims()
193            );
194            assert_eq!(
195                [batch, nchunks, nheads, chunk_len, state_rank],
196                c_bnhlr.dims()
197            );
198
199            // (a) C[n] · B[n]ᵀ → [B, nchunks, H, Q, Q]
200            //     Contracts over state_rank N.
201            let b_bnhrl = b_bnhlr.permute([0, 1, 2, 4, 3]); // [B, n, H, N, Q]
202            let cb_bnhll = c_bnhlr.matmul(b_bnhrl); // [B, n, H, Q, Q]
203            assert_eq!(
204                [batch, nchunks, nheads, chunk_len, chunk_len],
205                cb_bnhll.dims()
206            );
207            san(&cb_bnhll);
208
209            // (b) Element-wise multiply with the 1-SS mask L.
210            //     L = exp(segsum(a_bhnl))  [B, H, nchunks, Q, Q]
211            //     Lᵢⱼ = exp(a_cumsum[n,i] - a_cumsum[n,j])  (Eq. 4–5)
212            let l_bhnll = segsum(a_bhnl.clone()).exp();
213            assert_eq!(
214                [batch, nheads, nchunks, chunk_len, chunk_len],
215                l_bhnll.dims()
216            );
217            san(&l_bhnll);
218
219            // Permute both to [B, n, Q, H, Q] for the broadcast multiply.
220            let cb_bnlhl = cb_bnhll.permute([0, 1, 3, 2, 4]);
221            assert_eq!(
222                [batch, nchunks, chunk_len, nheads, chunk_len],
223                cb_bnlhl.dims()
224            );
225            let l_bnlhl = l_bhnll.permute([0, 2, 3, 1, 4]);
226            assert_eq!(
227                [batch, nchunks, chunk_len, nheads, chunk_len],
228                l_bnlhl.dims()
229            );
230            san(&cb_bnlhl);
231            san(&l_bnlhl);
232            let masked_cb_bnlhl = cb_bnlhl * l_bnlhl;
233            san(&masked_cb_bnlhl);
234
235            // (c) masked_CB · X → Y_diag.
236            //     Contract over the last Q dimension.
237            let masked_cb_bnhll = masked_cb_bnlhl.permute([0, 1, 3, 2, 4]);
238            assert_eq!(
239                [batch, nchunks, nheads, chunk_len, chunk_len],
240                masked_cb_bnhll.dims()
241            );
242
243            let x_bnhlp = input.x_bnlhp.clone().permute([0, 1, 3, 2, 4]); // [B, n, H, Q, P]
244            assert_eq!(
245                [batch, nchunks, nheads, chunk_len, per_head_dim],
246                x_bnhlp.dims()
247            );
248
249            let y_diag_bnhlp = masked_cb_bnhll.matmul(x_bnhlp);
250            assert_eq!(
251                [batch, nchunks, nheads, chunk_len, per_head_dim],
252                y_diag_bnhlp.dims()
253            );
254            san(&y_diag_bnhlp);
255
256            y_diag_bnhlp.permute([0, 1, 3, 2, 4]) // → [B, n, Q, H, P]
257        };
258        assert_eq!(
259            [batch, nchunks, chunk_len, nheads, per_head_dim],
260            y_diag_bnlhp.dims()
261        );
262
263        // =============================================================
264        // STEP 2: Compute chunk state (input → state)
265        // =============================================================
266        //
267        // For each chunk n, compute the SSM state at the end of the chunk
268        // assuming the initial state is zero (Eq. 20):
269        //
270        //   s[n] = Σ_{t ∈ [0, Q)} exp(a_cumsum[n,-1] - a_cumsum[n,t]) · B̄[n,t] · x[n,t]ᵀ
271        //
272        // Equivalently:
273        //   decay_state[n, t] = exp(a_cum_last[n] - a_cum[n, t])
274        //   s[n] = Σ_t  decay_state[n, t] · x[n, t]ᵀ · B̄[n, t]     (outer product over P and N)
275        //
276        // This is a batched GEMM, fully parallel across n and b.
277        let state_bnhpr = {
278            // Decay from each position t to the end of the chunk:
279            //   decay_state[n, t] = exp(a_cum[n, Q-1] - a_cum[n, t])
280            let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
281            assert_eq!([batch, nheads, nchunks, 1], a_cumsum_last_bhn1.dims());
282
283            let decay_state_bhnl = (a_cumsum_last_bhn1 - a_cumsum_bhnl.clone()).exp();
284            assert_eq!([batch, nheads, nchunks, chunk_len], decay_state_bhnl.dims());
285            san(&decay_state_bhnl);
286
287            // Multiply decay into x: decay[n, t] · x[n, t]  → [B, n, Q, H, P]
288            let decay_state_bnlh1 = decay_state_bhnl.permute([0, 2, 3, 1]).unsqueeze_dim(4);
289            assert_eq!(
290                [batch, nchunks, chunk_len, nheads, 1],
291                decay_state_bnlh1.dims()
292            );
293            let decayed_x_bnlhp = decay_state_bnlh1 * input.x_bnlhp.clone();
294            assert_eq!(
295                [batch, nchunks, chunk_len, nheads, per_head_dim],
296                decayed_x_bnlhp.dims()
297            );
298            san(&decayed_x_bnlhp);
299
300            // Contract over Q: (decayed_x[n, :, h, :])ᵀ · B̄[n, :, h, :]
301            //   [B, n, H, P, Q] × [B, n, H, Q, N] → [B, n, H, P, N]
302            let decayed_x_bnhpl = decayed_x_bnlhp.permute([0, 1, 3, 4, 2]);
303            assert_eq!(
304                [batch, nchunks, nheads, per_head_dim, chunk_len],
305                decayed_x_bnhpl.dims()
306            );
307            let b_bnhlr = delta_b_bnlhr.clone().permute([0, 1, 3, 2, 4]);
308            assert_eq!(
309                [batch, nchunks, nheads, chunk_len, state_rank],
310                b_bnhlr.dims()
311            );
312
313            decayed_x_bnhpl.matmul(b_bnhlr)
314        };
315        assert_eq!(
316            [batch, nchunks, nheads, per_head_dim, state_rank],
317            state_bnhpr.dims()
318        );
319        san(&state_bnhpr);
320
321        // =============================================================
322        // STEP 3: Inter-chunk state scan (state passing)
323        // =============================================================
324        //
325        // Propagate hidden state across chunk boundaries.  The recurrence is
326        //
327        //   h[n] = Ā_chunk[n] · h[n-1] + s[n]     (Eq. 22)
328        //
329        // where Ā_chunk[n] = exp(Σ_{t ∈ chunk n} Δₜ · A) = exp(a_cum[n, Q-1]).
330        //
331        // Unrolling the recurrence gives a matrix form identical to Step 2 but
332        // at the chunk level: each new state is a weighted sum of all previous
333        // chunk state.  We implement this with the same 1-SS segsum trick,
334        // now applied over the nchunks dimension.
335        //
336        // The result is `new_state[n]`, the true hidden state entering chunk n,
337        // for n ∈ {0, ..., nchunks-1}, plus the final state after all chunks.
338        let (state_bnhpr, final_state_bnpr) = {
339            // Prepend the initial state h₀ to the array of chunk state.
340            // Shape: [B, 1+nchunks, H, P, N]
341            let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
342            assert_eq!(
343                [batch, 1, nheads, per_head_dim, state_rank],
344                initial_state_b1hpr.dims()
345            );
346
347            // Optionally add learnable initial state (broadcast over batch).
348            let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
349                let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
350                    batch,
351                    1,
352                    nheads,
353                    per_head_dim,
354                    state_rank,
355                ]);
356                initial_state_b1hpr + init_b1hpr
357            } else {
358                initial_state_b1hpr
359            };
360            san(&initial_state_b1hpr);
361
362            let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
363            assert_eq!(
364                [batch, 1 + nchunks, nheads, per_head_dim, state_rank],
365                state_bNhpr.dims()
366            );
367
368            // Build the inter-chunk decay matrix using segsum.
369            // a_cum_last[n] = Σ_{t ∈ chunk n} Δₜ · A   (the total log-decay of chunk n)
370            let a_cumsum_last_bhn = a_cumsum_bhnl
371                .clone()
372                .slice(s![.., .., .., -1])
373                .squeeze_dim(3); // [B, H, nchunks]
374            assert_eq!([batch, nheads, nchunks], a_cumsum_last_bhn.dims());
375
376            // Prepend a zero for the initial state (no decay before chunk 0).
377            let a_chunk_pad_bhN = Tensor::cat(
378                vec![
379                    Tensor::zeros(Shape::new([batch, nheads, 1]), device),
380                    a_cumsum_last_bhn,
381                ],
382                2,
383            ); // [B, H, 1+nchunks]
384            assert_eq!([batch, nheads, 1 + nchunks], a_chunk_pad_bhN.dims());
385
386            // 1-SS inter-chunk decay matrix.
387            //   decay_chunk[i, j] = exp(Σ_{k=j+1..i} a_cum_last[k])  (i ≥ j)
388            // Row i of this matrix, when multiplied by the state vector,
389            // gives the true hidden state entering chunk i.
390            let decay_chunk_bhNN = segsum(a_chunk_pad_bhN).exp();
391            assert_eq!(
392                [batch, nheads, 1 + nchunks, 1 + nchunks],
393                decay_chunk_bhNN.dims()
394            );
395            san(&decay_chunk_bhNN);
396
397            // Flatten the state's (P, N) dimensions for the matmul.
398            let flat_state_dim = per_head_dim * state_rank; // f = P·N
399            let state_bhNf = state_bNhpr
400                .clone()
401                .permute([0, 2, 1, 3, 4]) // [B, H, 1+n, P, N]
402                .reshape([batch, nheads, 1 + nchunks, flat_state_dim]);
403            assert_eq!(
404                [batch, nheads, 1 + nchunks, flat_state_dim],
405                state_bhNf.dims()
406            );
407
408            // Matmul: [B, H, 1+n, 1+n] × [B, H, 1+n, f] → [B, H, 1+n, f]
409            let new_state_bhNf = decay_chunk_bhNN.matmul(state_bhNf);
410            assert_eq!(
411                [batch, nheads, 1 + nchunks, flat_state_dim],
412                new_state_bhNf.dims()
413            );
414            san(&new_state_bhNf);
415
416            let new_state_bhNpr =
417                new_state_bhNf.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
418
419            // Slice to get:
420            //   state[0..nchunks]  — the initial state entering each chunk
421            //   state[nchunks]     — the final state after the last real token
422            //
423            // For padded sequences the padding steps are identity operations
424            // (Δ=0 ⇒ Ā=1, B̄=0), so the state is carried unchanged through the
425            // pad region, and `state[nchunks]` is the correct final state.
426            let state_bhnpr = new_state_bhNpr
427                .clone()
428                .slice(s![.., .., 0..nchunks, .., ..]);
429            let final_state_bhpr = new_state_bhNpr
430                .slice(s![.., .., nchunks, .., ..])
431                .squeeze_dim(2);
432
433            (
434                state_bhnpr.permute([0, 2, 1, 3, 4]), // → [B, n, H, P, N]
435                final_state_bhpr,
436            )
437        };
438        assert_eq!(
439            [batch, nchunks, nheads, per_head_dim, state_rank],
440            state_bnhpr.dims()
441        );
442        assert_eq!(
443            [batch, nheads, per_head_dim, state_rank],
444            final_state_bnpr.dims()
445        );
446
447        // =============================================================
448        // STEP 4: State-to-output contribution (Y_off)
449        // =============================================================
450        //
451        // For each chunk n, compute the contribution of the true initial state
452        // h[n-1] to the outputs within that chunk (Eq. 23):
453        //
454        //   Y_off[n, t] = C[n, t]ᵀ · exp(a_cumsum[n, t]) · h[n-1]
455        //               = exp(a_cum[n,t]) · (C[n,t]ᵀ · h[n-1])
456        //
457        // where the scalar `exp(a_cum[n,t])` is the cumulative decay from the
458        // start of the chunk to position t.
459        //
460        // Implementation:
461        //   (a) C[n] · h[n-1]ᵀ  (contract over N)  → [B, n, H, Q, P]
462        //   (b) element-wise multiply with exp(a_cum)
463        let y_off_bnlhp = {
464            // exp(a_cumsum[n, t]): decay from start of chunk to position t.
465            let state_decay_out_bhnl = a_cumsum_bhnl.exp();
466            assert_eq!(
467                [batch, nheads, nchunks, chunk_len],
468                state_decay_out_bhnl.dims()
469            );
470            san(&state_decay_out_bhnl);
471
472            // (a) C[n] · h[n-1]ᵀ  → [B, n, H, Q, P]
473            //   C: [B, n, H, Q, N],  h: [B, n, H, N, P]  (transposed from [B,n,H,P,N])
474            let c_bnhlr = c_bnlhr.permute([0, 1, 3, 2, 4]); // [B, n, H, Q, N]
475            assert_eq!(
476                [batch, nchunks, nheads, chunk_len, state_rank],
477                c_bnhlr.dims()
478            );
479
480            let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]); // [B, n, H, N, P]
481            assert_eq!(
482                [batch, nchunks, nheads, state_rank, per_head_dim],
483                state_bnhrp.dims()
484            );
485
486            let ch_bnhlp = c_bnhlr.matmul(state_bnhrp); // [B, n, H, Q, P]
487            assert_eq!(
488                [batch, nchunks, nheads, chunk_len, per_head_dim],
489                ch_bnhlp.dims()
490            );
491            san(&ch_bnhlp);
492
493            // (b) Multiply by the intra-chunk cumulative decay.
494            let state_decay_out_bnhl1 = state_decay_out_bhnl.permute([0, 2, 1, 3]).unsqueeze_dim(4);
495            assert_eq!(
496                [batch, nchunks, nheads, chunk_len, 1],
497                state_decay_out_bnhl1.dims()
498            );
499
500            let y_off_bnhlp = ch_bnhlp * state_decay_out_bnhl1;
501            assert_eq!(
502                [batch, nchunks, nheads, chunk_len, per_head_dim],
503                y_off_bnhlp.dims()
504            );
505            san(&y_off_bnhlp);
506
507            y_off_bnhlp.permute([0, 1, 3, 2, 4]) // → [B, n, Q, H, P]
508        };
509        assert_eq!(
510            [batch, nchunks, chunk_len, nheads, per_head_dim],
511            y_off_bnlhp.dims()
512        );
513
514        // ── Combine Y_diag and Y_off, undo padding ────────────────────────────
515        let y_bnlhp = y_diag_bnlhp + y_off_bnlhp;
516        san(&y_bnlhp);
517
518        // ── D skip connection ─────────────────────────────────────────────────
519        // yₜ += D · xₜ
520        // D is a per-head scalar; broadcast over batch, sequence, and per_head_dim.
521        let d_bnlhp = input
522            .d_h
523            .unsqueeze_dims::<5>(&[0, 1, 2, 4]) // d_111h1
524            .expand([batch, nchunks, chunk_len, nheads, per_head_dim]);
525        let y_bnlhp = y_bnlhp + d_bnlhp * input.x_bnlhp;
526        san(&y_bnlhp);
527
528        (y_bnlhp, final_state_bnpr)
529    }
530}
531
532// ---------------------------------------------------------------------------
533// segsum  (stable segment sum for the 1-SS mask)
534// ---------------------------------------------------------------------------
535
536/// Compute stable segment sums for constructing the 1-semiseparable mask.
537///
538/// Given a tensor `x` of shape `[..., T]`, returns a tensor of shape
539/// `[..., T, T]` where:
540///
541/// ```text
542///   out[..., i, j] = Σ_{k=j+1}^{i} x[..., k]     for i ≥ j   (lower triangle)
543///   out[..., i, j] = -∞                             for i < j   (upper triangle)
544/// ```
545///
546/// The 1-semiseparable mask is then obtained by exponentiating:
547///
548/// ```text
549///   L = exp(segsum(log_A))
550///   L[i, j] = exp(log_A[j+1] + ... + log_A[i])
551///            = A[j+1] · A[j+2] · ... · A[i]       (Eq. 4–5 in the paper)
552/// ```
553///
554/// ## Implementation
555///
556/// A naive computation of all pairwise products `A[j+1]·...·A[i]` would
557/// suffer from underflow for long sequences (e.g. `0.9^1000 ≈ 2.6×10⁻⁴⁶`).
558/// Working in log-space and computing differences of prefix sums avoids this:
559///
560/// ```text
561///   segsum(x)[i, j] = cumsum(x)[i] - cumsum(x)[j]
562/// ```
563///
564/// The upper triangle is masked to -∞ so that `exp(segsum(...))` gives 0
565/// for non-causal positions (the strict upper triangle of L must be zero).
566///
567/// ## Const-generic dimension handling
568///
569/// This function is generic over the input rank `D` and returns a tensor of
570/// rank `D + 1`.  Burn requires the output rank to be known at compile time,
571/// which is achieved through the const generic expression `{ D + 1 }`.
572fn segsum<B: Backend, const D: usize, const D2: usize>(x: Tensor<B, D>) -> Tensor<B, D2> {
573    assert!(D > 0);
574    assert_eq!(D + 1, D2);
575
576    // cumsum[..., t] = x[..., 0] + x[..., 1] + ... + x[..., t]
577    let x_cumsum = x.cumsum(D - 1);
578    san(&x_cumsum);
579
580    // Broadcast along two different axes to compute all pairwise differences:
581    //   x_cumsum_row[..., i, j] = cumsum[..., i]   (i varies along axis D)
582    //   x_cumsum_col[..., i, j] = cumsum[..., j]   (j varies along axis D-1)
583    let x_cumsum_row = x_cumsum.clone().unsqueeze_dim(D); // [..., T, 1]
584    let x_cumsum_col = x_cumsum.unsqueeze_dim(D - 1); // [..., 1, T]
585
586    // diff[..., i, j] = cumsum[i] - cumsum[j]
587    //                 = x[j+1] + ... + x[i]    for i ≥ j
588    let diff = x_cumsum_row - x_cumsum_col; // [..., T, T]
589    san(&diff);
590
591    // Mask the strict upper triangle (i < j) with -∞.
592    // triu(1) returns a tensor that is -∞ above the main diagonal and 0
593    // elsewhere; adding it to `diff` zeroes out the upper triangle of exp(diff).
594    let neg_inf_mask = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(1);
595    diff + neg_inf_mask
596}