Skip to main content

burn_mamba/mamba2/ssd/serial_recalculated/
serial_recalculated.rs

1use crate::mamba2::prelude::*;
2use crate::mamba2::ssd::serial;
3use burn::prelude::*;
4use burn::tensor::{Tensor, TensorPrimitive, ops::FloatTensor};
5
6impl<B: Backend + Mamba2BackendExt> Mamba2<B> {
7    /// Forward pass for the Mamba-2 SSD module.
8    ///
9    /// Returns:
10    /// - `y_bnlhp`.
11    /// - `final_state_bhpr`.
12    #[allow(non_snake_case)]
13    pub fn ssd_serial_recalculated(
14        input: super::super::Mamba2SsdInput<B>,
15    ) -> (Tensor<B, 5>, Tensor<B, 4>) {
16        // Must use a backend-dependent method.
17        //
18        // For inference, this will ultimately replicate Mamba2::ssd_serial;
19        // For autodiff, this will call the custom implementation.
20
21        let [batch, nchunks, chunk_len, nheads, _per_head_dim] = input.x_bnlhp.dims();
22        let [.., ngroups, _state_rank] = input.b_bnlgr.dims();
23        assert_ne!(ngroups, 0);
24        assert_eq!(nheads % ngroups, 0);
25        assert!(nchunks > 0, "sequence length must be at least 1");
26
27        assert!(
28            input.init_state_hpr.is_none(),
29            "init_state_hpr not yet implemented"
30        );
31
32        // ── Permutes ──────────────────────────────────────────────────────────────────
33        // Note: dt_bnlh calculation (originally in Kernel 1) moved to Step 4 (before padding).
34        let dt_discretized_bhnl = input.dt_bnlh.permute([0, 3, 1, 2]);
35        assert_eq!(
36            [batch, nheads, nchunks, chunk_len],
37            dt_discretized_bhnl.dims()
38        );
39
40        // K1 is now computed inside the custom op (both forward and backward).
41        // a_decay_h is passed directly; da_cumsum is no longer an autodiff-tracked
42        // intermediate crossing the boundary.
43        let (y_bnlhp, final_state_bhpr) = <B as Mamba2BackendExt>::ssd_serial_recalculated(
44            input.x_bnlhp.into_primitive().tensor(),
45            dt_discretized_bhnl.into_primitive().tensor(),
46            input.b_bnlgr.into_primitive().tensor(),
47            input.c_bnlgr.into_primitive().tensor(),
48            input.d_h.into_primitive().tensor(),
49            input.initial_state_bhpr.into_primitive().tensor(),
50            input.a_decay_h.into_primitive().tensor(),
51        );
52        let y_bnlhp = Tensor::from_primitive(TensorPrimitive::Float(y_bnlhp));
53        let final_state_bhpr = Tensor::from_primitive(TensorPrimitive::Float(final_state_bhpr));
54        (y_bnlhp, final_state_bhpr)
55    }
56}
57
58/// Extends the backend and wraps it for `burn`.
59pub trait Mamba2BackendExt: burn::tensor::backend::Backend {
60    /// Returns:
61    /// - `y_bnlhp`.
62    /// - `final_state_bhpr`.
63    fn ssd_serial_recalculated(
64        x_bnlhp: FloatTensor<Self>,
65        dt_discretized_bhnl: FloatTensor<Self>,
66        b_bnlgr: FloatTensor<Self>,
67        c_bnlgr: FloatTensor<Self>,
68        d_h: FloatTensor<Self>,
69        initial_state_bhpr: FloatTensor<Self>,
70        a_decay_h: FloatTensor<Self>,
71    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
72        // Default impl essentially replicates Mamba2::ssd_serial.
73
74        let x_bnlhp: Tensor<Self, 5> = mk(x_bnlhp);
75        let dt_discretized_bhnl: Tensor<Self, 4> = mk(dt_discretized_bhnl);
76        let b_bnlgr: Tensor<Self, 5> = mk(b_bnlgr);
77        let c_bnlgr: Tensor<Self, 5> = mk(c_bnlgr);
78        let d_h: Tensor<Self, 1> = mk(d_h);
79        let initial_state_bhpr: Tensor<Self, 4> = mk(initial_state_bhpr);
80        let a_decay_h: Tensor<Self, 1> = mk(a_decay_h);
81
82        let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
83        let [.., ngroups, state_rank] = b_bnlgr.dims();
84        assert_ne!(ngroups, 0);
85        assert_eq!(nheads % ngroups, 0);
86        assert!(nchunks > 0, "sequence length must be at least 1");
87        // `heads_per_group` is called `nheads_ngroups_ratio` in every Triton kernel.
88        // It is the compile-time constant used by GQA (Grouped Query Attention) to map
89        // a head index to its B/C group: `group_idx = head_idx / heads_per_group`.
90
91        // ── Kernel 1 ──────────────────────────────────────────────────────────
92        let (da_cumsum_bhnl, da_chunk_end_bhn) =
93            serial::k1_ssd_chunk_cumsum(dt_discretized_bhnl.clone(), a_decay_h);
94        assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
95        assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
96
97        // ── Kernel 2 ──────────────────────────────────────────────────────────
98        // IO: (..) -> (cb_bngll [used in K5][!])
99        let cb_bngll: Tensor<Self, 5> = serial::k2_ssd_bmm(c_bnlgr.clone(), b_bnlgr.clone());
100        assert_eq!(
101            [batch, nchunks, ngroups, chunk_len, chunk_len],
102            cb_bngll.dims()
103        );
104        // Note: cb_bngll is then only used by Kernel 5.
105
106        // ── Kernel 3 ──────────────────────────────────────────────────────────
107        // IO: (..) -> (intra_chunk_state_bnhpr [used in K4][!])
108        let intra_chunk_state_bnhpr: Tensor<Self, 5> = serial::k3_ssd_chunk_state(
109            x_bnlhp.clone(),
110            b_bnlgr.clone(),
111            da_cumsum_bhnl.clone(),
112            dt_discretized_bhnl.clone(),
113        );
114        assert_eq!(
115            [batch, nchunks, nheads, per_head_dim, state_rank],
116            intra_chunk_state_bnhpr.dims()
117        );
118
119        // ── Kernel 4 ──────────────────────────────────────────────────────────
120        // IO: (..) -> (chunk_input_state_bnhpr [used in K5][!], final_state_bhpr [final output])
121        let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<Self, 5>, Tensor<Self, 4>) =
122            serial::k4_ssd_state_passing(
123                intra_chunk_state_bnhpr.clone(),
124                da_chunk_end_bhn.clone(),
125                initial_state_bhpr,
126            );
127        assert_eq!(
128            [batch, nchunks, nheads, per_head_dim, state_rank],
129            chunk_input_state_bnhpr.dims()
130        );
131        assert_eq!(
132            [batch, nheads, per_head_dim, state_rank],
133            final_state_bhpr.dims()
134        );
135
136        // ── Kernel 5 ──────────────────────────────────────────────────────────
137        let y_bnlhp: Tensor<Self, 5> = serial::k5_ssd_chunk_scan(
138            da_cumsum_bhnl,
139            dt_discretized_bhnl,
140            x_bnlhp,
141            c_bnlgr,
142            cb_bngll,
143            chunk_input_state_bnhpr,
144            d_h,
145        );
146        assert_eq!(
147            [batch, nchunks, chunk_len, nheads, per_head_dim],
148            y_bnlhp.dims()
149        );
150
151        let y_bnlhp = y_bnlhp.into_primitive().tensor();
152        let final_state_bhpr = final_state_bhpr.into_primitive().tensor();
153        (y_bnlhp, final_state_bhpr)
154    }
155}
156
157// For inference and for any backend, fallback to the default impl (to Mamba2::ssd_serial).
158//
159// impl<B: Backend> Mamba2BackendExt for B {}
160// Note: cannot generally implement as above as it conflicts with the custom autodiff impl.
161// So it's necessary to implement for each backend.
162//
163// TODO: somehow avoid leaking backend-* features into the library
164#[cfg(feature = "backend-ndarray")]
165impl<F, I> Mamba2BackendExt for burn::backend::NdArray<F, I> {}
166#[cfg(feature = "backend-flex")]
167impl Mamba2BackendExt for burn::backend::Flex {}
168#[cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu"))]
169impl<F, I> Mamba2BackendExt for burn::backend::libtorch::LibTorch<F, I> {}
170#[cfg(feature = "backend-remote")]
171impl<F, I> Mamba2BackendExt for burn::backend::RemoteBackend<F, I> {}
172// impl for cubecl backends
173#[cfg(feature = "cubecl")]
174mod cubecl {
175    use burn_cubecl::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
176    impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> super::Mamba2BackendExt
177        for CubeBackend<R, F, I, BT>
178    {
179    }
180}
181
182// impl for fusion backends — delegates to the default impl, which runs the serial
183// computation using the inner backend's standard tensor operations.
184#[cfg(feature = "fusion")]
185mod fusion {
186    use burn_fusion::{Fusion, FusionBackend};
187    impl<B: FusionBackend + super::Mamba2BackendExt> super::Mamba2BackendExt for Fusion<B> {}
188}
189
190/// Marker for autodiff-compatible backends that are valid for the custom backward implementation.
191#[cfg(feature = "autodiff")]
192pub trait Mamba2AutodiffBackendExt:
193    Backend + Mamba2BackendExt + burn::tensor::backend::AutodiffBackend
194{
195}
196// Any autodiff-compatible backend is valid with our custom implementation
197//
198// Note: This is just a marker. The actual custom implementation is at super::serial_recalculated::backward,
199// a custom Mamba2BackendExt implementation.
200#[cfg(feature = "autodiff")]
201impl<B: Backend + Mamba2BackendExt> Mamba2AutodiffBackendExt for burn::backend::Autodiff<B> {}
202
203/// Conversion helper.
204pub(crate) fn mk<B: Backend, const D: usize>(p: FloatTensor<B>) -> Tensor<B, D> {
205    Tensor::from_primitive(TensorPrimitive::Float(p))
206}