Skip to main content

Module serial

Module serial 

Source

Functionsยง

k1_ssd_chunk_cumsum
Compute the intra-chunk cumulative log-decay and per-chunk decay totals.
k2_ssd_bmm
Compute the intra-chunk CB matrix on fused (R-into-L) tensors.
k3_ssd_chunk_state
Compute the SSM state at the end of each chunk, assuming zero initial hidden state.
k4_ssd_state_passing
Propagate hidden state across chunk boundaries using a sequential scan.
k5_ssd_chunk_scan
Compute the chunk output by combining the intra-chunk (diagonal) and inter-chunk (off-diagonal) contributions.