pub fn k1_ssd_chunk_cumsum(da_bnlh: Tensor<4>) -> (Tensor<4>, Tensor<3>)Expand description
Compute the intra-chunk cumulative log-decay and per-chunk decay totals.
§Arguments
da_bnlh: pre-combinedΔ·A, shape[batch, nchunks, chunk_len, nheads]
§Returns
da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]— intra-chunk prefix sumsda_chunk_end_bhn:[batch, nheads, nchunks]— last prefix sum per chunk (total decay)