burn_mamba/mamba3/ssd/minimal.rs
1//! ## The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)
2//!
3//! During training (and prefill), a naive sequential recurrence cannot
4//! exploit GPU tensor cores. The **chunkwise SSD algorithm** achieves this
5//! by splitting the sequence into chunks of length Q and decomposing the
6//! computation into four steps.
7//!
8//! For MIMO (mimo_rank=R>1), the rank dimension is fused into the chunk_len
9//! dimension via an interleaved reshape: position `t*R+r` represents
10//! (time=t, rank=r). This gives a fused sequence length `L = Q·R` per chunk.
11//! SISO (R=1) is the special case where `L = Q`.
12//!
13//! ```text
14//! Step 1 (intra-chunk, MIMO quadratic form) → Y_diag [b, n, L, H, P]
15//! Step 2 (input → chunk state) → state [b, n, H, P, N]
16//! Step 3 (inter-chunk state scan) → state [b, n, H, P, N], final_state
17//! Step 4 (chunk state → output) → Y_off [b, n, L, H, P]
18//!
19//! Y = Y_diag + Y_off → reshape to [b, n, l, R, H, P]
20//! ```
21//!
22//! The MIMO causal mask `L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R])` for `i//R >= j//R`
23//! allows all R ranks at the same time step to attend to each other while
24//! maintaining causal ordering across time steps.
25
26use crate::mamba3::prelude::*;
27use burn::prelude::*;
28
29impl<B: Backend> Mamba3<B> {
30 /// MIMO-first chunkwise SSD — minimal/segsum variant.
31 ///
32 /// Implements the four-step decomposition for the MIMO trapezoidal recurrence.
33 /// SISO (R=1) is the degenerate case where the fused length equals the chunk length.
34 ///
35 /// No D skip is applied here — the caller handles it.
36 ///
37 /// # Shapes
38 /// - input: see [`Mamba3SsdInput`]
39 /// - output.0 `y_bnlrhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
40 /// - output.1 `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
41 #[allow(non_snake_case)]
42 pub fn ssd_minimal(input: super::Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>) {
43 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlrhp.dims();
44 let [.., state_rank] = input.b_bnlrhn.dims();
45 let fused_len = chunk_len * mimo_rank; // L = Q·R
46 let device = &input.v_bnlrhp.device();
47
48 assert!(nchunks >= 1, "sequence must be non-empty");
49 assert!(chunk_len > 0, "chunk_len must be positive");
50
51 // ── Fuse R into chunk_len ─────────────────────────────────────────────
52 // [b, n, l, R, H, N] → [b, n, L, H, N] where L = l*R
53 let b_bnLhn = input
54 .b_bnlrhn
55 .reshape([batch, nchunks, fused_len, nheads, state_rank]);
56 let c_bnLhn = input
57 .c_bnlrhn
58 .reshape([batch, nchunks, fused_len, nheads, state_rank]);
59 // [b, n, l, R, H, P] → [b, n, L, H, P]
60 let v_bnLhp = input
61 .v_bnlrhp
62 .reshape([batch, nchunks, fused_len, nheads, per_head_dim]);
63
64 // Base per-time-step cumulative log-decay: [b, H, n, l]
65 let a_bhnl = input.da_bnlh.clone().permute([0, 3, 1, 2]);
66 let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
67
68 // =============================================================
69 // STEP 1: Intra-chunk outputs (Y_diag)
70 //
71 // Y_diag[n] = (L_mimo[n] ∘ C[n] B[n]ᵀ) · V[n]
72 //
73 // MIMO mask: L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R]) if i//R >= j//R, else 0
74 // =============================================================
75 let y_diag_bnLhp = {
76 // CB = C @ B^T: contract over state_rank N
77 let c_bnhLr = c_bnLhn.clone().permute([0, 1, 3, 2, 4]); // [b, n, H, L, N]
78 let b_bnhLr = b_bnLhn.clone().permute([0, 1, 3, 2, 4]); // [b, n, H, L, N]
79 let b_bnhrL = b_bnhLr.permute([0, 1, 2, 4, 3]); // [b, n, H, N, L]
80 let cb_bnhLL = c_bnhLr.matmul(b_bnhrL); // [b, n, H, L, L]
81
82 // Build MIMO causal mask from segsum on base dimension, then interleave-expand.
83 // l_base_bhnll[i,j] = exp(cumA[i] - cumA[j]) if i >= j, else 0
84 let l_base_bhnll = segsum::<B, 4, 5>(a_bhnl.clone()).exp(); // [b, H, n, l, l]
85
86 // Interleave-expand: [b, H, n, l, l] → [b, H, n, L, L]
87 // L_mimo[i, j] = L_base[i//R, j//R] (same decay for all ranks at a given time)
88 let l_mimo_bhnLL = l_base_bhnll
89 // row interleaving: insert R copies of each l-row
90 .unsqueeze_dim::<6>(4)
91 .expand([batch, nheads, nchunks, chunk_len, mimo_rank, chunk_len])
92 .reshape([batch, nheads, nchunks, fused_len, chunk_len])
93 // col interleaving: insert R copies of each l-col
94 .unsqueeze_dim::<6>(5)
95 .expand([batch, nheads, nchunks, fused_len, chunk_len, mimo_rank])
96 .reshape([batch, nheads, nchunks, fused_len, fused_len]);
97
98 // Apply mask: (CB ∘ L_mimo) · V
99 let cb_bnLhL = cb_bnhLL.permute([0, 1, 3, 2, 4]); // [b, n, L, H, L]
100 let l_bnLhL = l_mimo_bhnLL.permute([0, 2, 3, 1, 4]); // [b, n, L, H, L]
101 let masked_cb_bnhLL = (cb_bnLhL * l_bnLhL).permute([0, 1, 3, 2, 4]); // [b, n, H, L, L]
102
103 let v_bnhLp = v_bnLhp.clone().permute([0, 1, 3, 2, 4]); // [b, n, H, L, P]
104 let y_diag_bnhLp = masked_cb_bnhLL.matmul(v_bnhLp); // [b, n, H, L, P]
105
106 y_diag_bnhLp.permute([0, 1, 3, 2, 4]) // [b, n, L, H, P]
107 };
108
109 // =============================================================
110 // STEP 2: Chunk state (input → state, zero initial state)
111 //
112 // s[n] = Σ_{t,r} exp(cumA[n,-1] - cumA[n,t]) · V[n,t*R+r] · B[n,t*R+r]ᵀ
113 // (outer product over P and N)
114 // =============================================================
115 let state_bnhpr = {
116 // Decay from each fused position to end of chunk:
117 // decay_fused[t*R+r] = exp(cumA_last - cumA_base[t])
118 let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]); // [b,H,n,1]
119 // Expand base cumsum to fused length (each time repeated R times):
120 // [b, H, n, l] → [b, H, n, l, R] → [b, H, n, L]
121 let a_cumsum_fused_bhnL = a_cumsum_bhnl
122 .clone()
123 .unsqueeze_dim::<5>(4)
124 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
125 .reshape([batch, nheads, nchunks, fused_len]);
126 // (cumA_last - cumA_fused): broadcasts [b,H,n,1] - [b,H,n,L]
127 let decay_bhnL = (a_cumsum_last_bhn1 - a_cumsum_fused_bhnL).exp();
128
129 // Multiply decay into V: [b, n, L, H, 1] * [b, n, L, H, P]
130 let decay_bnLh1 = decay_bhnL.permute([0, 2, 3, 1]).unsqueeze_dim(4);
131 let decayed_v_bnLhp = decay_bnLh1 * v_bnLhp.clone();
132
133 // state = decayed_V^T @ B: [b, n, H, P, L] × [b, n, H, L, N] → [b, n, H, P, N]
134 let decayed_v_bnhpL = decayed_v_bnLhp.permute([0, 1, 3, 4, 2]);
135 let b_bnhLN = b_bnLhn.permute([0, 1, 3, 2, 4]);
136 decayed_v_bnhpL.matmul(b_bnhLN) // [b, n, H, P, N]
137 };
138
139 // =============================================================
140 // STEP 3: Inter-chunk state scan (state passing via segsum)
141 //
142 // h[n] = Ā_chunk[n] · h[n-1] + s[n]
143 // =============================================================
144 let (state_bnhpr, final_state_bhpr) = {
145 let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
146 let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
147 let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
148 batch,
149 1,
150 nheads,
151 per_head_dim,
152 state_rank,
153 ]);
154 initial_state_b1hpr + init_b1hpr
155 } else {
156 initial_state_b1hpr
157 };
158
159 // Prepend initial state: [b, 1+n, H, P, N]
160 let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
161
162 // Per-chunk cumulative decay (last position of each chunk): [b, H, n]
163 let a_cumsum_last_bhn: Tensor<B, 3> = a_cumsum_bhnl
164 .clone()
165 .slice(s![.., .., .., -1])
166 .squeeze_dim(3);
167 // Prepend zero for the initial state (no decay before chunk 0):
168 let a_chunk_pad_bhN = Tensor::cat(
169 vec![Tensor::zeros([batch, nheads, 1], device), a_cumsum_last_bhn],
170 2,
171 ); // [b, H, 1+n]
172
173 // Inter-chunk decay matrix via segsum: [b, H, 1+n, 1+n]
174 let decay_chunk_bhNN = segsum::<B, 3, 4>(a_chunk_pad_bhN).exp();
175
176 // Flatten (P, N) for matmul
177 let flat = per_head_dim * state_rank;
178 let state_bhNf = state_bNhpr.clone().permute([0, 2, 1, 3, 4]).reshape([
179 batch,
180 nheads,
181 1 + nchunks,
182 flat,
183 ]);
184
185 // [b, H, 1+n, 1+n] × [b, H, 1+n, P·N] → [b, H, 1+n, P·N]
186 let new_state_bhNf = decay_chunk_bhNN.matmul(state_bhNf);
187 let new_state_bhNpr =
188 new_state_bhNf.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
189
190 // Split: chunk input states [0..n], final state [n]
191 let s_bhnpr = new_state_bhNpr
192 .clone()
193 .slice(s![.., .., 0..nchunks, .., ..]);
194 let f_bhpr: Tensor<B, 4> = new_state_bhNpr
195 .slice(s![.., .., nchunks, .., ..])
196 .squeeze_dim(2);
197
198 (s_bhnpr.permute([0, 2, 1, 3, 4]), f_bhpr) // [b, n, H, P, N], [b, H, P, N]
199 };
200
201 // =============================================================
202 // STEP 4: State-to-output (Y_off)
203 //
204 // Y_off[n, t*R+r] = C[t*R+r]ᵀ · exp(cumA[t]) · h[n-1]
205 // =============================================================
206 let y_off_bnLhp = {
207 // Expand base cumsum to fused, then exp:
208 let state_decay_bhnL = a_cumsum_bhnl
209 .clone()
210 .unsqueeze_dim::<5>(4)
211 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
212 .reshape([batch, nheads, nchunks, fused_len])
213 .exp(); // [b, H, n, L]
214
215 // C: [b, n, H, L, N], state: [b, n, H, N, P]
216 let c_bnhLr = c_bnLhn.permute([0, 1, 3, 2, 4]);
217 let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]);
218 let ch_bnhLp = c_bnhLr.matmul(state_bnhrp); // [b, n, H, L, P]
219
220 // Multiply by intra-chunk decay: [b, n, H, L, 1]
221 let decay_bnhL1 = state_decay_bhnL.permute([0, 2, 1, 3]).unsqueeze_dim(4);
222 let y_off_bnhLp = ch_bnhLp * decay_bnhL1;
223 y_off_bnhLp.permute([0, 1, 3, 2, 4]) // [b, n, L, H, P]
224 };
225
226 // ── Combine and reshape ───────────────────────────────────────────────
227 let y_bnLhp = y_diag_bnLhp + y_off_bnLhp; // [b, n, L, H, P]
228 // Reshape: [b, n, L, H, P] = [b, n, l*R, H, P] → [b, n, l, R, H, P]
229 let y_bnlrhp =
230 y_bnLhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
231
232 (y_bnlrhp, final_state_bhpr)
233 }
234}
235
236// ---------------------------------------------------------------------------
237// segsum (stable segment sum for the 1-SS mask)
238// ---------------------------------------------------------------------------
239
240/// Compute stable segment sums for constructing the 1-semiseparable mask.
241///
242/// Given a tensor `x` of shape `[..., T]`, returns a tensor of shape `[..., T, T]` where:
243///
244/// ```text
245/// out[..., i, j] = Σ_{k=j+1}^{i} x[..., k] for i ≥ j (lower triangle)
246/// out[..., i, j] = -∞ for i < j (upper triangle)
247/// ```
248pub(super) fn segsum<B: Backend, const D: usize, const D2: usize>(
249 x: Tensor<B, D>,
250) -> Tensor<B, D2> {
251 assert_eq!(D + 1, D2);
252
253 let x_cumsum = x.cumsum(D - 1);
254 let x_cumsum_row = x_cumsum.clone().unsqueeze_dim(D); // [..., T, 1]
255 let x_cumsum_col = x_cumsum.unsqueeze_dim(D - 1); // [..., 1, T]
256
257 let diff = x_cumsum_row - x_cumsum_col; // [..., T, T]
258 let neg_inf_mask = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(1);
259 diff + neg_inf_mask
260}