Skip to main content

burn_mamba/mamba3/ssd/
serial.rs

1#![allow(non_snake_case)]
2
3use crate::mamba3::prelude::*;
4use burn::prelude::*;
5
6impl<B: Backend> Mamba3<B> {
7    /// MIMO-first (Hybrid) Serial SSD.
8    ///
9    /// Implements K1-K5 with a sequential loop (K4) for the inter-chunk scan instead
10    /// of the quadratic segsum approach in [`ssd_minimal`](Self::ssd_minimal).
11    /// This is more memory-efficient for long sequences with many chunks.
12    ///
13    /// SISO (R=1) is the special case where the fused length equals the chunk length.
14    ///
15    /// # Returns
16    /// - `y_bnlrhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
17    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
18    pub fn ssd_serial(input: super::Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>) {
19        let [batch, nchunks, chunk_len, _mimo_rank, nheads, per_head_dim] = input.v_bnlrhp.dims();
20        let [.., state_rank] = input.b_bnlrhn.dims();
21
22        assert!(
23            input.init_state_hpr.is_none(),
24            "init_state_hpr is not yet supported in ssd_serial; use ssd_minimal instead"
25        );
26        assert!(nchunks > 0, "sequence length must be at least 1");
27
28        // ── K1: da_cumsum from da_bnlh ────────────────────────────────────────
29        let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(input.da_bnlh.clone());
30        assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
31        assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
32
33        // ── K2: CB matrix on fused tensors ────────────────────────────────────
34        let cb_bnhLL: Tensor<B, 5> = k2_ssd_bmm(input.c_bnlrhn.clone(), input.b_bnlrhn.clone());
35        // [b, n, H, L, L] where L = chunk_len * mimo_rank
36
37        // ── K3: intra-chunk state ─────────────────────────────────────────────
38        let intra_chunk_state_bnhpr: Tensor<B, 5> = k3_ssd_chunk_state(
39            input.v_bnlrhp.clone(),
40            input.b_bnlrhn.clone(),
41            da_cumsum_bhnl.clone(),
42        );
43        assert_eq!(
44            [batch, nchunks, nheads, per_head_dim, state_rank],
45            intra_chunk_state_bnhpr.dims()
46        );
47
48        // ── K4: state passing (sequential loop) ───────────────────────────────
49        let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<B, 5>, Tensor<B, 4>) =
50            k4_ssd_state_passing(
51                intra_chunk_state_bnhpr,
52                da_chunk_end_bhn,
53                input.initial_state_bhpr,
54            );
55        assert_eq!(
56            [batch, nchunks, nheads, per_head_dim, state_rank],
57            chunk_input_state_bnhpr.dims()
58        );
59        assert_eq!(
60            [batch, nheads, per_head_dim, state_rank],
61            final_state_bhpr.dims()
62        );
63
64        // ── K5: MIMO chunk scan ───────────────────────────────────────────────
65        let y_bnlrhp: Tensor<B, 6> = k5_ssd_chunk_scan(
66            da_cumsum_bhnl,
67            input.v_bnlrhp,
68            input.c_bnlrhn,
69            cb_bnhLL,
70            chunk_input_state_bnhpr,
71        );
72
73        (y_bnlrhp, final_state_bhpr)
74    }
75}
76
77// ---------------------------------------------------------------------------
78// K1 — chunk cumulative log-decay
79// ---------------------------------------------------------------------------
80
81/// Compute the intra-chunk cumulative log-decay and per-chunk decay totals.
82///
83/// # Arguments
84/// - `da_bnlh`: pre-combined `Δ·A`, shape `[batch, nchunks, chunk_len, nheads]`
85///
86/// # Returns
87/// - `da_cumsum_bhnl`: `[batch, nheads, nchunks, chunk_len]` — intra-chunk prefix sums
88/// - `da_chunk_end_bhn`: `[batch, nheads, nchunks]` — last prefix sum per chunk (total decay)
89pub fn k1_ssd_chunk_cumsum<B: Backend>(da_bnlh: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 3>) {
90    let [batch, nchunks, chunk_len, nheads] = da_bnlh.dims();
91    // Permute to [b, H, n, l] for the cumsum along the last dim
92    let da_bhnl = da_bnlh.permute([0, 3, 1, 2]);
93    let da_cumsum_bhnl = da_bhnl.cumsum(3);
94    assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
95
96    let da_chunk_end_bhn: Tensor<B, 3> = da_cumsum_bhnl
97        .clone()
98        .slice(s![.., .., .., -1]) // [b, H, n, 1]
99        .squeeze_dim(3); // [b, H, n]
100    assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
101
102    (da_cumsum_bhnl, da_chunk_end_bhn)
103}
104
105// ---------------------------------------------------------------------------
106// K2 — CB block matrix (C @ B^T on fused MIMO tensors)
107// ---------------------------------------------------------------------------
108
109/// Compute the intra-chunk CB matrix on fused (R-into-L) tensors.
110///
111/// # Arguments
112/// - `c_bnlrhn`: `[batch, nchunks, chunk_len, R, nheads, state_rank]`
113/// - `b_bnlrhn`: `[batch, nchunks, chunk_len, R, nheads, state_rank]`
114///
115/// # Returns
116/// - `cb_bnhLL`: `[batch, nchunks, nheads, L, L]` where `L = chunk_len * mimo_rank`
117pub fn k2_ssd_bmm<B: Backend>(c_bnlrhn: Tensor<B, 6>, b_bnlrhn: Tensor<B, 6>) -> Tensor<B, 5> {
118    let [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank] = c_bnlrhn.dims();
119    let fused_len = chunk_len * mimo_rank;
120
121    // Fuse R into chunk_len: [b, n, l, R, H, N] → [b, n, L, H, N]
122    let c_bnLhn = c_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
123    let b_bnLhn = b_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
124
125    // [b, n, H, L, N] @ [b, n, H, N, L] → [b, n, H, L, L]
126    let c_bnhLr = c_bnLhn.permute([0, 1, 3, 2, 4]); // [b, n, H, L, N]
127    let b_bnhrL = b_bnLhn.permute([0, 1, 3, 4, 2]); // [b, n, H, N, L]
128    let cb_bnhLL: Tensor<B, 5> = c_bnhLr.matmul(b_bnhrL);
129    assert_eq!(
130        [batch, nchunks, nheads, fused_len, fused_len],
131        cb_bnhLL.dims()
132    );
133    cb_bnhLL
134}
135
136// ---------------------------------------------------------------------------
137// K3 — intra-chunk state (chunk-end state assuming zero initial state)
138// ---------------------------------------------------------------------------
139
140/// Compute the SSM state at the end of each chunk, assuming zero initial hidden state.
141///
142/// Uses the pre-scaled V tensor — no `dt·B` scaling is performed here.
143///
144/// # Arguments
145/// - `v_bnlrhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]` — pre-scaled V
146/// - `b_bnlrhn`: `[batch, nchunks, chunk_len, R, nheads, state_rank]`
147/// - `da_cumsum_bhnl`: `[batch, nheads, nchunks, chunk_len]` — base (not fused)
148///
149/// # Returns
150/// - `intra_chunk_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
151pub fn k3_ssd_chunk_state<B: Backend>(
152    v_bnlrhp: Tensor<B, 6>,
153    b_bnlrhn: Tensor<B, 6>,
154    da_cumsum_bhnl: Tensor<B, 4>,
155) -> Tensor<B, 5> {
156    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlrhp.dims();
157    let [.., state_rank] = b_bnlrhn.dims();
158    let fused_len = chunk_len * mimo_rank;
159
160    // Fuse R into chunk_len
161    let v_bnLhp = v_bnlrhp.reshape([batch, nchunks, fused_len, nheads, per_head_dim]);
162    let b_bnLhn = b_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
163
164    // Decay from each fused position to end of chunk:
165    //   decay_fused[t*R+r] = exp(cumA_last - cumA_base[t])
166    let a_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]); // [b,H,n,1]
167    // Expand base cumsum to fused: [b, H, n, l] → [b, H, n, l, R] → [b, H, n, L]
168    let a_cumsum_fused_bhnL = da_cumsum_bhnl
169        .unsqueeze_dim::<5>(4)
170        .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
171        .reshape([batch, nheads, nchunks, fused_len]);
172    // Broadcast [b,H,n,1] - [b,H,n,L] → [b,H,n,L]
173    let decay_bhnL = (a_cumsum_last_bhn1 - a_cumsum_fused_bhnL).exp();
174
175    // decay * V: [b, n, L, H, 1] * [b, n, L, H, P]
176    let decay_bnLh1 = decay_bhnL.permute([0, 2, 3, 1]).unsqueeze_dim(4);
177    let decayed_v_bnLhp = decay_bnLh1 * v_bnLhp;
178
179    // state = decayed_V^T @ B: [b, n, H, P, L] × [b, n, H, L, N] → [b, n, H, P, N]
180    let decayed_v_bnhpL = decayed_v_bnLhp.permute([0, 1, 3, 4, 2]);
181    let b_bnhLN = b_bnLhn.permute([0, 1, 3, 2, 4]);
182    let intra_chunk_state_bnhpr: Tensor<B, 5> = decayed_v_bnhpL.matmul(b_bnhLN);
183    assert_eq!(
184        [batch, nchunks, nheads, per_head_dim, state_rank],
185        intra_chunk_state_bnhpr.dims()
186    );
187    intra_chunk_state_bnhpr
188}
189
190// ---------------------------------------------------------------------------
191// K4 — inter-chunk state scan (sequential loop)
192// ---------------------------------------------------------------------------
193
194/// Propagate hidden state across chunk boundaries using a sequential scan.
195///
196/// This kernel is independent of MIMO rank — it operates on the `[H, P, N]` state
197/// which is already aggregated over ranks.
198///
199/// # Arguments
200/// - `intra_chunk_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
201/// - `da_chunk_end_bhn`: `[batch, nheads, nchunks]` — total log-decay per chunk
202/// - `initial_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
203///
204/// # Returns
205/// - `chunk_input_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
206/// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
207pub fn k4_ssd_state_passing<B: Backend>(
208    intra_chunk_state_bnhpr: Tensor<B, 5>,
209    da_chunk_end_bhn: Tensor<B, 3>,
210    initial_state_bhpr: Tensor<B, 4>,
211) -> (Tensor<B, 5>, Tensor<B, 4>) {
212    let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
213
214    let mut running_state_bhpr = initial_state_bhpr;
215    assert_eq!(
216        [batch, nheads, per_head_dim, state_rank],
217        running_state_bhpr.dims()
218    );
219
220    let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
221    chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
222
223    for i_chunk in 0..nchunks {
224        let intra_state_bhpr: Tensor<B, 4> = intra_chunk_state_bnhpr
225            .clone()
226            .slice(s![.., i_chunk, .., .., ..])
227            .squeeze_dim(1);
228
229        let decay_bhpr = da_chunk_end_bhn
230            .clone()
231            .slice(s![.., .., i_chunk])
232            .exp()
233            .unsqueeze_dim::<4>(3)
234            .expand([batch, nheads, per_head_dim, state_rank]);
235
236        // SSM recurrence: h[n] = decay * h[n-1] + s[n]
237        running_state_bhpr = decay_bhpr * running_state_bhpr + intra_state_bhpr;
238        chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
239    }
240
241    let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
242    assert_eq!(
243        [batch, nheads, per_head_dim, state_rank],
244        final_state_bhpr.dims()
245    );
246
247    let chunk_input_state_bnhpr = Tensor::stack(chunk_input_state_vec_bhpr, 1);
248    assert_eq!(
249        [batch, nchunks, nheads, per_head_dim, state_rank],
250        chunk_input_state_bnhpr.dims()
251    );
252
253    (chunk_input_state_bnhpr, final_state_bhpr)
254}
255
256// ---------------------------------------------------------------------------
257// K5 — MIMO chunk scan (Y_diag + Y_off)
258// ---------------------------------------------------------------------------
259
260/// Compute the chunk output by combining the intra-chunk (diagonal) and
261/// inter-chunk (off-diagonal) contributions.
262///
263/// The MIMO causal mask uses interleaved time-step ordering:
264/// `L_mimo[i,j] = exp(cumA[i//R] - cumA[j//R])` if `i//R >= j//R`, else 0.
265///
266/// No D skip is applied — the caller handles it.
267///
268/// # Arguments
269/// - `da_cumsum_bhnl`: `[batch, nheads, nchunks, chunk_len]` — base (not fused)
270/// - `v_bnlrhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
271/// - `c_bnlrhn`: `[batch, nchunks, chunk_len, R, nheads, state_rank]`
272/// - `cb_bnhLL`: `[batch, nchunks, nheads, L, L]` from K2
273/// - `chunk_input_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
274///
275/// # Returns
276/// - `y_bnlrhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
277pub fn k5_ssd_chunk_scan<B: Backend>(
278    da_cumsum_bhnl: Tensor<B, 4>,
279    v_bnlrhp: Tensor<B, 6>,
280    c_bnlrhn: Tensor<B, 6>,
281    cb_bnhLL: Tensor<B, 5>,
282    chunk_input_state_bnhpr: Tensor<B, 5>,
283) -> Tensor<B, 6> {
284    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlrhp.dims();
285    let [.., state_rank] = c_bnlrhn.dims();
286    let fused_len = chunk_len * mimo_rank;
287    let device = v_bnlrhp.device();
288
289    // Fuse R into chunk_len
290    let v_bnLhp = v_bnlrhp.reshape([batch, nchunks, fused_len, nheads, per_head_dim]);
291    let c_bnLhn = c_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
292
293    // Expand base da_cumsum to fused length: [b, H, n, l] → [b, H, n, L]
294    let da_cumsum_fused_bhnL = da_cumsum_bhnl
295        .unsqueeze_dim::<5>(4)
296        .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
297        .reshape([batch, nheads, nchunks, fused_len]);
298
299    // ── BLUE (Y_off): exp(cumA_fused[i]) · C[i] · h[n-1] ─────────────────────
300    // [b, H, n, L] → [b, n, H, L, 1] → [b, n, H, L, P]
301    let exp_da_fused_bnhLp = da_cumsum_fused_bhnL
302        .clone()
303        .exp()
304        .permute([0, 2, 1, 3])
305        .unsqueeze_dim::<5>(4)
306        .expand([batch, nchunks, nheads, fused_len, per_head_dim]);
307
308    let c_bnhLr = c_bnLhn.permute([0, 1, 3, 2, 4]); // [b, n, H, L, N]
309    let state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]); // [b, n, H, N, P]
310    let ch_bnhLp = c_bnhLr.matmul(state_bnhrp); // [b, n, H, L, P]
311    let blue_bnhLp = ch_bnhLp * exp_da_fused_bnhLp; // [b, n, H, L, P]
312
313    // ── ORANGE (Y_diag): MIMO causal decay matrix · CB @ V ────────────────────
314    //
315    // MIMO pairwise decay: diff[i,j] = cumA_fused[i] - cumA_fused[j]
316    //                                = cumA_base[i//R] - cumA_base[j//R]
317    let da_fused_bnhL = da_cumsum_fused_bhnL.permute([0, 2, 1, 3]); // [b, n, H, L]
318    let da_target_bnhLL = da_fused_bnhL
319        .clone()
320        .unsqueeze_dim::<5>(4)
321        .expand([batch, nchunks, nheads, fused_len, fused_len]); // [b, n, H, L, L]
322    let da_source_bnhLL = da_fused_bnhL
323        .unsqueeze_dim::<5>(3)
324        .expand([batch, nchunks, nheads, fused_len, fused_len]); // [b, n, H, L, L]
325    let diff_bnhLL = da_target_bnhLL - da_source_bnhLL;
326
327    // MIMO causal neg-inf mask: −∞ where j//R > i//R (source strictly ahead of target in time).
328    // Build as interleaved expansion of the standard 2D upper-triangle mask.
329    let neg_inf_base_bnhll: Tensor<B, 5> = {
330        let zero_ll: Tensor<B, 2> = Tensor::zeros([chunk_len, chunk_len], &device);
331        Tensor::full_like(&zero_ll, f32::NEG_INFINITY)
332            .triu(1) // [l, l]: -inf above diagonal
333            .unsqueeze_dims::<5>(&[0, 1, 2])
334            .expand([batch, nchunks, nheads, chunk_len, chunk_len])
335    };
336    // Interleave-expand: [b, n, H, l, l] → [b, n, H, L, L]
337    let neg_inf_mimo_bnhLL: Tensor<B, 5> = neg_inf_base_bnhll
338        .unsqueeze_dim::<6>(4)
339        .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len])
340        .reshape([batch, nchunks, nheads, fused_len, chunk_len])
341        .unsqueeze_dim::<6>(5)
342        .expand([batch, nchunks, nheads, fused_len, chunk_len, mimo_rank])
343        .reshape([batch, nchunks, nheads, fused_len, fused_len]);
344
345    let decay_bnhLL = (diff_bnhLL + neg_inf_mimo_bnhLL).exp(); // [b, n, H, L, L]
346
347    let v_bnhLp = v_bnLhp.permute([0, 1, 3, 2, 4]); // [b, n, H, L, P]
348    let orange_bnhLp = (cb_bnhLL * decay_bnhLL).matmul(v_bnhLp); // [b, n, H, L, P]
349
350    // ── Combine and reshape ────────────────────────────────────────────────────
351    // [b, n, H, L, P] → [b, n, L, H, P] → [b, n, l, R, H, P]
352    let y_bnlrhp = (blue_bnhLp + orange_bnhLp)
353        .permute([0, 1, 3, 2, 4])
354        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
355
356    y_bnlrhp
357}