Expand description
Matmul/segsum MIMO-first SSD with plain autodiff backward.
§The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)
During training (and prefill), a naive sequential recurrence cannot exploit GPU tensor cores. The chunkwise SSD algorithm achieves this by splitting the sequence into chunks of length chunk_len and decomposing the computation into four steps.
Step 1 (intra-chunk, MIMO quadratic form) → Y_diag [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
Step 2 (input → chunk state) → state [batch, nchunks, nheads, per_head_dim, state_rank]
Step 3 (inter-chunk state scan) → state [batch, nchunks, nheads, per_head_dim, state_rank], final_state
Step 4 (chunk state → output) → Y_off [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
Y = Y_diag + Y_off → reshape to [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]The MIMO causal mask LM_mimo[i,j] = exp(cumA[i//m] - cumA[j//m]) for i//m >= j//m
allows all mimo_ranks ranks at the same time step to attend to each other while
maintaining causal ordering across time steps.