Skip to main content

k3_ssd_chunk_state

Function k3_ssd_chunk_state 

Source
pub fn k3_ssd_chunk_state(
    v_bnlmhp: Tensor<6>,
    b_bnlmhr: Tensor<6>,
    da_cumsum_bhnl: Tensor<4>,
) -> Tensor<5>
Expand description

Compute the SSM state at the end of each chunk, assuming zero initial hidden state.

Uses the pre-scaled V tensor — no dt·B scaling is performed here.

§Arguments

  • v_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] — pre-scaled V
  • b_bnlmhr: [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
  • da_cumsum_bhnl: [batch, nheads, nchunks, chunk_len]

§Returns

  • intra_chunk_state_bnhpr: [batch, nchunks, nheads, per_head_dim, state_rank]