Skip to main content

k5_ssd_chunk_scan

Function k5_ssd_chunk_scan 

Source
pub fn k5_ssd_chunk_scan<B: Backend>(
    da_cumsum_bhnl: Tensor<B, 4>,
    v_bnlrhp: Tensor<B, 6>,
    c_bnlrhn: Tensor<B, 6>,
    cb_bnhLL: Tensor<B, 5>,
    chunk_input_state_bnhpr: Tensor<B, 5>,
) -> Tensor<B, 6>
Expand description

Compute the chunk output by combining the intra-chunk (diagonal) and inter-chunk (off-diagonal) contributions.

The MIMO causal mask uses interleaved time-step ordering: L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R]) if i//R >= j//R, else 0.

No D skip is applied — the caller handles it.

§Arguments

  • da_cumsum_bhnl: [batch, nheads, nchunks, chunk_len] — base (not fused)
  • v_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]
  • c_bnlrhn: [batch, nchunks, chunk_len, R, nheads, state_rank]
  • cb_bnhLL: [batch, nchunks, nheads, L, L] from K2
  • chunk_input_state_bnhpr: [batch, nchunks, nheads, per_head_dim, state_rank]

§Returns

  • y_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim]