Skip to main contentModule serial
Source - k1_ssd_chunk_cumsum
- Compute the intra-chunk cumulative log-decay and per-chunk decay totals.
- k2_ssd_bmm
- Compute the intra-chunk CB matrix on fused (R-into-L) tensors.
- k3_ssd_chunk_state
- Compute the SSM state at the end of each chunk, assuming zero initial hidden state.
- k4_ssd_state_passing
- Propagate hidden state across chunk boundaries using a sequential scan.
- k5_ssd_chunk_scan
- Compute the chunk output by combining the intra-chunk (diagonal) and
inter-chunk (off-diagonal) contributions.