pub fn k4_ssd_state_passing<B: Backend>(
intra_chunk_state_bnhpr: Tensor<B, 5>,
da_chunk_end_bhn: Tensor<B, 3>,
initial_state_bhpr: Tensor<B, 4>,
) -> (Tensor<B, 5>, Tensor<B, 4>)Expand description
Propagate hidden state across chunk boundaries using a sequential scan.
This kernel is independent of MIMO rank — it operates on the [H, P, N] state
which is already aggregated over ranks.
§Arguments
intra_chunk_state_bnhpr:[batch, nchunks, nheads, per_head_dim, state_rank]da_chunk_end_bhn:[batch, nheads, nchunks]— total log-decay per chunkinitial_state_bhpr:[batch, nheads, per_head_dim, state_rank]
§Returns
chunk_input_state_bnhpr:[batch, nchunks, nheads, per_head_dim, state_rank]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]