Skip to main content

burn_mamba/mamba3/ssd/
minimal.rs

1//! ## The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)
2//!
3//! During training (and prefill), a naive sequential recurrence cannot
4//! exploit GPU tensor cores.  The **chunkwise SSD algorithm** achieves this
5//! by splitting the sequence into chunks of length Q and decomposing the
6//! computation into four steps.
7//!
8//! For MIMO (mimo_rank=R>1), the rank dimension is fused into the chunk_len
9//! dimension via an interleaved reshape: position `t*R+r` represents
10//! (time=t, rank=r).  This gives a fused sequence length `L = Q·R` per chunk.
11//! SISO (R=1) is the special case where `L = Q`.
12//!
13//! ```text
14//!   Step 1  (intra-chunk, MIMO quadratic form)  →  Y_diag   [b, n, L, H, P]
15//!   Step 2  (input → chunk state)               →  state    [b, n, H, P, N]
16//!   Step 3  (inter-chunk state scan)            →  state    [b, n, H, P, N], final_state
17//!   Step 4  (chunk state → output)              →  Y_off    [b, n, L, H, P]
18//!
19//!   Y = Y_diag + Y_off   →  reshape to [b, n, l, R, H, P]
20//! ```
21//!
22//! The MIMO causal mask `L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R])` for `i//R >= j//R`
23//! allows all R ranks at the same time step to attend to each other while
24//! maintaining causal ordering across time steps.
25
26use crate::mamba3::prelude::*;
27use burn::prelude::*;
28
29impl<B: Backend> Mamba3<B> {
30    /// MIMO-first chunkwise SSD — minimal/segsum variant.
31    ///
32    /// Implements the four-step decomposition for the MIMO trapezoidal recurrence.
33    /// SISO (R=1) is the degenerate case where the fused length equals the chunk length.
34    ///
35    /// No D skip is applied here — the caller handles it.
36    ///
37    /// # Shapes
38    /// - input: see [`Mamba3SsdInput`]
39    /// - output.0 `y_bnlrhp`:       `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
40    /// - output.1 `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
41    #[allow(non_snake_case)]
42    pub fn ssd_minimal(input: super::Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>) {
43        let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlrhp.dims();
44        let [.., state_rank] = input.b_bnlrhn.dims();
45        let fused_len = chunk_len * mimo_rank; // L = Q·R
46        let device = &input.v_bnlrhp.device();
47
48        assert!(nchunks >= 1, "sequence must be non-empty");
49        assert!(chunk_len > 0, "chunk_len must be positive");
50
51        // ── Fuse R into chunk_len ─────────────────────────────────────────────
52        // [b, n, l, R, H, N] → [b, n, L, H, N]  where L = l*R
53        let b_bnLhn = input
54            .b_bnlrhn
55            .reshape([batch, nchunks, fused_len, nheads, state_rank]);
56        let c_bnLhn = input
57            .c_bnlrhn
58            .reshape([batch, nchunks, fused_len, nheads, state_rank]);
59        // [b, n, l, R, H, P] → [b, n, L, H, P]
60        let v_bnLhp = input
61            .v_bnlrhp
62            .reshape([batch, nchunks, fused_len, nheads, per_head_dim]);
63
64        // Base per-time-step cumulative log-decay: [b, H, n, l]
65        let a_bhnl = input.da_bnlh.clone().permute([0, 3, 1, 2]);
66        let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
67
68        // =============================================================
69        // STEP 1: Intra-chunk outputs (Y_diag)
70        //
71        // Y_diag[n] = (L_mimo[n] ∘ C[n] B[n]ᵀ) · V[n]
72        //
73        // MIMO mask: L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R]) if i//R >= j//R, else 0
74        // =============================================================
75        let y_diag_bnLhp = {
76            // CB = C @ B^T: contract over state_rank N
77            let c_bnhLr = c_bnLhn.clone().permute([0, 1, 3, 2, 4]); // [b, n, H, L, N]
78            let b_bnhLr = b_bnLhn.clone().permute([0, 1, 3, 2, 4]); // [b, n, H, L, N]
79            let b_bnhrL = b_bnhLr.permute([0, 1, 2, 4, 3]); // [b, n, H, N, L]
80            let cb_bnhLL = c_bnhLr.matmul(b_bnhrL); // [b, n, H, L, L]
81
82            // Build MIMO causal mask from segsum on base dimension, then interleave-expand.
83            // l_base_bhnll[i,j] = exp(cumA[i] - cumA[j]) if i >= j, else 0
84            let l_base_bhnll = segsum::<B, 4, 5>(a_bhnl.clone()).exp(); // [b, H, n, l, l]
85
86            // Interleave-expand: [b, H, n, l, l] → [b, H, n, L, L]
87            // L_mimo[i, j] = L_base[i//R, j//R]  (same decay for all ranks at a given time)
88            let l_mimo_bhnLL = l_base_bhnll
89                // row interleaving: insert R copies of each l-row
90                .unsqueeze_dim::<6>(4)
91                .expand([batch, nheads, nchunks, chunk_len, mimo_rank, chunk_len])
92                .reshape([batch, nheads, nchunks, fused_len, chunk_len])
93                // col interleaving: insert R copies of each l-col
94                .unsqueeze_dim::<6>(5)
95                .expand([batch, nheads, nchunks, fused_len, chunk_len, mimo_rank])
96                .reshape([batch, nheads, nchunks, fused_len, fused_len]);
97
98            // Apply mask: (CB ∘ L_mimo) · V
99            let cb_bnLhL = cb_bnhLL.permute([0, 1, 3, 2, 4]); // [b, n, L, H, L]
100            let l_bnLhL = l_mimo_bhnLL.permute([0, 2, 3, 1, 4]); // [b, n, L, H, L]
101            let masked_cb_bnhLL = (cb_bnLhL * l_bnLhL).permute([0, 1, 3, 2, 4]); // [b, n, H, L, L]
102
103            let v_bnhLp = v_bnLhp.clone().permute([0, 1, 3, 2, 4]); // [b, n, H, L, P]
104            let y_diag_bnhLp = masked_cb_bnhLL.matmul(v_bnhLp); // [b, n, H, L, P]
105
106            y_diag_bnhLp.permute([0, 1, 3, 2, 4]) // [b, n, L, H, P]
107        };
108
109        // =============================================================
110        // STEP 2: Chunk state (input → state, zero initial state)
111        //
112        // s[n] = Σ_{t,r} exp(cumA[n,-1] - cumA[n,t]) · V[n,t*R+r] · B[n,t*R+r]ᵀ
113        //      (outer product over P and N)
114        // =============================================================
115        let state_bnhpr = {
116            // Decay from each fused position to end of chunk:
117            //   decay_fused[t*R+r] = exp(cumA_last - cumA_base[t])
118            let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]); // [b,H,n,1]
119            // Expand base cumsum to fused length (each time repeated R times):
120            // [b, H, n, l] → [b, H, n, l, R] → [b, H, n, L]
121            let a_cumsum_fused_bhnL = a_cumsum_bhnl
122                .clone()
123                .unsqueeze_dim::<5>(4)
124                .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
125                .reshape([batch, nheads, nchunks, fused_len]);
126            // (cumA_last - cumA_fused): broadcasts [b,H,n,1] - [b,H,n,L]
127            let decay_bhnL = (a_cumsum_last_bhn1 - a_cumsum_fused_bhnL).exp();
128
129            // Multiply decay into V: [b, n, L, H, 1] * [b, n, L, H, P]
130            let decay_bnLh1 = decay_bhnL.permute([0, 2, 3, 1]).unsqueeze_dim(4);
131            let decayed_v_bnLhp = decay_bnLh1 * v_bnLhp.clone();
132
133            // state = decayed_V^T @ B:  [b, n, H, P, L] × [b, n, H, L, N] → [b, n, H, P, N]
134            let decayed_v_bnhpL = decayed_v_bnLhp.permute([0, 1, 3, 4, 2]);
135            let b_bnhLN = b_bnLhn.permute([0, 1, 3, 2, 4]);
136            decayed_v_bnhpL.matmul(b_bnhLN) // [b, n, H, P, N]
137        };
138
139        // =============================================================
140        // STEP 3: Inter-chunk state scan (state passing via segsum)
141        //
142        // h[n] = Ā_chunk[n] · h[n-1] + s[n]
143        // =============================================================
144        let (state_bnhpr, final_state_bhpr) = {
145            let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
146            let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
147                let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
148                    batch,
149                    1,
150                    nheads,
151                    per_head_dim,
152                    state_rank,
153                ]);
154                initial_state_b1hpr + init_b1hpr
155            } else {
156                initial_state_b1hpr
157            };
158
159            // Prepend initial state: [b, 1+n, H, P, N]
160            let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
161
162            // Per-chunk cumulative decay (last position of each chunk): [b, H, n]
163            let a_cumsum_last_bhn: Tensor<B, 3> = a_cumsum_bhnl
164                .clone()
165                .slice(s![.., .., .., -1])
166                .squeeze_dim(3);
167            // Prepend zero for the initial state (no decay before chunk 0):
168            let a_chunk_pad_bhN = Tensor::cat(
169                vec![Tensor::zeros([batch, nheads, 1], device), a_cumsum_last_bhn],
170                2,
171            ); // [b, H, 1+n]
172
173            // Inter-chunk decay matrix via segsum: [b, H, 1+n, 1+n]
174            let decay_chunk_bhNN = segsum::<B, 3, 4>(a_chunk_pad_bhN).exp();
175
176            // Flatten (P, N) for matmul
177            let flat = per_head_dim * state_rank;
178            let state_bhNf = state_bNhpr.clone().permute([0, 2, 1, 3, 4]).reshape([
179                batch,
180                nheads,
181                1 + nchunks,
182                flat,
183            ]);
184
185            // [b, H, 1+n, 1+n] × [b, H, 1+n, P·N] → [b, H, 1+n, P·N]
186            let new_state_bhNf = decay_chunk_bhNN.matmul(state_bhNf);
187            let new_state_bhNpr =
188                new_state_bhNf.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
189
190            // Split: chunk input states [0..n], final state [n]
191            let s_bhnpr = new_state_bhNpr
192                .clone()
193                .slice(s![.., .., 0..nchunks, .., ..]);
194            let f_bhpr: Tensor<B, 4> = new_state_bhNpr
195                .slice(s![.., .., nchunks, .., ..])
196                .squeeze_dim(2);
197
198            (s_bhnpr.permute([0, 2, 1, 3, 4]), f_bhpr) // [b, n, H, P, N], [b, H, P, N]
199        };
200
201        // =============================================================
202        // STEP 4: State-to-output (Y_off)
203        //
204        // Y_off[n, t*R+r] = C[t*R+r]ᵀ · exp(cumA[t]) · h[n-1]
205        // =============================================================
206        let y_off_bnLhp = {
207            // Expand base cumsum to fused, then exp:
208            let state_decay_bhnL = a_cumsum_bhnl
209                .clone()
210                .unsqueeze_dim::<5>(4)
211                .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
212                .reshape([batch, nheads, nchunks, fused_len])
213                .exp(); // [b, H, n, L]
214
215            // C: [b, n, H, L, N], state: [b, n, H, N, P]
216            let c_bnhLr = c_bnLhn.permute([0, 1, 3, 2, 4]);
217            let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]);
218            let ch_bnhLp = c_bnhLr.matmul(state_bnhrp); // [b, n, H, L, P]
219
220            // Multiply by intra-chunk decay: [b, n, H, L, 1]
221            let decay_bnhL1 = state_decay_bhnL.permute([0, 2, 1, 3]).unsqueeze_dim(4);
222            let y_off_bnhLp = ch_bnhLp * decay_bnhL1;
223            y_off_bnhLp.permute([0, 1, 3, 2, 4]) // [b, n, L, H, P]
224        };
225
226        // ── Combine and reshape ───────────────────────────────────────────────
227        let y_bnLhp = y_diag_bnLhp + y_off_bnLhp; // [b, n, L, H, P]
228        // Reshape: [b, n, L, H, P] = [b, n, l*R, H, P] → [b, n, l, R, H, P]
229        let y_bnlrhp =
230            y_bnLhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
231
232        (y_bnlrhp, final_state_bhpr)
233    }
234}
235
236// ---------------------------------------------------------------------------
237// segsum  (stable segment sum for the 1-SS mask)
238// ---------------------------------------------------------------------------
239
240/// Compute stable segment sums for constructing the 1-semiseparable mask.
241///
242/// Given a tensor `x` of shape `[..., T]`, returns a tensor of shape `[..., T, T]` where:
243///
244/// ```text
245///   out[..., i, j] = Σ_{k=j+1}^{i} x[..., k]   for i ≥ j  (lower triangle)
246///   out[..., i, j] = -∞                           for i < j  (upper triangle)
247/// ```
248pub(super) fn segsum<B: Backend, const D: usize, const D2: usize>(
249    x: Tensor<B, D>,
250) -> Tensor<B, D2> {
251    assert_eq!(D + 1, D2);
252
253    let x_cumsum = x.cumsum(D - 1);
254    let x_cumsum_row = x_cumsum.clone().unsqueeze_dim(D); // [..., T, 1]
255    let x_cumsum_col = x_cumsum.unsqueeze_dim(D - 1); // [..., 1, T]
256
257    let diff = x_cumsum_row - x_cumsum_col; // [..., T, T]
258    let neg_inf_mask = Tensor::full_like(&diff, f32::NEG_INFINITY).triu(1);
259    diff + neg_inf_mask
260}