Skip to main content

burn_mamba/mamba3/double_ssd/ssd/
serial.rs

1//! # Serial-over-chunks SSD (Mamba-3 double-SSD pathway)
2//!
3//! The MIMO-first chunkwise scan as a serial loop over chunks, structured like
4//! the five Mamba-2 kernels ([`super::super`](crate::mamba3::double_ssd)) but
5//! generalised over the `mimo_rank` axis: the chunk and rank axes are fused into
6//! a single length `L·M` for the intra-chunk products.  The standard kernels
7//! here are reused by **both** the γ-pass and the β-pass of the double-SSD
8//! decomposition (the caller pre-scales `v` by γ or β and shifts the β inputs).
9//!
10//! - **K1** [`k1_ssd_chunk_cumsum`] — per-chunk cumulative `Δ·A` decays.
11//! - **K2** [`k2_ssd_bmm`] — the intra-chunk `C·Bᵀ` block matmul (fused `L·M`).
12//! - **K3** [`k3_ssd_chunk_state`] — each chunk's end-state contribution.
13//! - **K4** `k4_ssd_state_passing` — the serial inter-chunk scan.
14//! - **K5** [`k5_ssd_chunk_scan`] — combines intra- and inter-chunk parts into `y`.
15//!
16//! Produces identical values/gradients to [`super::minimal`]; SISO
17//! (`mimo_rank = 1`) is the special case where the fused length equals the chunk
18//! length.  Gradients flow through plain autodiff.
19
20#![allow(non_snake_case)]
21
22use crate::mamba3::double_ssd::prelude::*;
23use burn::prelude::*;
24
25impl Mamba3DoubleSsdInput {
26    /// MIMO-first (Hybrid) Serial SSD.
27    ///
28    /// Implements K1-K5 with a sequential loop (K4) for the inter-chunk scan instead
29    /// of the quadratic segsum approach in [`Self::double_ssd_minimal`].
30    /// This is more memory-efficient for long sequences with many chunks.
31    ///
32    /// SISO (mimo_rank=1) is the special case where the fused length equals the chunk length.
33    ///
34    /// # Returns
35    /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
36    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
37    pub fn double_ssd_serial(self) -> (Tensor<6>, Tensor<4>) {
38        let input = self;
39        let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
40        let [.., state_rank] = input.b_bnlmhr.dims();
41
42        assert!(
43            input.init_state_hpr.is_none(),
44            "init_state_hpr is not yet supported in ssd_serial; use ssd_minimal instead"
45        );
46        assert!(nchunks > 0, "sequence length must be at least 1");
47
48        // ── K1: da_cumsum from da_bnlh ────────────────────────────────────────
49        let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(input.da_bnlh.clone());
50        assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
51        assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
52
53        // ── K2: CB matrix on fused tensors ────────────────────────────────────
54        let cb_bnhLMLM: Tensor<5> = k2_ssd_bmm(input.c_bnlmhr.clone(), input.b_bnlmhr.clone());
55        assert_eq!(
56            [
57                batch,
58                nchunks,
59                nheads,
60                chunk_len * mimo_rank,
61                chunk_len * mimo_rank
62            ],
63            cb_bnhLMLM.dims()
64        );
65
66        // ── K3: intra-chunk state ─────────────────────────────────────────────
67        let intra_chunk_state_bnhpr: Tensor<5> = k3_ssd_chunk_state(
68            input.v_bnlmhp.clone(),
69            input.b_bnlmhr.clone(),
70            da_cumsum_bhnl.clone(),
71        );
72        assert_eq!(
73            [batch, nchunks, nheads, per_head_dim, state_rank],
74            intra_chunk_state_bnhpr.dims()
75        );
76
77        // ── K4: state passing (sequential loop) ───────────────────────────────
78        let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<5>, Tensor<4>) =
79            k4_ssd_state_passing(
80                intra_chunk_state_bnhpr,
81                da_chunk_end_bhn,
82                input.initial_state_bhpr,
83            );
84        assert_eq!(
85            [batch, nchunks, nheads, per_head_dim, state_rank],
86            chunk_input_state_bnhpr.dims()
87        );
88        assert_eq!(
89            [batch, nheads, per_head_dim, state_rank],
90            final_state_bhpr.dims()
91        );
92
93        // ── K5: MIMO chunk scan ───────────────────────────────────────────────
94        let y_bnlmhp: Tensor<6> = k5_ssd_chunk_scan(
95            da_cumsum_bhnl,
96            input.v_bnlmhp,
97            input.c_bnlmhr,
98            cb_bnhLMLM,
99            chunk_input_state_bnhpr,
100        );
101
102        (y_bnlmhp, final_state_bhpr)
103    }
104}
105
106// ---------------------------------------------------------------------------
107// K1 — chunk cumulative log-decay
108// ---------------------------------------------------------------------------
109
110/// Compute the intra-chunk cumulative log-decay and per-chunk decay totals.
111///
112/// # Arguments
113/// - `da_bnlh`: pre-combined `Δ·A`, shape `[batch, nchunks, chunk_len, nheads]`
114///
115/// # Returns
116/// - `da_cumsum_bhnl`: `[batch, nheads, nchunks, chunk_len]` — intra-chunk prefix sums
117/// - `da_chunk_end_bhn`: `[batch, nheads, nchunks]` — last prefix sum per chunk (total decay)
118pub fn k1_ssd_chunk_cumsum(da_bnlh: Tensor<4>) -> (Tensor<4>, Tensor<3>) {
119    let [batch, nchunks, chunk_len, nheads] = da_bnlh.dims();
120    // Permute to [batch, nheads, nchunks, chunk_len] for the cumsum along the last dim
121    let da_bhnl = da_bnlh.permute([0, 3, 1, 2]);
122    let da_cumsum_bhnl = da_bhnl.cumsum(3);
123    assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
124
125    let da_chunk_end_bhn: Tensor<3> = da_cumsum_bhnl
126        .clone()
127        .slice(s![.., .., .., -1]) // da_cumsum_end_bhn1
128        .squeeze_dim(3); // da_cumsum_end_bhn
129    assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
130
131    (da_cumsum_bhnl, da_chunk_end_bhn)
132}
133
134// ---------------------------------------------------------------------------
135// K2 — CB block matrix (C @ B^T on fused MIMO tensors)
136// ---------------------------------------------------------------------------
137
138/// Compute the intra-chunk CB matrix on fused (mimo_rank-into-chunk_len) tensors.
139///
140/// # Arguments
141/// - `c_bnlmhr`: `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
142/// - `b_bnlmhr`: `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
143///
144/// # Returns
145/// - `cb_bnhLMLM`: `[batch, nchunks, nheads, chunk_len*mimo_rank, chunk_len*mimo_rank]`
146pub fn k2_ssd_bmm(c_bnlmhr: Tensor<6>, b_bnlmhr: Tensor<6>) -> Tensor<5> {
147    let [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank] = c_bnlmhr.dims();
148
149    // Fuse R into chunk_len
150    let c_bnLMhr = c_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
151    let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
152
153    let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
154    let b_bnhrLM = b_bnLMhr.permute([0, 1, 3, 4, 2]);
155    let cb_bnhLMLM: Tensor<5> = c_bnhLMr.matmul(b_bnhrLM);
156    assert_eq!(
157        [
158            batch,
159            nchunks,
160            nheads,
161            chunk_len * mimo_rank,
162            chunk_len * mimo_rank
163        ],
164        cb_bnhLMLM.dims()
165    );
166    cb_bnhLMLM
167}
168
169// ---------------------------------------------------------------------------
170// K3 — intra-chunk state (chunk-end state assuming zero initial state)
171// ---------------------------------------------------------------------------
172
173/// Compute the SSM state at the end of each chunk, assuming zero initial hidden state.
174///
175/// Uses the pre-scaled V tensor — no `dt·B` scaling is performed here.
176///
177/// # Arguments
178/// - `v_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]` — pre-scaled V
179/// - `b_bnlmhr`: `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
180/// - `da_cumsum_bhnl`: `[batch, nheads, nchunks, chunk_len]`
181///
182/// # Returns
183/// - `intra_chunk_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
184pub fn k3_ssd_chunk_state(
185    v_bnlmhp: Tensor<6>,
186    b_bnlmhr: Tensor<6>,
187    da_cumsum_bhnl: Tensor<4>,
188) -> Tensor<5> {
189    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
190    let [.., state_rank] = b_bnlmhr.dims();
191
192    // Fuse mimo_rank into chunk_len
193    let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
194    let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
195
196    // Decay from each fused position to end of chunk
197    let a_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
198    // Expand base cumsum to fused
199    let a_cumsum_bhnLM = da_cumsum_bhnl
200        .unsqueeze_dim::<5>(4) // da_cumsum_bhnl1
201        .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // da_cumsum_bhnlm
202        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // da_cumsum_bhnLM
203    let decay_bhnLM = (a_cumsum_last_bhn1 - a_cumsum_bhnLM).exp();
204
205    // decay * V
206    let decay_bnLMh1 = decay_bhnLM
207        .permute([0, 2, 3, 1]) // decay_bnLMh
208        .unsqueeze_dim(4); // decay_bnLMh1
209    let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp;
210
211    // state = decayed_V^T @ B
212    let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
213    let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
214    let intra_chunk_state_bnhpr: Tensor<5> = decayed_v_bnhpLM.matmul(b_bnhLMr);
215    assert_eq!(
216        [batch, nchunks, nheads, per_head_dim, state_rank],
217        intra_chunk_state_bnhpr.dims()
218    );
219    intra_chunk_state_bnhpr
220}
221
222// ---------------------------------------------------------------------------
223// K4 — inter-chunk state scan (sequential loop)
224// ---------------------------------------------------------------------------
225
226/// Propagate hidden state across chunk boundaries using a sequential scan.
227///
228/// This kernel is independent of MIMO rank — it operates on the `[nheads, per_head_dim, state_rank]` state
229/// which is already aggregated over ranks.
230///
231/// # Arguments
232/// - `intra_chunk_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
233/// - `da_chunk_end_bhn`: `[batch, nheads, nchunks]` — total log-decay per chunk
234/// - `initial_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
235///
236/// # Returns
237/// - `chunk_input_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
238/// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
239pub fn k4_ssd_state_passing(
240    intra_chunk_state_bnhpr: Tensor<5>,
241    da_chunk_end_bhn: Tensor<3>,
242    initial_state_bhpr: Tensor<4>,
243) -> (Tensor<5>, Tensor<4>) {
244    let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
245
246    let mut running_state_bhpr = initial_state_bhpr;
247    assert_eq!(
248        [batch, nheads, per_head_dim, state_rank],
249        running_state_bhpr.dims()
250    );
251
252    let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
253    chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
254
255    for i_chunk in 0..nchunks {
256        let intra_state_bhpr: Tensor<4> = intra_chunk_state_bnhpr
257            .clone()
258            .slice(s![.., i_chunk, .., .., ..]) // intra_chunk_state_b1hpr
259            .squeeze_dim(1); // intra_state_bhpr
260
261        let decay_bhpr = da_chunk_end_bhn
262            .clone()
263            .slice(s![.., .., i_chunk]) // da_chunk_end_bh1
264            .unsqueeze_dim::<4>(3) // da_chunk_end_bh
265            .exp()
266            .expand([batch, nheads, per_head_dim, state_rank]); // decay_bhpr
267
268        // SSM recurrence: h[n] = decay * h[n-1] + s[n]
269        running_state_bhpr = decay_bhpr * running_state_bhpr + intra_state_bhpr;
270        chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
271    }
272
273    let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
274    assert_eq!(
275        [batch, nheads, per_head_dim, state_rank],
276        final_state_bhpr.dims()
277    );
278
279    let chunk_input_state_bnhpr = Tensor::stack(chunk_input_state_vec_bhpr, 1);
280    assert_eq!(
281        [batch, nchunks, nheads, per_head_dim, state_rank],
282        chunk_input_state_bnhpr.dims()
283    );
284
285    (chunk_input_state_bnhpr, final_state_bhpr)
286}
287
288// ---------------------------------------------------------------------------
289// K5 — MIMO chunk scan (Y_diag + Y_off)
290// ---------------------------------------------------------------------------
291
292/// Compute the chunk output by combining the intra-chunk (diagonal) and
293/// inter-chunk (off-diagonal) contributions.
294///
295/// The MIMO causal mask uses interleaved time-step ordering:
296/// `L_mimo[i,j] = exp(cumA[i//m] - cumA[j//m])` if `i//m >= j//m`, else 0.
297///
298/// No D skip is applied — the caller handles it.
299///
300/// # Arguments
301/// - `da_cumsum_bhnl`: `[batch, nheads, nchunks, chunk_len]` — base (not fused)
302/// - `v_bnlmhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
303/// - `c_bnlmhr`: `[batch, nchunks, chunk_len, R, nheads, state_rank]`
304/// - `cb_bnhLMLM`: `[batch, nchunks, nheads, L, L]` from K2
305/// - `chunk_input_state_bnhpr`: `[batch, nchunks, nheads, per_head_dim, state_rank]`
306///
307/// # Returns
308/// - `y_bnlmhp`: `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
309pub fn k5_ssd_chunk_scan(
310    da_cumsum_bhnl: Tensor<4>,
311    v_bnlmhp: Tensor<6>,
312    c_bnlmhr: Tensor<6>,
313    cb_bnhLMLM: Tensor<5>,
314    chunk_input_state_bnhpr: Tensor<5>,
315) -> Tensor<6> {
316    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
317    let [.., state_rank] = c_bnlmhr.dims();
318    let device = v_bnlmhp.device();
319
320    // Fuse mimo_rank into chunk_len
321    let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
322    let c_bnLMhr = c_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
323
324    // Expand base da_cumsum to fused length: [b, nheads, n, l] → [b, nheads, n, L]
325    let da_cumsum_bhnLM = da_cumsum_bhnl
326        .unsqueeze_dim::<5>(4) // da_cumsum_bhnl1
327        .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // da_cumsum_bhnlm
328        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // da_cumsum_bhnLM
329
330    // ── BLUE (Y_off): exp(cumA[i]) · C[i] · h[n-1] ─────────────────────
331    let exp_da_bnhLMp = da_cumsum_bhnLM
332        .clone()
333        .exp()
334        .permute([0, 2, 1, 3]) // exp_da_bnhLM
335        .unsqueeze_dim::<5>(4) // // exp_da_bnhLM1
336        .expand([batch, nchunks, nheads, chunk_len * mimo_rank, per_head_dim]); // exp_da_bnhLMp
337
338    let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
339    let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
340    let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
341    let blue_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
342
343    // ── ORANGE (Y_diag): MIMO causal decay matrix · CB @ V ────────────────────
344    //
345    // MIMO pairwise decay: diff[i,j] = cumA[i] - cumA[j]
346    //                                = cumA_base[i//m] - cumA_base[j//m]
347    let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]);
348    let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
349        .clone()
350        .unsqueeze_dim::<5>(4) // da_cumsum_bnhLM1
351        .expand([
352            batch,
353            nchunks,
354            nheads,
355            chunk_len * mimo_rank,
356            chunk_len * mimo_rank,
357        ]);
358    let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
359        .unsqueeze_dim::<5>(3) // da_cumsum_bnh1LM
360        .expand([
361            batch,
362            nchunks,
363            nheads,
364            chunk_len * mimo_rank,
365            chunk_len * mimo_rank,
366        ]);
367    let diff_da_cumsum_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
368
369    // MIMO causal neg-inf mask: −∞ where j//m > i//m (source strictly ahead of target in time).
370    // Build as interleaved expansion of the standard 2-dimensional upper-triangle mask.
371    let neg_inf_base_bnhll: Tensor<5> =
372        Tensor::<2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device)
373            .triu(1) // [chunk_len, chunk_len]: -inf above diagonal
374            .unsqueeze_dims::<5>(&[0, 1, 2]) // neg_inf_base_111ll
375            .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // neg_inf_base_bnhll
376    // Interleave-expand
377    let neg_inf_bnhLMLM: Tensor<5> = neg_inf_base_bnhll
378        .unsqueeze_dim::<6>(4) // neg_inf_base_bnhl1l
379        .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len]) // neg_inf_base_bnhlml
380        .reshape([batch, nchunks, nheads, chunk_len * mimo_rank, chunk_len]) // neg_inf_base_bnhLMl
381        .unsqueeze_dim::<6>(5) // neg_inf_base_bnhLMl1
382        .expand([
383            batch,
384            nchunks,
385            nheads,
386            chunk_len * mimo_rank,
387            chunk_len,
388            mimo_rank,
389        ]) // neg_inf_base_bnhLMlm
390        .reshape([
391            batch,
392            nchunks,
393            nheads,
394            chunk_len * mimo_rank,
395            chunk_len * mimo_rank,
396        ]); // neg_inf_bnhLMLM
397
398    let decay_bnhLMLM = (diff_da_cumsum_bnhLMLM + neg_inf_bnhLMLM).exp();
399
400    let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
401    let orange_bnhLMp = (cb_bnhLMLM * decay_bnhLMLM).matmul(v_bnhLMp);
402
403    // ── Combine and reshape ────────────────────────────────────────────────────
404    let y_bnlmhp = (blue_bnhLMp + orange_bnhLMp)
405        .permute([0, 1, 3, 2, 4]) // y_bnLMhp
406        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]); // y_bnlmhp
407
408    y_bnlmhp
409}