Skip to main content

k1_ssd_chunk_cumsum

Function k1_ssd_chunk_cumsum 

Source
pub fn k1_ssd_chunk_cumsum(da_bnlh: Tensor<4>) -> (Tensor<4>, Tensor<3>)
Expand description

Compute the intra-chunk cumulative log-decay and per-chunk decay totals.

§Arguments

  • da_bnlh: pre-combined Δ·A, shape [batch, nchunks, chunk_len, nheads]

§Returns

  • da_cumsum_bhnl: [batch, nheads, nchunks, chunk_len] — intra-chunk prefix sums
  • da_chunk_end_bhn: [batch, nheads, nchunks] — last prefix sum per chunk (total decay)