Skip to main content

burn_mamba/mamba3/double_ssd/ssd/
minimal.rs

1//! ## The Chunkwise MIMO-SSD Algorithm (Minimal/Segsum variant)
2//!
3//! During training (and prefill), a naive sequential recurrence cannot
4//! exploit GPU tensor cores.  The **chunkwise SSD algorithm** achieves this
5//! by splitting the sequence into chunks of length chunk_len and decomposing the
6//! computation into four steps.
7//!
8//! ```text
9//!   Step 1  (intra-chunk, MIMO quadratic form)  →  Y_diag   [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
10//!   Step 2  (input → chunk state)               →  state    [batch, nchunks, nheads, per_head_dim, state_rank]
11//!   Step 3  (inter-chunk state scan)            →  state    [batch, nchunks, nheads, per_head_dim, state_rank], final_state
12//!   Step 4  (chunk state → output)              →  Y_off    [batch, nchunks, chunk_len*mimo_rank, nheads, per_head_dim]
13//!
14//!   Y = Y_diag + Y_off   →  reshape to [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
15//! ```
16//!
17//! The MIMO causal mask `LM_mimo[i,j] = exp(cumA[i//m] - cumA[j//m])` for `i//m >= j//m`
18//! allows all mimo_ranks ranks at the same time step to attend to each other while
19//! maintaining causal ordering across time steps.
20
21use crate::mamba3::double_ssd::prelude::*;
22use crate::modules::segsum;
23use burn::prelude::*;
24
25impl Mamba3DoubleSsdInput {
26    /// MIMO-first chunkwise SSD — minimal/segsum variant.
27    ///
28    /// Implements the four-step decomposition for the MIMO (double-ssd) trapezoidal recurrence.
29    /// SISO (mimo_rank=1) is the degenerate case where the fused length equals the chunk length.
30    ///
31    /// No D skip is applied here — the caller handles it.
32    ///
33    /// # Shapes
34    /// - input: see [`Mamba3DoubleSsdInput`]
35    /// - output.0 `y_bnlrhp`:       `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
36    /// - output.1 `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
37    #[allow(non_snake_case)]
38    pub fn double_ssd_minimal(self) -> (Tensor<6>, Tensor<4>) {
39        let input = self;
40        let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
41        let [.., state_rank] = input.b_bnlmhr.dims();
42        // note: L above denotes the chunk_len
43        let device = &input.v_bnlmhp.device();
44
45        assert!(nchunks >= 1, "sequence must be non-empty");
46        assert!(chunk_len > 0, "chunk_len must be positive");
47
48        // ── Fuse mimo_rank into chunk_len ────────────────────────────────────────
49        let b_bnLMhr =
50            input
51                .b_bnlmhr
52                .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
53        let c_bnLMhr =
54            input
55                .c_bnlmhr
56                .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
57        let v_bnLMhp =
58            input
59                .v_bnlmhp
60                .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
61
62        // Base per-time-step cumulative log-decay
63        let a_bhnl = input.da_bnlh.clone().permute([0, 3, 1, 2]);
64        let a_cumsum_bhnl = a_bhnl.clone().cumsum(3);
65
66        // =============================================================
67        // STEP 1: Intra-chunk outputs (Y_diag)
68        //
69        // Y_diag[m] = (L_mimo[m] ∘ C[m] B[m]ᵀ) · V[m]
70        // note: L above does not denote the chunk_len, but L in the Mamba-3 paper.
71        //
72        // MIMO mask: L_mimo[i,j] = exp(cumA[i//m] - cumA[j//m]) if i//m >= j//m, else 0
73        // =============================================================
74        let y_diag_bnLMhp = {
75            // CB = C @ B^T: contract over state_rank
76            let c_bnhLMr = c_bnLMhr.clone().permute([0, 1, 3, 2, 4]);
77            let b_bnhrLM = b_bnLMhr.clone().permute([0, 1, 3, 4, 2]);
78            // [batch, nchunks, nheads, chunk_len*mimo_rank, chunk_len*mimo_rank]
79            let cb_bnhLMLM = c_bnhLMr.matmul(b_bnhrLM);
80
81            // Build MIMO causal mask from segsum on base dimension, then interleave-expand.
82            // l_base_bhnll[i,j] = exp(cumA[i] - cumA[j]) if i >= j, else 0
83            let l_base_bhnll = segsum::<4, 5>(a_bhnl.clone()).exp();
84
85            // Interleave-expand
86            // L_mimo[i, j] = L_base[i//m, j//m]  (same decay for all ranks at a given time)
87            let l_mimo_bhnLMLM = l_base_bhnll
88                // row interleaving: insert mimo_rank copies of each l-row
89                .unsqueeze_dim::<6>(4) // l_base_bhnl1l
90                .expand([batch, nheads, nchunks, chunk_len, mimo_rank, chunk_len]) // l_base_bhnlml
91                .reshape([batch, nheads, nchunks, chunk_len * mimo_rank, chunk_len]) // l_base_bhnLMl
92                // col interleaving: insert mimo_rank copies of each l-col
93                .unsqueeze_dim::<6>(5) // l_base_bhnLMl1
94                .expand([
95                    batch,
96                    nheads,
97                    nchunks,
98                    chunk_len * mimo_rank,
99                    chunk_len,
100                    mimo_rank,
101                ]) // l_base_bhnLMlm
102                .reshape([
103                    batch,
104                    nheads,
105                    nchunks,
106                    chunk_len * mimo_rank,
107                    chunk_len * mimo_rank,
108                ]); // l_base_bhnLMLM
109
110            // Apply mask: (CB ∘ L_mimo) · V
111            let cb_bnLMhLM = cb_bnhLMLM.permute([0, 1, 3, 2, 4]);
112            let l_bnLMhLM = l_mimo_bhnLMLM.permute([0, 2, 3, 1, 4]);
113            let masked_cb_bnhLMLM = (cb_bnLMhLM * l_bnLMhLM).permute([0, 1, 3, 2, 4]);
114
115            let v_bnhLMp = v_bnLMhp.clone().permute([0, 1, 3, 2, 4]);
116            let y_diag_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
117
118            y_diag_bnhLMp.permute([0, 1, 3, 2, 4])
119        };
120
121        // =============================================================
122        // STEP 2: Chunk state (input → state, zero initial state)
123        //
124        // s[n] = Σ_{t,r} exp(cumA[n,-1] - cumA[n,t]) · V[n,t*m+r] · B[n,t*m+r]ᵀ
125        //      (outer product over per_head_dim and state_rank)
126        // =============================================================
127        let state_bnhpr = {
128            // Decay from each fused position to end of chunk:
129            let a_cumsum_last_bhn1 = a_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
130            // Expand base cumsum to fused length (each time repeated mimo_rank times):
131            // [b, nheads, n, l] → [b, nheads, n, l, R] → [b, nheads, n, L]
132            let a_cumsum_bhnLM = a_cumsum_bhnl
133                .clone()
134                .unsqueeze_dim::<5>(4) // a_cumsum_bhnl1
135                .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // a_cumsum_bhnlm
136                .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // a_cumsum_bhnLM
137            let decay_bhnLM = (a_cumsum_last_bhn1 - a_cumsum_bhnLM).exp();
138
139            // Multiply decay into V
140            let decay_bnLMh1 = decay_bhnLM
141                .permute([0, 2, 3, 1]) // decay_bnLMh
142                .unsqueeze_dim(4); // decay_bnLMh1
143            let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp.clone();
144
145            // state = decayed_V^T @ B
146            let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
147            let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
148            decayed_v_bnhpLM.matmul(b_bnhLMr) // state_bnhpr
149        };
150
151        // =============================================================
152        // STEP 3: Inter-chunk state scan (state passing via segsum)
153        //
154        // h[n] = Ā_chunk[n] · h[n-1] + s[n]
155        // =============================================================
156        let (state_bnhpr, final_state_bhpr) = {
157            let initial_state_b1hpr = input.initial_state_bhpr.unsqueeze_dim(1);
158            let initial_state_b1hpr = if let Some(init_hpr) = input.init_state_hpr {
159                let init_b1hpr = init_hpr.unsqueeze_dim::<4>(0).expand([
160                    batch,
161                    1,
162                    nheads,
163                    per_head_dim,
164                    state_rank,
165                ]);
166                initial_state_b1hpr + init_b1hpr
167            } else {
168                initial_state_b1hpr
169            };
170
171            // Prepend initial state: [batch, 1+nchunks, nheads, per_head_dim, state_rank]
172            let state_bNhpr = Tensor::cat(vec![initial_state_b1hpr, state_bnhpr], 1);
173
174            // Per-chunk cumulative decay (last position of each chunk)
175            let a_cumsum_last_bhn: Tensor<3> = a_cumsum_bhnl
176                .clone()
177                .slice(s![.., .., .., -1]) // a_cumsum_last_bhn1
178                .squeeze_dim(3); // a_cumsum_last_bhn
179            // Prepend zero for the initial state (no decay before chunk 0):
180            let a_chunk_pad_bhN = Tensor::cat(
181                vec![Tensor::zeros([batch, nheads, 1], device), a_cumsum_last_bhn],
182                2,
183            ); // [batch, nheads, 1+nchunks]
184
185            // Inter-chunk decay matrix via segsum: [batch, nheads, 1+nchunks, 1+nchunks]
186            let decay_chunk_bhNN = segsum::<3, 4>(a_chunk_pad_bhN).exp();
187
188            // Flatten (per_head_dim, state_rank) for matmul
189            let flat = per_head_dim * state_rank;
190            let state_bhNPR = state_bNhpr
191                .clone()
192                .permute([0, 2, 1, 3, 4]) // state_bhNpr
193                .reshape([batch, nheads, 1 + nchunks, flat]); // [batch, nheads, 1+nchunks, per_head_dim*state_rank]
194
195            let new_state_bhNPR = decay_chunk_bhNN.matmul(state_bhNPR);
196            let new_state_bhNpr =
197                new_state_bhNPR.reshape([batch, nheads, 1 + nchunks, per_head_dim, state_rank]);
198
199            // Split: chunk input states [0..n], final state [n]
200            let new_state_bnhpr = new_state_bhNpr
201                .clone()
202                .slice(s![.., .., 0..nchunks, .., ..]) // new_state_bhnpr
203                .permute([0, 2, 1, 3, 4]); // new_state_bnhpr
204            let last_state_bhpr: Tensor<4> = new_state_bhNpr
205                .slice(s![.., .., nchunks, .., ..]) // new_state_bh1pr
206                .squeeze_dim(2); // last_state_bhpr
207
208            (new_state_bnhpr, last_state_bhpr)
209        };
210
211        // =============================================================
212        // STEP 4: State-to-output (Y_off)
213        //
214        // Y_off[n, t*m+r] = C[t*m+r]ᵀ · exp(cumA[t]) · h[n-1]
215        // =============================================================
216        let y_off_bnLMhp = {
217            // Expand base cumsum to fused, then exp:
218            let state_decay_bhnLM = a_cumsum_bhnl
219                .clone()
220                .unsqueeze_dim::<5>(4) // a_cumsum_bhnl1
221                .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // a_cumsum_bhnlm
222                .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]) // a_cumsum_bhnLM
223                .exp();
224
225            // C
226            let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
227            let state_bnhrp = state_bnhpr.permute([0, 1, 2, 4, 3]);
228            let ch_bnhLMp = c_bnhLMr.matmul(state_bnhrp);
229
230            // Multiply by intra-chunk decay
231            let decay_bnhLM1 = state_decay_bhnLM
232                .permute([0, 2, 1, 3]) // state_decay_bnhLM
233                .unsqueeze_dim(4); // state_decay_bnhLM1
234            let y_off_bnhLMp = ch_bnhLMp * decay_bnhLM1;
235            y_off_bnhLMp.permute([0, 1, 3, 2, 4]) // y_off_bnLMhp
236        };
237
238        // ── Combine and reshape ───────────────────────────────────────────────
239        let y_bnLMhp = y_diag_bnLMhp + y_off_bnLMhp;
240        let y_bnlmhp =
241            y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
242
243        (y_bnlmhp, final_state_bhpr)
244    }
245}