burn_mamba/mamba3/double_ssd/ssd/minimal.rs
1//! ## The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)
2//!
3//! During training (and prefill), a naive sequential recurrence cannot
4//! exploit GPU tensor cores. The **chunkwise SSD algorithm** achieves this
5//! by splitting the sequence into chunks of length chunk_len and decomposing the
6//! computation into four steps.
7//!
8//! ```text
9//! Step 1 (intra-chunk, MIMO quadratic form) → Y_diag [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
10//! Step 2 (input → chunk state) → state [batch, nchunks, nheads, per_head_dim, state_rank]
11//! Step 3 (inter-chunk state scan) → state [batch, nchunks, nheads, per_head_dim, state_rank], final_state
12//! Step 4 (chunk state → output) → Y_off [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
13//!
14//! Y = Y_diag + Y_off → reshape to [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
15//! ```
16//!
17//! The MIMO causal mask `LM_mimo[i,j] = exp(cumA[i//m] - cumA[j//m])` for `i//m >= j//m`
18//! allows all mimo_ranks ranks at the same time step to attend to each other while
19//! maintaining causal ordering across time steps.
20
21use crate::mamba3::double_ssd::prelude::*;
22use crate::modules::segsum;
23use burn::prelude::*;
24
25impl Mamba3DoubleSsdInput {
26 /// MIMO-first chunkwise SSD — minimal/segsum variant.
27 ///
28 /// Implements the four-step decomposition for the MIMO (double-ssd) trapezoidal recurrence.
29 /// SISO (mimo_rank=1) is the degenerate case where the fused length equals the chunk length.
30 ///
31 /// No D skip is applied here — the caller handles it.
32 ///
33 /// # Shapes
34 /// - input: see [`Mamba3DoubleSsdInput`]
35 /// - output.0 `y_bnlrhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
36 /// - output.1 `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
37 #[allow(non_snake_case)]
38 pub fn double_ssd_minimal(self) -> (Tensor<6>, Tensor<4>) {
39 let input = self;
40 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
41 let [.., state_rank] = input.b_bnlmhr.dims();
42 // note: L above denotes the chunk_len
43 let device = &input.v_bnlmhp.device();
44
45 assert!(nchunks >= 1, "sequence must be non-empty");
46 assert!(chunk_len > 0, "chunk_len must be positive");
47
48 // ── Fuse mimo_rank into chunk_len ────────────────────────────────────────
49 let b_bnLMhr =
50 input
51 .b_bnlmhr
52 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
53 let c_bnLMhr =
54 input
55 .c_bnlmhr
56 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
57 let v_bnLMhp =
58 input
59 .v_bnlmhp
60 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
61
62 // Base per-time-step cumulative log-decay
63 let a_bhnl = input.da_bnlh.clone().permute([0, 3, 1, 2]);
64 let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
65
66 // =============================================================
67 // STEP 1: Intra-chunk outputs (Y_diag)
68 //
69 // Y_diag[m] = (L_mimo[m] ∘ C[m] B[m]ᵀ) · V[m]
70 // note: L above does not denote the chunk_len, but L in the Mamba-3 paper.
71 //
72 // MIMO mask: L_mimo[i,j] = exp(cumA[i//m] - cumA[j//m]) if i//m >= j//m, else 0
73 // =============================================================
74 let y_diag_bnLMhp = {
75 // CB = C @ B^T: contract over state_rank
76 let c_bnhLMr = c_bnLMhr.clone().permute([0, 1, 3, 2, 4]);
77 let b_bnhrLM = b_bnLMhr.clone().permute([0, 1, 3, 4, 2]);
78 // [batch, nchunks, nheads, chunk_len*mimo_rank, chunk_len*mimo_rank]
79 let cb_bnhLMLM = c_bnhLMr.matmul(b_bnhrLM);
80
81 // Build MIMO causal mask from segsum on base dimension, then interleave-expand.
82 // l_base_bhnll[i,j] = exp(cumA[i] - cumA[j]) if i >= j, else 0
83 let l_base_bhnll = segsum::<4, 5>(a_bhnl.clone()).exp();
84
85 // Interleave-expand
86 // L_mimo[i, j] = L_base[i//m, j//m] (same decay for all ranks at a given time)
87 let l_mimo_bhnLMLM = l_base_bhnll
88 // row interleaving: insert mimo_rank copies of each l-row
89 .unsqueeze_dim::<6>(4) // l_base_bhnl1l
90 .expand([batch, nheads, nchunks, chunk_len, mimo_rank, chunk_len]) // l_base_bhnlml
91 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank, chunk_len]) // l_base_bhnLMl
92 // col interleaving: insert mimo_rank copies of each l-col
93 .unsqueeze_dim::<6>(5) // l_base_bhnLMl1
94 .expand([
95 batch,
96 nheads,
97 nchunks,
98 chunk_len * mimo_rank,
99 chunk_len,
100 mimo_rank,
101 ]) // l_base_bhnLMlm
102 .reshape([
103 batch,
104 nheads,
105 nchunks,
106 chunk_len * mimo_rank,
107 chunk_len * mimo_rank,
108 ]); // l_base_bhnLMLM
109
110 // Apply mask: (CB ∘ L_mimo) · V
111 let cb_bnLMhLM = cb_bnhLMLM.permute([0, 1, 3, 2, 4]);
112 let l_bnLMhLM = l_mimo_bhnLMLM.permute([0, 2, 3, 1, 4]);
113 let masked_cb_bnhLMLM = (cb_bnLMhLM * l_bnLMhLM).permute([0, 1, 3, 2, 4]);
114
115 let v_bnhLMp = v_bnLMhp.clone().permute([0, 1, 3, 2, 4]);
116 let y_diag_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
117
118 y_diag_bnhLMp.permute([0, 1, 3, 2, 4])
119 };
120
121 // =============================================================
122 // STEP 2: Chunk state (input → state, zero initial state)
123 //
124 // s[n] = Σ_{t,r} exp(cumA[n,-1] - cumA[n,t]) · V[n,t*m+r] · B[n,t*m+r]ᵀ
125 // (outer product over per_head_dim and state_rank)
126 // =============================================================
127 let state_bnhpr = {
128 // Decay from each fused position to end of chunk:
129 let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
130 // Expand base cumsum to fused length (each time repeated mimo_rank times):
131 // [b, nheads, n, l] → [b, nheads, n, l, R] → [b, nheads, n, L]
132 let a_cumsum_bhnLM = a_cumsum_bhnl
133 .clone()
134 .unsqueeze_dim::<5>(4) // a_cumsum_bhnl1
135 .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // a_cumsum_bhnlm
136 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // a_cumsum_bhnLM
137 let decay_bhnLM = (a_cumsum_last_bhn1 - a_cumsum_bhnLM).exp();
138
139 // Multiply decay into V
140 let decay_bnLMh1 = decay_bhnLM
141 .permute([0, 2, 3, 1]) // decay_bnLMh
142 .unsqueeze_dim(4); // decay_bnLMh1
143 let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp.clone();
144
145 // state = decayed_V^T @ B
146 let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
147 let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
148 decayed_v_bnhpLM.matmul(b_bnhLMr) // state_bnhpr
149 };
150
151 // =============================================================
152 // STEP 3: Inter-chunk state scan (state passing via segsum)
153 //
154 // h[n] = Ā_chunk[n] · h[n-1] + s[n]
155 // =============================================================
156 let (state_bnhpr, final_state_bhpr) = {
157 let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
158 let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
159 let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
160 batch,
161 1,
162 nheads,
163 per_head_dim,
164 state_rank,
165 ]);
166 initial_state_b1hpr + init_b1hpr
167 } else {
168 initial_state_b1hpr
169 };
170
171 // Prepend initial state: [batch, 1+nchunks, nheads, per_head_dim, state_rank]
172 let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
173
174 // Per-chunk cumulative decay (last position of each chunk)
175 let a_cumsum_last_bhn: Tensor<3> = a_cumsum_bhnl
176 .clone()
177 .slice(s![.., .., .., -1]) // a_cumsum_last_bhn1
178 .squeeze_dim(3); // a_cumsum_last_bhn
179 // Prepend zero for the initial state (no decay before chunk 0):
180 let a_chunk_pad_bhN = Tensor::cat(
181 vec![Tensor::zeros([batch, nheads, 1], device), a_cumsum_last_bhn],
182 2,
183 ); // [batch, nheads, 1+nchunks]
184
185 // Inter-chunk decay matrix via segsum: [batch, nheads, 1+nchunks, 1+nchunks]
186 let decay_chunk_bhNN = segsum::<3, 4>(a_chunk_pad_bhN).exp();
187
188 // Flatten (per_head_dim, state_rank) for matmul
189 let flat = per_head_dim * state_rank;
190 let state_bhNPR = state_bNhpr
191 .clone()
192 .permute([0, 2, 1, 3, 4]) // state_bhNpr
193 .reshape([batch, nheads, 1 + nchunks, flat]); // [batch, nheads, 1+nchunks, per_head_dim*state_rank]
194
195 let new_state_bhNPR = decay_chunk_bhNN.matmul(state_bhNPR);
196 let new_state_bhNpr =
197 new_state_bhNPR.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
198
199 // Split: chunk input states [0..n], final state [n]
200 let new_state_bnhpr = new_state_bhNpr
201 .clone()
202 .slice(s![.., .., 0..nchunks, .., ..]) // new_state_bhnpr
203 .permute([0, 2, 1, 3, 4]); // new_state_bnhpr
204 let last_state_bhpr: Tensor<4> = new_state_bhNpr
205 .slice(s![.., .., nchunks, .., ..]) // new_state_bh1pr
206 .squeeze_dim(2); // last_state_bhpr
207
208 (new_state_bnhpr, last_state_bhpr)
209 };
210
211 // =============================================================
212 // STEP 4: State-to-output (Y_off)
213 //
214 // Y_off[n, t*m+r] = C[t*m+r]ᵀ · exp(cumA[t]) · h[n-1]
215 // =============================================================
216 let y_off_bnLMhp = {
217 // Expand base cumsum to fused, then exp:
218 let state_decay_bhnLM = a_cumsum_bhnl
219 .clone()
220 .unsqueeze_dim::<5>(4) // a_cumsum_bhnl1
221 .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // a_cumsum_bhnlm
222 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]) // a_cumsum_bhnLM
223 .exp();
224
225 // C
226 let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
227 let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]);
228 let ch_bnhLMp = c_bnhLMr.matmul(state_bnhrp);
229
230 // Multiply by intra-chunk decay
231 let decay_bnhLM1 = state_decay_bhnLM
232 .permute([0, 2, 1, 3]) // state_decay_bnhLM
233 .unsqueeze_dim(4); // state_decay_bnhLM1
234 let y_off_bnhLMp = ch_bnhLMp * decay_bnhLM1;
235 y_off_bnhLMp.permute([0, 1, 3, 2, 4]) // y_off_bnLMhp
236 };
237
238 // ── Combine and reshape ───────────────────────────────────────────────
239 let y_bnLMhp = y_diag_bnLMhp + y_off_bnLMhp;
240 let y_bnlmhp =
241 y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
242
243 (y_bnlmhp, final_state_bhpr)
244 }
245}