pub fn k5_ssd_chunk_scan<B: Backend>(
da_cumsum_bhnl: Tensor<B, 4>,
v_bnlrhp: Tensor<B, 6>,
c_bnlrhn: Tensor<B, 6>,
cb_bnhLL: Tensor<B, 5>,
chunk_input_state_bnhpr: Tensor<B, 5>,
) -> Tensor<B, 6>Expand description
Compute the chunk output by combining the intra-chunk (diagonal) and inter-chunk (off-diagonal) contributions.
The MIMO causal mask uses interleaved time-step ordering:
L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R]) if i//R >= j//R, else 0.
No D skip is applied — the caller handles it.
§Arguments
da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]— base (not fused)v_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]c_bnlrhn:[batch, nchunks, chunk_len, R, nheads, state_rank]cb_bnhLL:[batch, nchunks, nheads, L, L]from K2chunk_input_state_bnhpr:[batch, nchunks, nheads, per_head_dim, state_rank]
§Returns
y_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]