Skip to main content

segsum

Function segsum 

Source
pub fn segsum<const D: usize, const D2: usize>(x: Tensor<D>) -> Tensor<D2>
Expand description

Compute stable segment sums for constructing the 1-semiseparable mask.

Given a tensor x of shape [..., sequence], returns a tensor of shape [..., sequence, sequence] where:

  out[..., i, j] = Σ_{k=j+1}^{i} x[..., k]   for i ≥ j  (lower triangle)
  out[..., i, j] = -∞                        for i < j  (upper triangle)

§Implementation

A naive computation of all pairwise products A[j+1]·...·A[i] would suffer from underflow for long sequences (e.g. 0.9^1000 ≈ 2.6×10⁻⁴⁶). Working in log-space and computing differences of prefix sums avoids this:

  segsum(x)[i, j] = cumsum(x)[i] - cumsum(x)[j]

The upper triangle is masked to -∞ so that exp(segsum(...)) gives 0 for non-causal positions (the strict upper triangle of L must be zero).