Skip to main content

Module minimal

Module minimal 

Source
Expand description

Matmul/segsum MIMO-first SSD with plain autodiff backward.

§The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)

During training (and prefill), a naive sequential recurrence cannot exploit GPU tensor cores. The chunkwise SSD algorithm achieves this by splitting the sequence into chunks of length chunk_len and decomposing the computation into four steps.

  Step 1  (intra-chunk, MIMO quadratic form)  →  Y_diag   [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
  Step 2  (input → chunk state)               →  state    [batch, nchunks, nheads, per_head_dim, state_rank]
  Step 3  (inter-chunk state scan)            →  state    [batch, nchunks, nheads, per_head_dim, state_rank], final_state
  Step 4  (chunk state → output)              →  Y_off    [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]

  Y = Y_diag + Y_off   →  reshape to [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]

The MIMO causal mask LM_mimo[i,j] = exp(cumA[i//m] - cumA[j//m]) for i//m >= j//m allows all mimo_ranks ranks at the same time step to attend to each other while maintaining causal ordering across time steps.