Skip to main content

k5_ssd_chunk_scan

Function k5_ssd_chunk_scan 

Source
pub fn k5_ssd_chunk_scan(
    da_cumsum_bhnl: Tensor<4>,
    v_bnlmhp: Tensor<6>,
    c_bnlmhr: Tensor<6>,
    cb_bnhLMLM: Tensor<5>,
    chunk_input_state_bnhpr: Tensor<5>,
) -> Tensor<6>
Expand description

Compute the chunk output by combining the intra-chunk (diagonal) and inter-chunk (off-diagonal) contributions.

The MIMO causal mask uses interleaved time-step ordering: L_mimo[i,j] = exp(cumA[i//m] - cumA[j//m]) if i//m >= j//m, else 0.

No D skip is applied — the caller handles it.

§Arguments

  • da_cumsum_bhnl: [batch, nheads, nchunks, chunk_len] — base (not fused)
  • v_bnlmhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]
  • c_bnlmhr: [batch, nchunks, chunk_len, R, nheads, state_rank]
  • cb_bnhLMLM: [batch, nchunks, nheads, L, L] from K2
  • chunk_input_state_bnhpr: [batch, nchunks, nheads, per_head_dim, state_rank]

§Returns

  • y_bnlmhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]