Skip to main content

burn_mamba/mamba3/double_ssd/ssd/serial_recalculated/
serial_recalculated.rs

1//! # Serial SSD with a custom, memory-efficient backward (Mamba-3 double-SSD)
2//!
3//! The `SerialRecalculated` path for the double-SSD pathway.  The forward is the
4//! same serial scan as [`super::super::serial`], routed through the
5//! [`Mamba3DoubleSsdBackendExt`] trait so that `Autodiff` backends substitute a
6//! custom backward that recomputes per-chunk intermediates instead of storing
7//! them (see [`super::backward`] / [`super::combined_backward`]).  Plain
8//! backends use the trait's default body, which replays the serial kernels.
9//!
10//! The default body runs under a generic backend `B`, where the high-level
11//! `Tensor` (pinned to `Dispatch`) is unavailable, so the K1–K5 math goes
12//! through the rank-tagged [`F`] primitive wrapper.  K1/K2/K4 are reused by the
13//! recompute backward in [`super::combined_backward`]; K5 is forward-only.
14
15#![allow(non_snake_case)]
16
17use crate::mamba3::double_ssd::prelude::*;
18use crate::utils::fprim::{F, san};
19use burn::backend::tensor::FloatTensor;
20use burn::backend::*;
21use burn::backend::{Backend, Dispatch, backend_extension};
22use burn::tensor::Tensor;
23use burn::tensor::s;
24
25impl Mamba3DoubleSsdInput {
26    /// MIMO-first Serial SSD with recalculated backward.
27    ///
28    /// Delegates the full K1-K5 computation to [`Mamba3DoubleSsdBackendExt::double_ssd_serial_recalculated`]
29    /// which can provide a memory-efficient custom backward for supported backends.
30    ///
31    /// Falls back to the standard K1-K5 serial computation on unsupported backends.
32    ///
33    /// # Returns
34    /// - `y_bnlmhp`:         `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
35    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
36    pub fn double_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>) {
37        let input = self;
38        assert!(
39            input.init_state_hpr.is_none(),
40            "init_state_hpr not yet implemented for ssd_serial_recalculated"
41        );
42
43        let (y_bnlmhp, final_state_bhpr) =
44            <Dispatch as Mamba3DoubleSsdBackendExt>::double_ssd_serial_recalculated(
45                input.v_bnlmhp.into_primitive(),
46                input.da_bnlh.into_primitive(),
47                input.b_bnlmhr.into_primitive(),
48                input.c_bnlmhr.into_primitive(),
49                input.initial_state_bhpr.into_primitive(),
50            );
51        let y_bnlmhp = Tensor::from_primitive(y_bnlmhp);
52        let final_state_bhpr = Tensor::from_primitive(final_state_bhpr);
53        (y_bnlmhp, final_state_bhpr)
54    }
55}
56
57/// Extends the backend for the memory-efficient serial recalculated SSD.
58///
59/// The default implementation runs K1-K5 using primitive tensor operations.
60/// Backends that support a custom memory-efficient backward (specifically the
61/// Autodiff wrapper) override this to recompute forward intermediates during
62/// the backward pass instead of saving them.
63#[backend_extension(
64    Cpu:  cfg(feature = "backend-cpu"),
65    Cuda: cfg(feature = "backend-cuda"),
66    Rocm:  cfg(feature = "backend-rocm"),
67    Metal:  cfg(feature = "backend-metal"),
68    Vulkan:  cfg(feature = "backend-vulkan"),
69    Wgpu:  cfg(feature = "backend-wgpu"),
70    WebGpu:  cfg(feature = "backend-webgpu"),
71    Flex:  cfg(feature = "backend-flex"),
72    NdArray:  cfg(feature = "backend-ndarray"),
73    LibTorch:  cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu")),
74    Autodiff:  cfg(feature = "autodiff"),
75)]
76pub trait Mamba3DoubleSsdBackendExt: Backend {
77    /// Memory-efficient MIMO serial SSD.
78    ///
79    /// # Arguments
80    /// - `v_bnlmhp`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
81    /// - `da_bnlh`:            `[batch, nchunks, chunk_len, nheads]` — pre-combined Δ·A
82    /// - `b_bnlmhr`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
83    /// - `c_bnlmhr`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
84    /// - `initial_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
85    ///
86    /// # Returns
87    /// - `y_bnlmhp`:         `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
88    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
89    fn double_ssd_serial_recalculated(
90        v_bnlmhp: FloatTensor<Self>,
91        da_bnlh: FloatTensor<Self>,
92        b_bnlmhr: FloatTensor<Self>,
93        c_bnlmhr: FloatTensor<Self>,
94        initial_state_bhpr: FloatTensor<Self>,
95    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
96        // Default impl: replicate Mamba3::double_ssd_serial (K1-K5) on primitives.
97        let v_bnlmhp = F::<Self, 6>::new(v_bnlmhp);
98        let da_bnlh = F::<Self, 4>::new(da_bnlh);
99        let b_bnlmhr = F::<Self, 6>::new(b_bnlmhr);
100        let c_bnlmhr = F::<Self, 6>::new(c_bnlmhr);
101        let initial_state_bhpr = F::<Self, 4>::new(initial_state_bhpr);
102
103        let nchunks = v_bnlmhp.dims()[1];
104        assert!(nchunks > 0, "sequence length must be at least 1");
105
106        let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum::<Self>(da_bnlh);
107        san(&da_cumsum_bhnl);
108
109        let cb_bnhLMLM = k2_ssd_bmm::<Self>(c_bnlmhr.clone(), b_bnlmhr.clone());
110        san(&cb_bnhLMLM);
111
112        let intra_chunk_state_bnhpr =
113            k3_ssd_chunk_state::<Self>(v_bnlmhp.clone(), b_bnlmhr, da_cumsum_bhnl.clone());
114        san(&intra_chunk_state_bnhpr);
115
116        let (chunk_input_state_bnhpr, final_state_bhpr) = k4_ssd_state_passing::<Self>(
117            intra_chunk_state_bnhpr,
118            da_chunk_end_bhn,
119            initial_state_bhpr,
120        );
121        san(&chunk_input_state_bnhpr);
122        san(&final_state_bhpr);
123
124        let y_bnlmhp = k5_ssd_chunk_scan::<Self>(
125            da_cumsum_bhnl,
126            v_bnlmhp,
127            c_bnlmhr,
128            cb_bnhLMLM,
129            chunk_input_state_bnhpr,
130        );
131        san(&y_bnlmhp);
132
133        (y_bnlmhp.inner(), final_state_bhpr.inner())
134    }
135}
136
137crate::decl_ssd_autodiff_backend_ext!(Mamba3DoubleSsdAutodiffBackendExt, Mamba3DoubleSsdBackendExt);
138
139// ---------------------------------------------------------------------------
140// Per-backend impls: each delegates to the trait's default (K1-K5) body. The
141// custom autodiff backward lives in `super::backward` as a separate impl.
142// ---------------------------------------------------------------------------
143crate::impl_ssd_backend_ext_for_burn_backends!(Mamba3DoubleSsdBackendExt);
144
145// ─── Primitive forward kernels (K1–K5) ───────────────────────────────────────
146// Primitive ports of the high-level [`super::super::serial`] kernels, expressed
147// on `B`'s primitives via [`F`] so the trait default body can run under a
148// generic backend. K1/K2/K4 are reused by the recompute backward in
149// [`super::combined_backward`]; K5 is forward-only (the backward computes K5's
150// gradient analytically rather than recomputing it).
151
152/// Primitive port of [`super::super::serial::k1_ssd_chunk_cumsum`].
153///
154/// Returns the intra-chunk cumsum `da_cumsum_bhnl` and the per-chunk last value
155/// `da_chunk_end_bhn`.
156pub(crate) fn k1_ssd_chunk_cumsum<B: Backend>(da_bnlh: F<B, 4>) -> (F<B, 4>, F<B, 3>) {
157    let da_bhnl = da_bnlh.permute([0, 3, 1, 2]);
158    let da_cumsum_bhnl = da_bhnl.cumsum(3);
159    let da_chunk_end_bhn = da_cumsum_bhnl
160        .clone()
161        .slice(s![.., .., .., -1])
162        .squeeze_dim::<3>(3);
163    (da_cumsum_bhnl, da_chunk_end_bhn)
164}
165
166/// Primitive port of [`super::super::serial::k2_ssd_bmm`] (fused `L·M`).
167///
168/// Returns the intra-chunk `C·Bᵀ` block matrix `cb_bnhLMLM`.
169pub(crate) fn k2_ssd_bmm<B: Backend>(c_bnlmhr: F<B, 6>, b_bnlmhr: F<B, 6>) -> F<B, 5> {
170    let [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank] = c_bnlmhr.dims();
171    let c_bnLMhr = c_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
172    let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
173    let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
174    let b_bnhrLM = b_bnLMhr.permute([0, 1, 3, 4, 2]);
175    c_bnhLMr.matmul(b_bnhrLM)
176}
177
178/// Primitive port of [`super::super::serial::k3_ssd_chunk_state`] (lean:
179/// returns only the chunk-end state).
180///
181/// Returns `intra_chunk_state_bnhpr` — each chunk's contribution to its end
182/// state assuming a zero state at the chunk's start. `v_bnlmhp` is pre-scaled.
183pub(crate) fn k3_ssd_chunk_state<B: Backend>(
184    v_bnlmhp: F<B, 6>,
185    b_bnlmhr: F<B, 6>,
186    da_cumsum_bhnl: F<B, 4>,
187) -> F<B, 5> {
188    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
189    let [.., state_rank] = b_bnlmhr.dims();
190
191    let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
192    let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
193
194    let da_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
195    let da_cumsum_bhnLM = da_cumsum_bhnl
196        .unsqueeze_dim::<5>(4) // da_cumsum_bhnl1
197        .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // da_cumsum_bhnlm
198        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // da_cumsum_bhnLM
199    let decay_bhnLM = (da_cumsum_last_bhn1 - da_cumsum_bhnLM).exp();
200
201    let decay_bnLMh1 = decay_bhnLM.permute([0, 2, 3, 1]).unsqueeze_dim::<5>(4); // decay_bnLMh1
202    let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp;
203
204    let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
205    let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
206    let intra_chunk_state_bnhpr = decayed_v_bnhpLM.matmul(b_bnhLMr);
207    assert_eq!(
208        [batch, nchunks, nheads, per_head_dim, state_rank],
209        intra_chunk_state_bnhpr.dims()
210    );
211    intra_chunk_state_bnhpr
212}
213
214/// Primitive port of [`super::super::serial::k4_ssd_state_passing`].
215///
216/// Returns the per-chunk input-state stream `chunk_input_state_bnhpr` and the
217/// `final_state_bhpr`.
218pub(crate) fn k4_ssd_state_passing<B: Backend>(
219    intra_chunk_state_bnhpr: F<B, 5>,
220    da_chunk_end_bhn: F<B, 3>,
221    initial_state_bhpr: F<B, 4>,
222) -> (F<B, 5>, F<B, 4>) {
223    let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
224
225    let mut running_state_bhpr = initial_state_bhpr;
226    let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
227    chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
228
229    for i_chunk in 0..nchunks {
230        let intra_state_bhpr = intra_chunk_state_bnhpr
231            .clone()
232            .slice(s![.., i_chunk, .., .., ..])
233            .squeeze_dim::<4>(1);
234        let decay_bhpr = da_chunk_end_bhn
235            .clone()
236            .slice(s![.., .., i_chunk])
237            .unsqueeze_dim::<4>(3)
238            .exp()
239            .expand([batch, nheads, per_head_dim, state_rank]);
240        running_state_bhpr = decay_bhpr * running_state_bhpr + intra_state_bhpr;
241        chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
242    }
243
244    let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
245    let chunk_input_state_bnhpr = F::stack(chunk_input_state_vec_bhpr, 1);
246    (chunk_input_state_bnhpr, final_state_bhpr)
247}
248
249/// Primitive port of [`super::super::serial::k5_ssd_chunk_scan`].
250///
251/// Combines the intra-chunk (ORANGE, MIMO causal) and inter-chunk (BLUE,
252/// state-carried) contributions into the output `y_bnlmhp`. No `D` skip is
253/// applied — the caller handles it. Forward-only.
254fn k5_ssd_chunk_scan<B: Backend>(
255    da_cumsum_bhnl: F<B, 4>,
256    v_bnlmhp: F<B, 6>,
257    c_bnlmhr: F<B, 6>,
258    cb_bnhLMLM: F<B, 5>,
259    chunk_input_state_bnhpr: F<B, 5>,
260) -> F<B, 6> {
261    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
262    let [.., state_rank] = c_bnlmhr.dims();
263    let device = v_bnlmhp.device();
264    let dtype = v_bnlmhp.dtype();
265
266    let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
267    let c_bnLMhr = c_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
268
269    // Expand base da_cumsum to fused length: [b, nheads, n, l] → [b, nheads, n, L]
270    let da_cumsum_bhnLM = da_cumsum_bhnl
271        .unsqueeze_dim::<5>(4) // da_cumsum_bhnl1
272        .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // da_cumsum_bhnlm
273        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // da_cumsum_bhnLM
274
275    // ── BLUE (Y_off): exp(cumA[i]) · C[i] · h[n-1] ─────────────────────
276    let exp_da_bnhLMp = da_cumsum_bhnLM
277        .clone()
278        .exp()
279        .permute([0, 2, 1, 3]) // exp_da_bnhLM
280        .unsqueeze_dim::<5>(4) // exp_da_bnhLM1
281        .expand([batch, nchunks, nheads, chunk_len * mimo_rank, per_head_dim]);
282    let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
283    let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
284    let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
285    let blue_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
286
287    // ── ORANGE (Y_diag): MIMO causal decay matrix · CB @ V ────────────────────
288    let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]);
289    let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
290        .clone()
291        .unsqueeze_dim::<5>(4) // da_cumsum_bnhLM1
292        .expand([
293            batch,
294            nchunks,
295            nheads,
296            chunk_len * mimo_rank,
297            chunk_len * mimo_rank,
298        ]);
299    let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
300        .unsqueeze_dim::<5>(3) // da_cumsum_bnh1LM
301        .expand([
302            batch,
303            nchunks,
304            nheads,
305            chunk_len * mimo_rank,
306            chunk_len * mimo_rank,
307        ]);
308    let diff_da_cumsum_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
309
310    // MIMO causal neg-inf mask: −∞ where j//m > i//m (source strictly ahead of
311    // target in time). Built as interleaved expansion of the standard
312    // 2-dimensional upper-triangle mask.
313    let neg_inf_base_bnhll =
314        F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype)
315            .triu(1) // [chunk_len, chunk_len]: -inf above diagonal
316            .unsqueeze_dims::<5>(&[0, 1, 2]) // neg_inf_base_111ll
317            .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // neg_inf_base_bnhll
318    let neg_inf_bnhLMLM = neg_inf_base_bnhll
319        .unsqueeze_dim::<6>(4) // neg_inf_base_bnhl1l
320        .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len]) // neg_inf_base_bnhlml
321        .reshape([batch, nchunks, nheads, chunk_len * mimo_rank, chunk_len]) // neg_inf_base_bnhLMl
322        .unsqueeze_dim::<6>(5) // neg_inf_base_bnhLMl1
323        .expand([
324            batch,
325            nchunks,
326            nheads,
327            chunk_len * mimo_rank,
328            chunk_len,
329            mimo_rank,
330        ]) // neg_inf_base_bnhLMlm
331        .reshape([
332            batch,
333            nchunks,
334            nheads,
335            chunk_len * mimo_rank,
336            chunk_len * mimo_rank,
337        ]); // neg_inf_bnhLMLM
338
339    let decay_bnhLMLM = (diff_da_cumsum_bnhLMLM + neg_inf_bnhLMLM).exp();
340
341    let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
342    let orange_bnhLMp = (cb_bnhLMLM * decay_bnhLMLM).matmul(v_bnhLMp);
343
344    // ── Combine and reshape ────────────────────────────────────────────────────
345    let y_bnlmhp = (blue_bnhLMp + orange_bnhLMp)
346        .permute([0, 1, 3, 2, 4]) // y_bnLMhp
347        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]); // y_bnlmhp
348    y_bnlmhp
349}