Skip to main content

burn_mamba/mamba3/ssd/
ssd_path.rs

1use crate::mamba3::prelude::*;
2use burn::prelude::*;
3
4/// Ssd algorithm selection.
5///
6/// Each variant carries the chunk length Q for the SSD algorithm.
7/// Larger values increase the intra-chunk GEMM work and reduce the
8/// inter-chunk scan length.
9/// Optimal value is approximately `√(state_rank · per_head_dim)`.
10#[derive(Debug, Clone)]
11pub enum Mamba3SsdPath {
12    /// Minimal SSD.
13    ///
14    /// This algorithm mostly uses batched matmuls. For the backward operation, this relies on autodiff.
15    /// See [`chunked_selective_scan`] for more info.
16    ///
17    /// For training, you may prefer using [SerialRecalculated](Self::SerialRecalculated) instead.
18    ///
19    /// Based on `/mamba_ssm/modules/ssd_minimal.py` from the `state-spaces/mamba` github reference.
20    Minimal(Option<usize>),
21    /// (Hybrid) Serial SSD.
22    ///
23    /// This algorithm uses a serial loop over the nchunks, besides batched matmuls.
24    /// For the backward operation, this relies on autodiff.
25    /// For a custom backwards that saves memory, see [SerialRecalculated](Self::SerialRecalculated).
26    ///
27    /// Based on 5 kernels on `/mamba_ssm/ops/triton/` from the `state-spaces/mamba` github reference:
28    /// - `ssd_chunk_state.py` (K1, K3).
29    /// - `ssd_bmm.py` (K2).
30    /// - `ssd_state_passing.py` (K4).
31    /// - `ssd_chunk_scan.py` (K5).
32    Serial(Option<usize>),
33    /// (Hybrid) Serial SSD that triggers recalculations for the backward pass.
34    ///
35    /// This algorithm uses a serial loop over the nchunks, besides batched matmuls.
36    /// Contains a custom backward operation that saves memory.
37    /// For an autodiff backwards, see [Serial](Self::Serial).
38    ///
39    /// Based on the combined kernel `/mamba_ssm/ops/triton/ssd_combined.py` from the `state-spaces/mamba`
40    /// github reference.
41    SerialRecalculated(Option<usize>),
42}
43
44/// MIMO-first SSD input.
45///
46/// All tensors are pre-processed: B/C are already QK-normed, RoPE-applied, bias-added, and
47/// expanded to per-head (not per-group). V is already scaled by the trapezoidal coefficient
48/// (γ or β). The combined log-decay `da = Δ·A` is pre-computed. D skip is handled by the caller.
49pub struct Mamba3SsdInput<B: Backend> {
50    /// Value tensor, already scaled by trapezoidal coefficient (γ or β).
51    ///
52    /// # Shape
53    /// - [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
54    pub v_bnlrhp: Tensor<B, 6>,
55
56    /// Pre-combined log-decay `Δ·A` (negative).
57    ///
58    /// # Shape
59    /// - [batch, nchunks, chunk_len, nheads]
60    pub da_bnlh: Tensor<B, 4>,
61
62    /// Key/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head, per-rank.
63    ///
64    /// # Shape
65    /// - [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
66    pub b_bnlrhn: Tensor<B, 6>,
67
68    /// Query/C tensor: same processing as B.
69    ///
70    /// # Shape
71    /// - [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
72    pub c_bnlrhn: Tensor<B, 6>,
73
74    /// Initial SSM hidden state.
75    ///
76    /// # Shape
77    /// - [batch, nheads, per_head_dim, state_rank]
78    pub initial_state_bhpr: Tensor<B, 4>,
79
80    /// Optional learnable initial state (broadcast over batch).
81    ///
82    /// # Shape
83    /// - [nheads, per_head_dim, state_rank]
84    pub init_state_hpr: Option<Tensor<B, 3>>,
85}
86
87impl<B: Backend> Mamba3SsdInput<B> {
88    pub fn sanity(&self) {
89        use crate::utils::sanity::sanity as san;
90        san(&self.v_bnlrhp);
91        san(&self.da_bnlh);
92        san(&self.b_bnlrhn);
93        san(&self.c_bnlrhn);
94        san(&self.initial_state_bhpr);
95        if let Some(ref init_state_hpr) = self.init_state_hpr {
96            san(init_state_hpr);
97        }
98    }
99}
100
101impl Mamba3SsdPath {
102    /// Optimal chunk length is approximately `√(state_rank · per_head_dim)`.
103    pub fn optimal_default(state_rank: usize, per_head_dim: usize) -> usize {
104        (state_rank * per_head_dim)
105            .isqrt()
106            .next_multiple_of(32) // rule-of-thumb: common plane dimension.
107            .min(512) // rule-of-thumb: ceiling at 512.
108    }
109
110    /// Optimal Minimal variant.
111    ///
112    /// See [optimal_default](Self::optimal_default) for more info.
113    pub fn core_optimal(state_rank: usize, per_head_dim: usize) -> Self {
114        let optim = Self::optimal_default(state_rank, per_head_dim);
115        Self::Minimal(Some(optim))
116    }
117
118    /// Optimal Minimal variant.
119    ///
120    /// See [optimal_default](Self::optimal_default) for more info.
121    pub fn core_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self {
122        Self::core_optimal(block.state_rank, block.per_head_dim())
123    }
124
125    /// Optimal Serial variant.
126    ///
127    /// See [optimal_default](Self::optimal_default) for more info.
128    pub fn chunked_optimal(state_rank: usize, per_head_dim: usize) -> Self {
129        let optim = Self::optimal_default(state_rank, per_head_dim);
130        Self::Serial(Some(optim))
131    }
132
133    /// Optimal Serial variant.
134    ///
135    /// See [optimal_default](Self::optimal_default) for more info.
136    pub fn chunked_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self {
137        Self::chunked_optimal(block.state_rank, block.per_head_dim())
138    }
139
140    /// Optimal Serial variant.
141    ///
142    /// See [optimal_default](Self::optimal_default) for more info.
143    pub fn chunked_recalculated_optimal(state_rank: usize, per_head_dim: usize) -> Self {
144        let optim = Self::optimal_default(state_rank, per_head_dim);
145        Self::SerialRecalculated(Some(optim))
146    }
147
148    /// Optimal Serial Recalculated variant.
149    ///
150    /// See [optimal_default](Self::optimal_default) for more info.
151    pub fn chunked_recalculated_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self {
152        Self::chunked_recalculated_optimal(block.state_rank, block.per_head_dim())
153    }
154
155    pub fn chunk_len(&self) -> Option<usize> {
156        match self {
157            Mamba3SsdPath::Minimal(chunk_len) => *chunk_len,
158            Mamba3SsdPath::Serial(chunk_len) => *chunk_len,
159            Mamba3SsdPath::SerialRecalculated(chunk_len) => *chunk_len,
160        }
161    }
162
163    pub fn chunk_len_or_optimal(&self, state_rank: usize, per_head_dim: usize) -> usize {
164        match self {
165            Mamba3SsdPath::Minimal(chunk_len) => {
166                chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
167            }
168            Mamba3SsdPath::Serial(chunk_len) => {
169                chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
170            }
171            Mamba3SsdPath::SerialRecalculated(chunk_len) => {
172                chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
173            }
174        }
175    }
176
177    /// Run the SSD algorithm on the given MIMO-first input.
178    ///
179    /// Dispatches to `ssd_minimal`, `ssd_serial`, or `ssd_serial_recalculated` based on the variant.
180    ///
181    /// # Returns
182    /// - `y_bnlrhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
183    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
184    pub fn run<B: Backend + Mamba3BackendExt>(
185        &self,
186        input: Mamba3SsdInput<B>,
187    ) -> (Tensor<B, 6>, Tensor<B, 4>) {
188        match self {
189            Mamba3SsdPath::Minimal(_) => Mamba3::<B>::ssd_minimal(input),
190            Mamba3SsdPath::Serial(_) => Mamba3::<B>::ssd_serial(input),
191            Mamba3SsdPath::SerialRecalculated(_) => Mamba3::<B>::ssd_serial_recalculated(input),
192        }
193    }
194}
195
196impl Default for Mamba3SsdPath {
197    fn default() -> Mamba3SsdPath {
198        // Mamba3SsdPath defaults to the SerialRecalculated algorithm with the optimal chunk length.
199        Mamba3SsdPath::SerialRecalculated(None)
200    }
201}
202
203// ---------------------------------------------------------------------------
204// Tests
205// ---------------------------------------------------------------------------
206
207#[cfg(all(test, feature = "backend-flex"))]
208mod tests {
209    use super::*;
210    use burn::backend::{Autodiff, Flex};
211    use burn::module::Param;
212    use burn::tensor::Distribution;
213
214    /// Inner (non-autodiff) backend used for materialising values and
215    /// extracted gradients.
216    type InnerB = Flex;
217    /// Autodiff-wrapped backend used to drive `.backward()`.
218    type B = Autodiff<InnerB>;
219
220    type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
221
222    /// Build a randomised set of tensors on the inner backend.  `Param`s
223    /// wrapping these are built per-path so each path gets a fresh autodiff
224    /// graph.
225    ///
226    /// `da` is drawn from a negative-mean distribution so that the implied
227    /// per-token decay `exp(da)` stays in `(0, 1]`, matching how the upstream
228    /// block produces `Δ · A` with `A < 0`.
229    fn random_input(
230        batch: usize,
231        nchunks: usize,
232        chunk_len: usize,
233        mimo_rank: usize,
234        nheads: usize,
235        per_head_dim: usize,
236        state_rank: usize,
237        device: &Device,
238    ) -> (
239        Tensor<InnerB, 6>,
240        Tensor<InnerB, 4>,
241        Tensor<InnerB, 6>,
242        Tensor<InnerB, 6>,
243        Tensor<InnerB, 4>,
244    ) {
245        let v = Tensor::<InnerB, 6>::random(
246            [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim],
247            Distribution::Normal(0.0, 1.0),
248            device,
249        );
250        let da = Tensor::<InnerB, 4>::random(
251            [batch, nchunks, chunk_len, nheads],
252            Distribution::Normal(-0.5, 0.1),
253            device,
254        );
255        let b = Tensor::<InnerB, 6>::random(
256            [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank],
257            Distribution::Normal(0.0, 1.0),
258            device,
259        );
260        let c = Tensor::<InnerB, 6>::random(
261            [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank],
262            Distribution::Normal(0.0, 1.0),
263            device,
264        );
265        let initial_state = Tensor::<InnerB, 4>::random(
266            [batch, nheads, per_head_dim, state_rank],
267            Distribution::Normal(0.0, 0.1),
268            device,
269        );
270        (v, da, b, c, initial_state)
271    }
272
273    /// Inputs wrapped as `Param`s so each tensor becomes an autodiff leaf
274    /// with `require_grad`.  A fresh `Inputs` is built per path so each path
275    /// runs with its own independent autodiff graph.
276    struct Inputs {
277        v: Param<Tensor<B, 6>>,
278        da: Param<Tensor<B, 4>>,
279        b: Param<Tensor<B, 6>>,
280        c: Param<Tensor<B, 6>>,
281        initial_state: Param<Tensor<B, 4>>,
282    }
283
284    impl Inputs {
285        fn from_inner(
286            v: Tensor<InnerB, 6>,
287            da: Tensor<InnerB, 4>,
288            b: Tensor<InnerB, 6>,
289            c: Tensor<InnerB, 6>,
290            initial_state: Tensor<InnerB, 4>,
291        ) -> Self {
292            Self {
293                v: Param::from_tensor(Tensor::from_inner(v)),
294                da: Param::from_tensor(Tensor::from_inner(da)),
295                b: Param::from_tensor(Tensor::from_inner(b)),
296                c: Param::from_tensor(Tensor::from_inner(c)),
297                initial_state: Param::from_tensor(Tensor::from_inner(initial_state)),
298            }
299        }
300
301        fn ssd_input(&self) -> Mamba3SsdInput<B> {
302            Mamba3SsdInput {
303                v_bnlrhp: self.v.val(),
304                da_bnlh: self.da.val(),
305                b_bnlrhn: self.b.val(),
306                c_bnlrhn: self.c.val(),
307                initial_state_bhpr: self.initial_state.val(),
308                // Serial paths assert this is None — see ssd_serial / ssd_serial_recalculated.
309                init_state_hpr: None,
310            }
311        }
312    }
313
314    /// Collected forward outputs and input gradients for a single SSD path run.
315    struct PathRun {
316        y: Tensor<InnerB, 6>,
317        state: Tensor<InnerB, 4>,
318        d_v: Tensor<InnerB, 6>,
319        d_da: Tensor<InnerB, 4>,
320        d_b: Tensor<InnerB, 6>,
321        d_c: Tensor<InnerB, 6>,
322        d_init_state: Tensor<InnerB, 4>,
323    }
324
325    /// Combine `y` and `final_state` into a single deterministic scalar loss
326    /// using fixed (non-tracked) random "head" tensors. Two distinct heads so
327    /// that gradients for the y-branch and the state-branch are independent.
328    fn loss_from_outputs(
329        y_bnlrhp: Tensor<B, 6>,
330        final_state_bhpr: Tensor<B, 4>,
331        y_head: Tensor<InnerB, 6>,
332        s_head: Tensor<InnerB, 4>,
333    ) -> Tensor<B, 1> {
334        let y_head = Tensor::from_inner(y_head);
335        let s_head = Tensor::from_inner(s_head);
336        (y_bnlrhp * y_head).sum() + (final_state_bhpr * s_head).sum()
337    }
338
339    /// Run a single SSD path and extract the gradients of all 5 inputs.
340    fn run_path(
341        path: Mamba3SsdPath,
342        inputs: &Inputs,
343        y_head: Tensor<InnerB, 6>,
344        s_head: Tensor<InnerB, 4>,
345    ) -> PathRun {
346        let (y, state) = path.run(inputs.ssd_input());
347        let y_inner = y.clone().inner();
348        let state_inner = state.clone().inner();
349
350        let loss = loss_from_outputs(y, state, y_head, s_head);
351        let grads = loss.backward();
352
353        PathRun {
354            y: y_inner,
355            state: state_inner,
356            d_v: inputs.v.val().grad(&grads).expect("grad v"),
357            d_da: inputs.da.val().grad(&grads).expect("grad da"),
358            d_b: inputs.b.val().grad(&grads).expect("grad b"),
359            d_c: inputs.c.val().grad(&grads).expect("grad c"),
360            d_init_state: inputs
361                .initial_state
362                .val()
363                .grad(&grads)
364                .expect("grad initial_state"),
365        }
366    }
367
368    /// Run the same input through `Minimal`, `Serial`, and `SerialRecalculated`
369    /// and assert that all three agree on:
370    ///   1. the forward outputs (`y`, `final_state`)
371    ///   2. the gradients of every input through a fixed scalar loss.
372    ///
373    /// All three are chunkwise reformulations of the same MIMO-first SSD, so
374    /// both the values and their gradients must agree up to floating-point
375    /// noise.
376    fn run_minimal_matches_serial(
377        batch: usize,
378        nchunks: usize,
379        chunk_len: usize,
380        mimo_rank: usize,
381        nheads: usize,
382        per_head_dim: usize,
383        state_rank: usize,
384    ) {
385        let device: Device = Default::default();
386        let (v, da, b, c, init) = random_input(
387            batch,
388            nchunks,
389            chunk_len,
390            mimo_rank,
391            nheads,
392            per_head_dim,
393            state_rank,
394            &device,
395        );
396
397        // Fixed (non-tracked) "downstream heads" for the loss.
398        let y_head = Tensor::<InnerB, 6>::random(
399            [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim],
400            Distribution::Normal(0.0, 1.0),
401            &device,
402        );
403        let s_head = Tensor::<InnerB, 4>::random(
404            [batch, nheads, per_head_dim, state_rank],
405            Distribution::Normal(0.0, 1.0),
406            &device,
407        );
408
409        // Each path gets its own fresh autodiff graph (Param leaves).
410        let inputs_min =
411            Inputs::from_inner(v.clone(), da.clone(), b.clone(), c.clone(), init.clone());
412        let inputs_ser =
413            Inputs::from_inner(v.clone(), da.clone(), b.clone(), c.clone(), init.clone());
414        let inputs_rec = Inputs::from_inner(v, da, b, c, init);
415
416        let r_min = run_path(
417            Mamba3SsdPath::Minimal(Some(chunk_len)),
418            &inputs_min,
419            y_head.clone(),
420            s_head.clone(),
421        );
422        let r_ser = run_path(
423            Mamba3SsdPath::Serial(Some(chunk_len)),
424            &inputs_ser,
425            y_head.clone(),
426            s_head.clone(),
427        );
428        let r_rec = run_path(
429            Mamba3SsdPath::SerialRecalculated(Some(chunk_len)),
430            &inputs_rec,
431            y_head,
432            s_head,
433        );
434
435        // ── Forward agreement ────────────────────────────────────────────
436        let tol = 1e-4;
437        let dy_ser = (r_min.y.clone() - r_ser.y.clone())
438            .abs()
439            .max()
440            .into_scalar();
441        let ds_ser = (r_min.state.clone() - r_ser.state.clone())
442            .abs()
443            .max()
444            .into_scalar();
445        let dy_rec = (r_min.y.clone() - r_rec.y.clone())
446            .abs()
447            .max()
448            .into_scalar();
449        let ds_rec = (r_min.state.clone() - r_rec.state.clone())
450            .abs()
451            .max()
452            .into_scalar();
453        assert!(
454            dy_ser < tol,
455            "Minimal vs Serial: y max abs diff = {dy_ser:.6} (tol {tol})"
456        );
457        assert!(
458            ds_ser < tol,
459            "Minimal vs Serial: final_state max abs diff = {ds_ser:.6} (tol {tol})"
460        );
461        assert!(
462            dy_rec < tol,
463            "Minimal vs SerialRecalculated: y max abs diff = {dy_rec:.6} (tol {tol})"
464        );
465        assert!(
466            ds_rec < tol,
467            "Minimal vs SerialRecalculated: final_state max abs diff = {ds_rec:.6} (tol {tol})"
468        );
469
470        // ── Gradient agreement ───────────────────────────────────────────
471        // Looser tolerance: every path computes the same mathematical
472        // gradients, but the chunkwise reformulations accumulate sums in
473        // different orders, so small drift is expected.
474        let grad_tol = 1e-3;
475
476        let mut failures: Vec<String> = Vec::new();
477        macro_rules! diff {
478            ($a:expr, $b:expr) => {
479                ($a.clone() - $b.clone()).abs().max().into_scalar()
480            };
481        }
482        macro_rules! check_grad {
483            ($field:ident, $name:expr) => {{
484                let d_ser = diff!(r_min.$field, r_ser.$field);
485                let d_rec = diff!(r_min.$field, r_rec.$field);
486                eprintln!(
487                    "grad {:>14} | min↔ser = {:>10.6} | min↔rec = {:>10.6}",
488                    $name, d_ser, d_rec
489                );
490                if d_ser >= grad_tol {
491                    failures.push(format!(
492                        "Minimal vs Serial: grad of {} max abs diff = {:.6} (tol {})",
493                        $name, d_ser, grad_tol
494                    ));
495                }
496                if d_rec >= grad_tol {
497                    failures.push(format!(
498                        "Minimal vs SerialRecalculated: grad of {} max abs diff = {:.6} (tol {})",
499                        $name, d_rec, grad_tol
500                    ));
501                }
502            }};
503        }
504        check_grad!(d_v, "v");
505        check_grad!(d_da, "da");
506        check_grad!(d_b, "b");
507        check_grad!(d_c, "c");
508        check_grad!(d_init_state, "initial_state");
509
510        assert!(
511            failures.is_empty(),
512            "gradient mismatches:\n  {}",
513            failures.join("\n  ")
514        );
515    }
516
517    #[test]
518    fn paths_agree_siso() {
519        // batch=2, nchunks=3, chunk_len=4, mimo_rank=1, nheads=2, per_head_dim=8, state_rank=8
520        run_minimal_matches_serial(2, 3, 4, 1, 2, 8, 8);
521    }
522
523    #[test]
524    fn paths_agree_mimo() {
525        // mimo_rank=2 exercises the fused-L (= chunk_len · R) reshape shared by all three paths.
526        run_minimal_matches_serial(2, 3, 4, 2, 2, 8, 8);
527    }
528
529    #[test]
530    fn paths_agree_single_chunk() {
531        // nchunks=1 — no inter-chunk scan; checks the intra-chunk + state-passing
532        // boundary case where K4 runs a single iteration.
533        run_minimal_matches_serial(2, 1, 4, 1, 2, 8, 8);
534    }
535}