pub fn k3_ssd_chunk_state(
v_bnlmhp: Tensor<6>,
b_bnlmhr: Tensor<6>,
da_cumsum_bhnl: Tensor<4>,
) -> Tensor<5>Expand description
Compute the SSM state at the end of each chunk, assuming zero initial hidden state.
Uses the pre-scaled V tensor — no dt·B scaling is performed here.
§Arguments
v_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]— pre-scaled Vb_bnlmhr:[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]
§Returns
intra_chunk_state_bnhpr:[batch, nchunks, nheads, per_head_dim, state_rank]