Skip to main content

k4_ssd_state_passing

Function k4_ssd_state_passing 

Source
pub fn k4_ssd_state_passing(
    intra_chunk_state_bnhpr: Tensor<5>,
    da_chunk_end_bhn: Tensor<3>,
    initial_state_bhpr: Tensor<4>,
) -> (Tensor<5>, Tensor<4>)
Expand description

Propagate hidden state across chunk boundaries using a sequential scan.

This kernel is independent of MIMO rank — it operates on the [nheads, per_head_dim, state_rank] state which is already aggregated over ranks.

§Arguments

  • intra_chunk_state_bnhpr: [batch, nchunks, nheads, per_head_dim, state_rank]
  • da_chunk_end_bhn: [batch, nheads, nchunks] — total log-decay per chunk
  • initial_state_bhpr: [batch, nheads, per_head_dim, state_rank]

§Returns

  • chunk_input_state_bnhpr: [batch, nchunks, nheads, per_head_dim, state_rank]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]