Expand description
§The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)
During training (and prefill), a naive sequential recurrence cannot exploit GPU tensor cores. The chunkwise SSD algorithm achieves this by splitting the sequence into chunks of length Q and decomposing the computation into four steps.
For MIMO (mimo_rank=R>1), the rank dimension is fused into the chunk_len
dimension via an interleaved reshape: position t*R+r represents
(time=t, rank=r). This gives a fused sequence length L = Q·R per chunk.
SISO (R=1) is the special case where L = Q.
Step 1 (intra-chunk, MIMO quadratic form) → Y_diag [b, n, L, H, P]
Step 2 (input → chunk state) → state [b, n, H, P, N]
Step 3 (inter-chunk state scan) → state [b, n, H, P, N], final_state
Step 4 (chunk state → output) → Y_off [b, n, L, H, P]
Y = Y_diag + Y_off → reshape to [b, n, l, R, H, P]The MIMO causal mask L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R]) for i//R >= j//R
allows all R ranks at the same time step to attend to each other while
maintaining causal ordering across time steps.