Skip to main content

k3_ssd_chunk_state

Function k3_ssd_chunk_state 

Source
pub fn k3_ssd_chunk_state<B: Backend>(
    v_bnlrhp: Tensor<B, 6>,
    b_bnlrhn: Tensor<B, 6>,
    da_cumsum_bhnl: Tensor<B, 4>,
) -> Tensor<B, 5>
Expand description

Compute the SSM state at the end of each chunk, assuming zero initial hidden state.

Uses the pre-scaled V tensor — no dt·B scaling is performed here.

§Arguments

  • v_bnlrhp: [batch, nchunks, chunk_len, R, nheads, per_head_dim] — pre-scaled V
  • b_bnlrhn: [batch, nchunks, chunk_len, R, nheads, state_rank]
  • da_cumsum_bhnl: [batch, nheads, nchunks, chunk_len] — base (not fused)

§Returns

  • intra_chunk_state_bnhpr: [batch, nchunks, nheads, per_head_dim, state_rank]