pub fn k3_ssd_chunk_state<B: Backend>(
v_bnlrhp: Tensor<B, 6>,
b_bnlrhn: Tensor<B, 6>,
da_cumsum_bhnl: Tensor<B, 4>,
) -> Tensor<B, 5>Expand description
Compute the SSM state at the end of each chunk, assuming zero initial hidden state.
Uses the pre-scaled V tensor — no dt·B scaling is performed here.
§Arguments
v_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]— pre-scaled Vb_bnlrhn:[batch, nchunks, chunk_len, R, nheads, state_rank]da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]— base (not fused)
§Returns
intra_chunk_state_bnhpr:[batch, nchunks, nheads, per_head_dim, state_rank]