Skip to main content

Module minimal

Module minimal 

Source
Expand description

§The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)

During training (and prefill), a naive sequential recurrence cannot exploit GPU tensor cores. The chunkwise SSD algorithm achieves this by splitting the sequence into chunks of length Q and decomposing the computation into four steps.

For MIMO (mimo_rank=R>1), the rank dimension is fused into the chunk_len dimension via an interleaved reshape: position t*R+r represents (time=t, rank=r). This gives a fused sequence length L = Q·R per chunk. SISO (R=1) is the special case where L = Q.

  Step 1  (intra-chunk, MIMO quadratic form)  →  Y_diag   [b, n, L, H, P]
  Step 2  (input → chunk state)               →  state    [b, n, H, P, N]
  Step 3  (inter-chunk state scan)            →  state    [b, n, H, P, N], final_state
  Step 4  (chunk state → output)              →  Y_off    [b, n, L, H, P]

  Y = Y_diag + Y_off   →  reshape to [b, n, l, R, H, P]

The MIMO causal mask L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R]) for i//R >= j//R allows all R ranks at the same time step to attend to each other while maintaining causal ordering across time steps.