Skip to main content

k1_ssd_chunk_cumsum

Function k1_ssd_chunk_cumsum 

Source
pub fn k1_ssd_chunk_cumsum<B: Backend>(
    da_bnlh: Tensor<B, 4>,
) -> (Tensor<B, 4>, Tensor<B, 3>)
Expand description

Compute the intra-chunk cumulative log-decay and per-chunk decay totals.

§Arguments

  • da_bnlh: pre-combined Δ·A, shape [batch, nchunks, chunk_len, nheads]

§Returns

  • da_cumsum_bhnl: [batch, nheads, nchunks, chunk_len] — intra-chunk prefix sums
  • da_chunk_end_bhn: [batch, nheads, nchunks] — last prefix sum per chunk (total decay)