Skip to main content

burn_mamba/mamba2/ssd/
ssd_path.rs

1use crate::mamba2::prelude::*;
2use burn::prelude::*;
3
4/// Ssd algorithm selection.
5///
6/// Each variant carries the chunk length Q for the SSD algorithm.  
7/// Larger values increase the intra-chunk GEMM work and reduce the
8/// inter-chunk scan length.  
9/// Optimal value is approximately `√(state_rank · per_head_dim)`.
10#[derive(Debug, Clone)]
11pub enum Mamba2SsdPath {
12    /// Minimal SSD.
13    ///
14    /// This algorithm mostly uses batched matmuls. For the backward operation, this relies on autodiff.  
15    /// See [`chunked_selective_scan`] for more info.
16    ///
17    /// For training, you may prefer using [SerialRecalculated](Self::SerialRecalculated) instead.
18    ///
19    /// Based on `/mamba_ssm/modules/ssd_minimal.py` from the `state-spaces/mamba` github reference.
20    Minimal(Option<usize>),
21    /// (Hybrid) Serial SSD.
22    ///
23    /// This algorithm uses a serial loop over the nchunks, besides batched matmuls.
24    /// For the backward operation, this relies on autodiff.  
25    /// For a custom backwards that saves memory, see [SerialRecalculated](Self::SerialRecalculated).
26    ///
27    /// Based on 5 kernels on `/mamba_ssm/ops/triton/` from the `state-spaces/mamba` github reference:
28    /// - `ssd_chunk_state.py` (K1, K3).
29    /// - `ssd_bmm.py` (K2).
30    /// - `ssd_state_passing.py` (K4).
31    /// - `ssd_chunk_scan.py` (K5).
32    Serial(Option<usize>),
33    /// (Hybrid) Serial SSD that triggers recalculations for the backward pass.
34    ///
35    /// This algorithm uses a serial loop over the nchunks, besides batched matmuls.
36    /// Contains a custom backward operation that saves memory.  
37    /// For an autodiff backwards, see [Serial](Self::Serial).
38    ///
39    /// Based on the combined kernel `/mamba_ssm/ops/triton/ssd_combined.py` from the `state-spaces/mamba`
40    /// github reference.
41    SerialRecalculated(Option<usize>),
42}
43
44pub struct Mamba2SsdInput<B: Backend> {
45    /// # Shape
46    /// - [batch, nchunks, chunk_len, nheads, per_head_dim]
47    pub x_bnlhp: Tensor<B, 5>,
48    /// # Shape
49    /// - [batch, nchunks, chunk_len, nheads]
50    pub dt_bnlh: Tensor<B, 4>,
51    /// # Shape
52    /// - [nheads]
53    pub a_decay_h: Tensor<B, 1>,
54    /// # Shape
55    /// - [batch, nchunks, chunk_len, ngroups, state_rank]
56    pub b_bnlgr: Tensor<B, 5>,
57    /// # Shape
58    /// - [batch, nchunks, chunk_len, ngroups, state_rank]
59    pub c_bnlgr: Tensor<B, 5>,
60    /// # Shape
61    /// - [nheads]
62    pub d_h: Tensor<B, 1>,
63    /// # Shape
64    /// - [batch, nheads, per_head_dim, state_rank]
65    pub initial_state_bhpr: Tensor<B, 4>,
66    /// # Shape
67    /// - [nheads, per_head_dim, state_rank]
68    pub init_state_hpr: Option<Tensor<B, 3>>,
69}
70
71impl<B: Backend> Mamba2SsdInput<B> {
72    pub fn sanity(&self) {
73        use crate::utils::sanity::sanity as san;
74        san(&self.x_bnlhp);
75        san(&self.dt_bnlh);
76        san(&self.a_decay_h);
77        san(&self.b_bnlgr);
78        san(&self.c_bnlgr);
79        san(&self.d_h);
80        san(&self.initial_state_bhpr);
81        if let Some(ref init_state_hpr) = self.init_state_hpr {
82            san(init_state_hpr);
83        }
84    }
85}
86
87impl Mamba2SsdPath {
88    /// Optimal chunk length is approximately `√(state_rank · per_head_dim)`.
89    pub fn optimal_default(state_rank: usize, per_head_dim: usize) -> usize {
90        (state_rank * per_head_dim)
91            .isqrt()
92            .next_multiple_of(32) // rule-of-thumb: common plane dimension.
93            .min(512) // rule-of-thumb: ceiling at 512.
94    }
95
96    /// Optimal Minimal variant.
97    ///
98    /// See [optimal_default](Self::optimal_default) for more info.
99    pub fn core_optimal(state_rank: usize, per_head_dim: usize) -> Self {
100        let optim = Self::optimal_default(state_rank, per_head_dim);
101        Self::Minimal(Some(optim))
102    }
103
104    /// Optimal Minimal variant.
105    ///
106    /// See [optimal_default](Self::optimal_default) for more info.
107    pub fn core_optimal_from_block<B: Backend>(block: &Mamba2<B>) -> Self {
108        Self::core_optimal(block.state_rank, block.per_head_dim())
109    }
110
111    /// Optimal Serial variant.
112    ///
113    /// See [optimal_default](Self::optimal_default) for more info.
114    pub fn chunked_optimal(state_rank: usize, per_head_dim: usize) -> Self {
115        let optim = Self::optimal_default(state_rank, per_head_dim);
116        Self::Serial(Some(optim))
117    }
118
119    /// Optimal Serial variant.
120    ///
121    /// See [optimal_default](Self::optimal_default) for more info.
122    pub fn chunked_optimal_from_block<B: Backend>(block: &Mamba2<B>) -> Self {
123        Self::chunked_optimal(block.state_rank, block.per_head_dim())
124    }
125
126    /// Optimal Serial variant.
127    ///
128    /// See [optimal_default](Self::optimal_default) for more info.
129    pub fn chunked_recalculated_optimal(state_rank: usize, per_head_dim: usize) -> Self {
130        let optim = Self::optimal_default(state_rank, per_head_dim);
131        Self::SerialRecalculated(Some(optim))
132    }
133
134    /// Optimal Serial Recalculated variant.
135    ///
136    /// See [optimal_default](Self::optimal_default) for more info.
137    pub fn chunked_recalculated_optimal_from_block<B: Backend>(block: &Mamba2<B>) -> Self {
138        Self::chunked_recalculated_optimal(block.state_rank, block.per_head_dim())
139    }
140
141    pub fn chunk_len(&self) -> Option<usize> {
142        match self {
143            Mamba2SsdPath::Minimal(chunk_len) => *chunk_len,
144            Mamba2SsdPath::Serial(chunk_len) => *chunk_len,
145            Mamba2SsdPath::SerialRecalculated(chunk_len) => *chunk_len,
146        }
147    }
148
149    pub fn chunk_len_or_optimal(&self, state_rank: usize, per_head_dim: usize) -> usize {
150        match self {
151            Mamba2SsdPath::Minimal(chunk_len) => {
152                chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
153            }
154            Mamba2SsdPath::Serial(chunk_len) => {
155                chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
156            }
157            Mamba2SsdPath::SerialRecalculated(chunk_len) => {
158                chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
159            }
160        }
161    }
162
163    /// Run the SSD algorithm on the given input.
164    ///
165    /// Dispatches to `ssd_minimal`, `ssd_serial`, or `ssd_serial_recalculated` based on the variant.
166    ///
167    /// # Returns
168    /// - `y_bnlhp`: `[batch, nchunks, chunk_len, nheads, per_head_dim]`
169    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
170    pub fn run<B: Backend + Mamba2BackendExt>(
171        &self,
172        input: Mamba2SsdInput<B>,
173    ) -> (Tensor<B, 5>, Tensor<B, 4>) {
174        match self {
175            Mamba2SsdPath::Minimal(_) => Mamba2::<B>::ssd_minimal(input),
176            Mamba2SsdPath::Serial(_) => Mamba2::<B>::ssd_serial(input),
177            Mamba2SsdPath::SerialRecalculated(_) => Mamba2::<B>::ssd_serial_recalculated(input),
178        }
179    }
180}
181
182impl Default for Mamba2SsdPath {
183    fn default() -> Mamba2SsdPath {
184        // Mamba2SsdPath defaults to the SerialRecalculated algorithm with the optimal chunk length.
185        Mamba2SsdPath::SerialRecalculated(None)
186    }
187}
188
189// ---------------------------------------------------------------------------
190// Tests
191// ---------------------------------------------------------------------------
192
193#[cfg(all(test, feature = "backend-flex"))]
194mod tests {
195    use super::*;
196    use burn::backend::{Autodiff, Flex};
197    use burn::module::Param;
198    use burn::tensor::Distribution;
199
200    /// Inner (non-autodiff) backend used for materialising values and
201    /// extracted gradients.
202    type InnerB = Flex;
203    /// Autodiff-wrapped backend used to drive `.backward()`.
204    type B = Autodiff<InnerB>;
205
206    type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
207
208    /// Build a randomised set of tensors on the inner backend (no grad
209    /// tracking yet — `Param::from_tensor` is applied per-path below to
210    /// give each path its own fresh autodiff graph).
211    ///
212    /// `dt` is drawn from a positive distribution (softplus-like) and `a_decay`
213    /// from a negative range so that the implied per-token decay `exp(dt·a)`
214    /// stays in `(0, 1]`, matching how the upstream block produces them.
215    fn random_input(
216        batch: usize,
217        nchunks: usize,
218        chunk_len: usize,
219        nheads: usize,
220        per_head_dim: usize,
221        ngroups: usize,
222        state_rank: usize,
223        device: &Device,
224    ) -> (
225        Tensor<InnerB, 5>,
226        Tensor<InnerB, 4>,
227        Tensor<InnerB, 1>,
228        Tensor<InnerB, 5>,
229        Tensor<InnerB, 5>,
230        Tensor<InnerB, 1>,
231        Tensor<InnerB, 4>,
232    ) {
233        let x = Tensor::<InnerB, 5>::random(
234            [batch, nchunks, chunk_len, nheads, per_head_dim],
235            Distribution::Normal(0.0, 1.0),
236            device,
237        );
238        let dt = Tensor::<InnerB, 4>::random(
239            [batch, nchunks, chunk_len, nheads],
240            Distribution::Uniform(0.05, 0.3),
241            device,
242        );
243        let a_decay =
244            Tensor::<InnerB, 1>::random([nheads], Distribution::Uniform(-1.0, -0.5), device);
245        let b = Tensor::<InnerB, 5>::random(
246            [batch, nchunks, chunk_len, ngroups, state_rank],
247            Distribution::Normal(0.0, 1.0),
248            device,
249        );
250        let c = Tensor::<InnerB, 5>::random(
251            [batch, nchunks, chunk_len, ngroups, state_rank],
252            Distribution::Normal(0.0, 1.0),
253            device,
254        );
255        let d = Tensor::<InnerB, 1>::random([nheads], Distribution::Normal(0.0, 0.1), device);
256        let initial_state = Tensor::<InnerB, 4>::random(
257            [batch, nheads, per_head_dim, state_rank],
258            Distribution::Normal(0.0, 0.1),
259            device,
260        );
261        (x, dt, a_decay, b, c, d, initial_state)
262    }
263
264    /// Inputs wrapped as `Param`s so each tensor becomes an autodiff leaf
265    /// with `require_grad`. One `Inputs` is built per path, sharing the same
266    /// underlying inner values but its own autodiff graph.
267    struct Inputs {
268        x: Param<Tensor<B, 5>>,
269        dt: Param<Tensor<B, 4>>,
270        a_decay: Param<Tensor<B, 1>>,
271        b: Param<Tensor<B, 5>>,
272        c: Param<Tensor<B, 5>>,
273        d: Param<Tensor<B, 1>>,
274        initial_state: Param<Tensor<B, 4>>,
275    }
276
277    impl Inputs {
278        #[allow(clippy::too_many_arguments)]
279        fn from_inner(
280            x: Tensor<InnerB, 5>,
281            dt: Tensor<InnerB, 4>,
282            a_decay: Tensor<InnerB, 1>,
283            b: Tensor<InnerB, 5>,
284            c: Tensor<InnerB, 5>,
285            d: Tensor<InnerB, 1>,
286            initial_state: Tensor<InnerB, 4>,
287        ) -> Self {
288            Self {
289                x: Param::from_tensor(Tensor::from_inner(x)),
290                dt: Param::from_tensor(Tensor::from_inner(dt)),
291                a_decay: Param::from_tensor(Tensor::from_inner(a_decay)),
292                b: Param::from_tensor(Tensor::from_inner(b)),
293                c: Param::from_tensor(Tensor::from_inner(c)),
294                d: Param::from_tensor(Tensor::from_inner(d)),
295                initial_state: Param::from_tensor(Tensor::from_inner(initial_state)),
296            }
297        }
298
299        fn ssd_input(&self) -> Mamba2SsdInput<B> {
300            Mamba2SsdInput {
301                x_bnlhp: self.x.val(),
302                dt_bnlh: self.dt.val(),
303                a_decay_h: self.a_decay.val(),
304                b_bnlgr: self.b.val(),
305                c_bnlgr: self.c.val(),
306                d_h: self.d.val(),
307                initial_state_bhpr: self.initial_state.val(),
308                // Serial paths assert this is None — see ssd_serial / ssd_serial_recalculated.
309                init_state_hpr: None,
310            }
311        }
312    }
313
314    /// Collected forward outputs and input gradients for a single SSD path run.
315    struct PathRun {
316        y: Tensor<InnerB, 5>,
317        state: Tensor<InnerB, 4>,
318        d_x: Tensor<InnerB, 5>,
319        d_dt: Tensor<InnerB, 4>,
320        d_a_decay: Tensor<InnerB, 1>,
321        d_b: Tensor<InnerB, 5>,
322        d_c: Tensor<InnerB, 5>,
323        d_d: Tensor<InnerB, 1>,
324        d_init_state: Tensor<InnerB, 4>,
325    }
326
327    /// Combine `y` and `final_state` into a single deterministic scalar loss
328    /// using fixed (non-tracked) random "head" tensors. The two heads differ so
329    /// that gradients for the y-branch and the state-branch are independent
330    /// (a mistake in either path shows up in the parameter grads).
331    fn loss_from_outputs(
332        y_bnlhp: Tensor<B, 5>,
333        final_state_bhpr: Tensor<B, 4>,
334        y_head: Tensor<InnerB, 5>,
335        s_head: Tensor<InnerB, 4>,
336    ) -> Tensor<B, 1> {
337        let y_head = Tensor::from_inner(y_head);
338        let s_head = Tensor::from_inner(s_head);
339        (y_bnlhp * y_head).sum() + (final_state_bhpr * s_head).sum()
340    }
341
342    /// Run a single SSD path and extract the gradients of all 7 inputs.
343    fn run_path(
344        path: Mamba2SsdPath,
345        inputs: &Inputs,
346        y_head: Tensor<InnerB, 5>,
347        s_head: Tensor<InnerB, 4>,
348    ) -> PathRun {
349        let (y, state) = path.run(inputs.ssd_input());
350        let y_inner = y.clone().inner();
351        let state_inner = state.clone().inner();
352
353        let loss = loss_from_outputs(y, state, y_head, s_head);
354        let grads = loss.backward();
355
356        // Inline grad extraction (a closure cannot be reused here since the
357        // gradient tensor rank varies per call).
358        PathRun {
359            y: y_inner,
360            state: state_inner,
361            d_x: inputs.x.val().grad(&grads).expect("grad x"),
362            d_dt: inputs.dt.val().grad(&grads).expect("grad dt"),
363            d_a_decay: inputs.a_decay.val().grad(&grads).expect("grad a_decay"),
364            d_b: inputs.b.val().grad(&grads).expect("grad b"),
365            d_c: inputs.c.val().grad(&grads).expect("grad c"),
366            d_d: inputs.d.val().grad(&grads).expect("grad d"),
367            d_init_state: inputs
368                .initial_state
369                .val()
370                .grad(&grads)
371                .expect("grad initial_state"),
372        }
373    }
374
375    /// Run the same input through `Minimal`, `Serial`, and `SerialRecalculated`
376    /// and assert that all three agree on:
377    ///   1. the forward outputs (`y`, `final_state`)
378    ///   2. the gradients of every input through a fixed scalar loss.
379    ///
380    /// All three are chunkwise reformulations of the same SSD, so both the
381    /// values and their gradients must agree up to floating-point noise.
382    fn run_minimal_matches_serial(
383        batch: usize,
384        nchunks: usize,
385        chunk_len: usize,
386        nheads: usize,
387        per_head_dim: usize,
388        ngroups: usize,
389        state_rank: usize,
390    ) {
391        let device: Device = Default::default();
392        let (x, dt, a_decay, b, c, d, init) = random_input(
393            batch,
394            nchunks,
395            chunk_len,
396            nheads,
397            per_head_dim,
398            ngroups,
399            state_rank,
400            &device,
401        );
402
403        // Fixed (non-tracked) "downstream heads" for the loss. Two distinct
404        // random tensors so y- and state-gradient paths are exercised
405        // independently.
406        let y_head = Tensor::<InnerB, 5>::random(
407            [batch, nchunks, chunk_len, nheads, per_head_dim],
408            Distribution::Normal(0.0, 1.0),
409            &device,
410        );
411        let s_head = Tensor::<InnerB, 4>::random(
412            [batch, nheads, per_head_dim, state_rank],
413            Distribution::Normal(0.0, 1.0),
414            &device,
415        );
416
417        // Each path gets its own fresh autodiff graph (Param leaves).
418        let inputs_min = Inputs::from_inner(
419            x.clone(),
420            dt.clone(),
421            a_decay.clone(),
422            b.clone(),
423            c.clone(),
424            d.clone(),
425            init.clone(),
426        );
427        let inputs_ser = Inputs::from_inner(
428            x.clone(),
429            dt.clone(),
430            a_decay.clone(),
431            b.clone(),
432            c.clone(),
433            d.clone(),
434            init.clone(),
435        );
436        let inputs_rec = Inputs::from_inner(x, dt, a_decay, b, c, d, init);
437
438        let r_min = run_path(
439            Mamba2SsdPath::Minimal(Some(chunk_len)),
440            &inputs_min,
441            y_head.clone(),
442            s_head.clone(),
443        );
444        let r_ser = run_path(
445            Mamba2SsdPath::Serial(Some(chunk_len)),
446            &inputs_ser,
447            y_head.clone(),
448            s_head.clone(),
449        );
450        let r_rec = run_path(
451            Mamba2SsdPath::SerialRecalculated(Some(chunk_len)),
452            &inputs_rec,
453            y_head,
454            s_head,
455        );
456
457        // ── Forward agreement ────────────────────────────────────────────
458        let tol = 1e-4;
459        let dy_ser = (r_min.y.clone() - r_ser.y.clone())
460            .abs()
461            .max()
462            .into_scalar();
463        let ds_ser = (r_min.state.clone() - r_ser.state.clone())
464            .abs()
465            .max()
466            .into_scalar();
467        let dy_rec = (r_min.y.clone() - r_rec.y.clone())
468            .abs()
469            .max()
470            .into_scalar();
471        let ds_rec = (r_min.state.clone() - r_rec.state.clone())
472            .abs()
473            .max()
474            .into_scalar();
475        assert!(
476            dy_ser < tol,
477            "Minimal vs Serial: y max abs diff = {dy_ser:.6} (tol {tol})"
478        );
479        assert!(
480            ds_ser < tol,
481            "Minimal vs Serial: final_state max abs diff = {ds_ser:.6} (tol {tol})"
482        );
483        assert!(
484            dy_rec < tol,
485            "Minimal vs SerialRecalculated: y max abs diff = {dy_rec:.6} (tol {tol})"
486        );
487        assert!(
488            ds_rec < tol,
489            "Minimal vs SerialRecalculated: final_state max abs diff = {ds_rec:.6} (tol {tol})"
490        );
491
492        // ── Gradient agreement ───────────────────────────────────────────
493        // Looser tolerance: every path computes the same mathematical
494        // gradients, but the chunkwise reformulations accumulate the
495        // summations in different orders, so small drift is expected.
496        let grad_tol = 1e-3;
497
498        let mut failures: Vec<String> = Vec::new();
499        macro_rules! diff {
500            ($a:expr, $b:expr) => {
501                ($a.clone() - $b.clone()).abs().max().into_scalar()
502            };
503        }
504        macro_rules! check_grad {
505            ($field:ident, $name:expr) => {{
506                let d_ser = diff!(r_min.$field, r_ser.$field);
507                let d_rec = diff!(r_min.$field, r_rec.$field);
508                eprintln!(
509                    "grad {:>14} | min↔ser = {:>10.6} | min↔rec = {:>10.6}",
510                    $name, d_ser, d_rec
511                );
512                if d_ser >= grad_tol {
513                    failures.push(format!(
514                        "Minimal vs Serial: grad of {} max abs diff = {:.6} (tol {})",
515                        $name, d_ser, grad_tol
516                    ));
517                }
518                if d_rec >= grad_tol {
519                    failures.push(format!(
520                        "Minimal vs SerialRecalculated: grad of {} max abs diff = {:.6} (tol {})",
521                        $name, d_rec, grad_tol
522                    ));
523                }
524            }};
525        }
526        check_grad!(d_x, "x");
527        check_grad!(d_dt, "dt");
528        check_grad!(d_a_decay, "a_decay");
529        check_grad!(d_b, "b");
530        check_grad!(d_c, "c");
531        check_grad!(d_d, "d");
532        check_grad!(d_init_state, "initial_state");
533
534        assert!(
535            failures.is_empty(),
536            "gradient mismatches:\n  {}",
537            failures.join("\n  ")
538        );
539    }
540
541    #[test]
542    fn paths_agree_no_gqa() {
543        // ngroups == nheads (no GQA expansion): B/C are per-head.
544        run_minimal_matches_serial(2, 3, 4, 2, 8, 2, 8);
545    }
546
547    #[test]
548    fn paths_agree_gqa() {
549        // ngroups < nheads: B/C are shared across `heads_per_group` heads.
550        run_minimal_matches_serial(2, 3, 4, 4, 8, 1, 8);
551    }
552
553    #[test]
554    fn paths_agree_single_chunk() {
555        // nchunks=1 — no inter-chunk scan; checks the intra-chunk + state-passing
556        // boundary case where K4 runs a single iteration.
557        run_minimal_matches_serial(2, 1, 4, 2, 8, 2, 8);
558    }
559}