Skip to main content

burn_mamba/mamba3/ssd/serial_recalculated/
serial_recalculated.rs

1#![allow(non_snake_case)]
2
3use crate::mamba3::prelude::*;
4use crate::mamba3::ssd::serial;
5use burn::prelude::*;
6use burn::tensor::{Tensor, TensorPrimitive, ops::FloatTensor};
7
8impl<B: Backend + Mamba3BackendExt> Mamba3<B> {
9    /// MIMO-first Serial SSD with recalculated backward.
10    ///
11    /// Computes K1 eagerly (so the cumsum is available for the backward pass),
12    /// then delegates the remaining computation to [`Mamba3BackendExt::ssd_serial_recalculated`]
13    /// which can provide a memory-efficient custom backward for supported backends.
14    ///
15    /// Falls back to the standard K2-K5 serial computation on unsupported backends.
16    ///
17    /// # Returns
18    /// - `y_bnlrhp`:        `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
19    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
20    pub fn ssd_serial_recalculated(
21        input: super::super::Mamba3SsdInput<B>,
22    ) -> (Tensor<B, 6>, Tensor<B, 4>) {
23        assert!(
24            input.init_state_hpr.is_none(),
25            "init_state_hpr not yet implemented for ssd_serial_recalculated"
26        );
27
28        // K1 runs on the tracked compute graph so the cumsum is available during backward.
29        let (da_cumsum_bhnl, _da_chunk_end_bhn): (Tensor<B, 4>, Tensor<B, 3>) =
30            serial::k1_ssd_chunk_cumsum(input.da_bnlh.clone());
31
32        let (y_bnlrhp, final_state_bhpr) = <B as Mamba3BackendExt>::ssd_serial_recalculated(
33            input.v_bnlrhp.into_primitive().tensor(),
34            input.da_bnlh.into_primitive().tensor(),
35            input.b_bnlrhn.into_primitive().tensor(),
36            input.c_bnlrhn.into_primitive().tensor(),
37            input.initial_state_bhpr.into_primitive().tensor(),
38            da_cumsum_bhnl.into_primitive().tensor(),
39        );
40        let y_bnlrhp = Tensor::from_primitive(TensorPrimitive::Float(y_bnlrhp));
41        let final_state_bhpr = Tensor::from_primitive(TensorPrimitive::Float(final_state_bhpr));
42        (y_bnlrhp, final_state_bhpr)
43    }
44}
45
46/// Extends the backend for the memory-efficient serial recalculated SSD.
47///
48/// The default implementation runs K2-K5 using standard tensor operations.
49/// Backends that support a custom memory-efficient backward can override this.
50pub trait Mamba3BackendExt: burn::tensor::backend::Backend {
51    /// Memory-efficient MIMO serial SSD.
52    ///
53    /// # Arguments
54    /// - `v_bnlrhp`:           `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
55    /// - `da_bnlh`:            `[batch, nchunks, chunk_len, nheads]` — pre-combined Δ·A
56    /// - `b_bnlrhn`:           `[batch, nchunks, chunk_len, R, nheads, state_rank]`
57    /// - `c_bnlrhn`:           `[batch, nchunks, chunk_len, R, nheads, state_rank]`
58    /// - `initial_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
59    /// - `da_cumsum_bhnl`:     `[batch, nheads, nchunks, chunk_len]` — pre-computed by K1
60    ///
61    /// # Returns
62    /// - `y_bnlrhp`:        `[batch, nchunks, chunk_len, R, nheads, per_head_dim]`
63    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
64    fn ssd_serial_recalculated(
65        v_bnlrhp: FloatTensor<Self>,
66        _da_bnlh: FloatTensor<Self>,
67        b_bnlrhn: FloatTensor<Self>,
68        c_bnlrhn: FloatTensor<Self>,
69        initial_state_bhpr: FloatTensor<Self>,
70        da_cumsum_bhnl: FloatTensor<Self>,
71    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
72        // Default impl: run K2-K5 using the pre-computed cumsum.
73
74        let v: Tensor<Self, 6> = mk(v_bnlrhp);
75        let b: Tensor<Self, 6> = mk(b_bnlrhn);
76        let c: Tensor<Self, 6> = mk(c_bnlrhn);
77        let da_cumsum: Tensor<Self, 4> = mk(da_cumsum_bhnl);
78        let init_state: Tensor<Self, 4> = mk(initial_state_bhpr);
79
80        // Recalculate da_chunk_end_bhn from the pre-computed cumsum (the "recalculated" part).
81        let da_chunk_end_bhn: Tensor<Self, 3> =
82            da_cumsum.clone().slice(s![.., .., .., -1]).squeeze_dim(3);
83
84        let cb_bnhLL: Tensor<Self, 5> = serial::k2_ssd_bmm(c.clone(), b.clone());
85        let intra_chunk_state_bnhpr: Tensor<Self, 5> =
86            serial::k3_ssd_chunk_state(v.clone(), b, da_cumsum.clone());
87        let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<Self, 5>, Tensor<Self, 4>) =
88            serial::k4_ssd_state_passing(intra_chunk_state_bnhpr, da_chunk_end_bhn, init_state);
89        let y_bnlrhp: Tensor<Self, 6> =
90            serial::k5_ssd_chunk_scan(da_cumsum, v, c, cb_bnhLL, chunk_input_state_bnhpr);
91
92        let y_bnlrhp = y_bnlrhp.into_primitive().tensor();
93        let final_state_bhpr = final_state_bhpr.into_primitive().tensor();
94        (y_bnlrhp, final_state_bhpr)
95    }
96}
97
98/// Marker for autodiff-compatible backends that support the custom MIMO backward.
99#[cfg(feature = "autodiff")]
100pub trait Mamba3AutodiffBackendExt:
101    Backend + Mamba3BackendExt + burn::tensor::backend::AutodiffBackend
102{
103}
104/// Autodiff-wrapped backends inherit the inner backend's Mamba3BackendExt impl.
105#[cfg(feature = "autodiff")]
106impl<B: Backend + Mamba3BackendExt> Mamba3BackendExt for burn::backend::Autodiff<B> {}
107/// Any autodiff-wrapped backend satisfies the marker.
108#[cfg(feature = "autodiff")]
109impl<B: Backend + Mamba3BackendExt> Mamba3AutodiffBackendExt for burn::backend::Autodiff<B> {}
110
111// ---------------------------------------------------------------------------
112// Backend impls — each backend uses the default (K2-K5) implementation.
113// ---------------------------------------------------------------------------
114
115#[cfg(feature = "backend-ndarray")]
116impl<F, I> Mamba3BackendExt for burn::backend::NdArray<F, I> {}
117#[cfg(feature = "backend-flex")]
118impl Mamba3BackendExt for burn::backend::Flex {}
119#[cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu"))]
120impl<F, I> Mamba3BackendExt for burn::backend::libtorch::LibTorch<F, I> {}
121#[cfg(feature = "backend-remote")]
122impl<F, I> Mamba3BackendExt for burn::backend::RemoteBackend<F, I> {}
123
124// CubeCL backends
125#[cfg(feature = "cubecl")]
126mod cubecl {
127    use burn_cubecl::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
128    impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> super::Mamba3BackendExt
129        for CubeBackend<R, F, I, BT>
130    {
131    }
132}
133
134// Fusion backends
135#[cfg(feature = "fusion")]
136mod fusion {
137    use burn_fusion::{Fusion, FusionBackend};
138    impl<B: FusionBackend + super::Mamba3BackendExt> super::Mamba3BackendExt for Fusion<B> {}
139}
140
141/// Conversion helper: `FloatTensor<B>` → `Tensor<B, D>`.
142pub(crate) fn mk<B: Backend, const D: usize>(p: FloatTensor<B>) -> Tensor<B, D> {
143    Tensor::from_primitive(TensorPrimitive::Float(p))
144}