pub fn k5_ssd_chunk_scan(
da_cumsum_bhnl: Tensor<4>,
v_bnlmhp: Tensor<6>,
c_bnlmhr: Tensor<6>,
cb_bnhLMLM: Tensor<5>,
chunk_input_state_bnhpr: Tensor<5>,
) -> Tensor<6>Expand description
Compute the chunk output by combining the intra-chunk (diagonal) and inter-chunk (off-diagonal) contributions.
The MIMO causal mask uses interleaved time-step ordering:
L_mimo[i,j] = exp(cumA[i//m] - cumA[j//m]) if i//m >= j//m, else 0.
No D skip is applied — the caller handles it.
§Arguments
da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]— base (not fused)v_bnlmhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]c_bnlmhr:[batch, nchunks, chunk_len, R, nheads, state_rank]cb_bnhLMLM:[batch, nchunks, nheads, L, L]from K2chunk_input_state_bnhpr:[batch, nchunks, nheads, per_head_dim, state_rank]
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]