burn_mamba/modules/misc/segsum.rs
1//! Stable segment-sum, the building block of the SSD 1-semiseparable mask `L`.
2//!
3//! `L[i, j] = exp(segsum(a)[i, j])` is the causal decay from step `j` to `i`;
4//! computing it via differences of log-space prefix sums (rather than chained
5//! products) keeps it stable over long sequences. See [`segsum`] for the math.
6
7use crate::modules::sanity as san;
8use burn::prelude::*;
9
10// ---------------------------------------------------------------------------
11// segsum (stable segment sum for the 1-SS mask)
12// ---------------------------------------------------------------------------
13
14/// Compute stable segment sums for constructing the 1-semiseparable mask.
15///
16/// Given a tensor `x` of shape `[..., sequence]`, returns a tensor of shape `[..., sequence, sequence]` where:
17///
18/// ```text
19/// out[..., i, j] = Σ_{k=j+1}^{i} x[..., k] for i ≥ j (lower triangle)
20/// out[..., i, j] = -∞ for i < j (upper triangle)
21/// ```
22///
23/// ## Implementation
24///
25/// A naive computation of all pairwise products `A[j+1]·...·A[i]` would
26/// suffer from underflow for long sequences (e.g. `0.9^1000 ≈ 2.6×10⁻⁴⁶`).
27/// Working in log-space and computing differences of prefix sums avoids this:
28///
29/// ```text
30/// segsum(x)[i, j] = cumsum(x)[i] - cumsum(x)[j]
31/// ```
32///
33/// The upper triangle is masked to -∞ so that `exp(segsum(...))` gives 0
34/// for non-causal positions (the strict upper triangle of L must be zero).
35pub fn segsum<const D: usize, const D2: usize>(x: Tensor<D>) -> Tensor<D2> {
36 assert_eq!(D + 1, D2);
37
38 let x_cumsum = x.cumsum(D - 1);
39 san(&x_cumsum);
40 let x_cumsum_row = x_cumsum.clone().unsqueeze_dim(D); // [..., sequence, 1]
41 let x_cumsum_col = x_cumsum.unsqueeze_dim(D - 1); // [..., 1, sequence]
42
43 let diff = x_cumsum_row - x_cumsum_col; // [..., sequence, sequence]
44 san(&diff);
45 let neg_inf_mask = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(1);
46 diff + neg_inf_mask
47}