pub fn k1_ssd_chunk_cumsum<B: Backend>(
da_bnlh: Tensor<B, 4>,
) -> (Tensor<B, 4>, Tensor<B, 3>)Expand description
Compute the intra-chunk cumulative log-decay and per-chunk decay totals.
§Arguments
da_bnlh: pre-combinedΔ·A, shape[batch, nchunks, chunk_len, nheads]
§Returns
da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]— intra-chunk prefix sumsda_chunk_end_bhn:[batch, nheads, nchunks]— last prefix sum per chunk (total decay)