Skip to main content

burn_mamba/mamba3/single_ssd/ssd/serial_recalculated/
serial_recalculated.rs

1//! # Serial SSD with a custom, memory-efficient backward (Mamba-3 single-SSD)
2//!
3//! The `SerialRecalculated` path for the single-SSD pathway.  The forward is the
4//! same serial scan as [`super::super::serial`], routed through the
5//! [`Mamba3SingleSsdBackendExt`] trait so `Autodiff` backends substitute a
6//! custom backward that recomputes per-chunk intermediates rather than storing
7//! them (see [`super::backward`] / [`super::combined_backward`]).  Unlike the
8//! double-SSD form, the kernels here apply the trapezoid `gamma`/`scale` and the
9//! boundary-β seed internally, so the backward also returns `d_gamma`/`d_scale`.
10//!
11//! The default body runs under a generic backend `B`, where the high-level
12//! `Tensor` (pinned to `Dispatch`) is unavailable, so the K1–K5 math goes
13//! through the rank-tagged [`F`] primitive wrapper.  K1–K4 are mode-agnostic and
14//! reused from the double-SSD forward; only the single-SSD K5 (strict-lower
15//! intra-chunk + γ-correction) is owned here, and it is forward-only.
16
17#![allow(non_snake_case)]
18
19use crate::mamba3::double_ssd::ssd::serial_recalculated::{
20    k1_ssd_chunk_cumsum, k2_ssd_bmm, k3_ssd_chunk_state, k4_ssd_state_passing,
21};
22use crate::mamba3::single_ssd::prelude::*;
23use crate::utils::fprim::{F, san};
24use burn::backend::tensor::FloatTensor;
25use burn::backend::*;
26use burn::backend::{Backend, Dispatch, backend_extension};
27use burn::tensor::Tensor;
28
29impl Mamba3SingleSsdInput {
30    /// MIMO-first single-ssd form Serial SSD with recalculated backward.
31    ///
32    /// Delegates the full K1–K5 (single-ssd) computation to
33    /// [`Mamba3SingleSsdBackendExt::single_ssd_serial_recalculated`], which can provide
34    /// a memory-efficient custom backward for supported backends (the Autodiff
35    /// wrapper) and falls back to the standard K1–K5 forward on others.
36    ///
37    /// # Returns
38    /// - `y_bnlmhp`:         `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
39    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
40    pub fn single_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>) {
41        let input = self;
42        input.sanity();
43        assert!(
44            input.init_state_hpr.is_none(),
45            "init_state_hpr not yet implemented for single_ssd_serial_recalculated"
46        );
47
48        let (y_bnlmhp, final_state_bhpr) =
49            <Dispatch as Mamba3SingleSsdBackendExt>::single_ssd_serial_recalculated(
50                input.v_bnlmhp.into_primitive(),
51                input.da_bnlh.into_primitive(),
52                input.b_bnlmhr.into_primitive(),
53                input.c_bnlmhr.into_primitive(),
54                input.gamma_bnlh.into_primitive(),
55                input.scale_bnlh.into_primitive(),
56                input.initial_state_bhpr.into_primitive(),
57            );
58        let y_bnlmhp = Tensor::from_primitive(y_bnlmhp);
59        let final_state_bhpr = Tensor::from_primitive(final_state_bhpr);
60        (y_bnlmhp, final_state_bhpr)
61    }
62}
63
64/// Extends the backend for the memory-efficient single-ssd form serial SSD.
65///
66/// The default implementation runs K1–K5 using primitive tensor operations,
67/// reusing the mode-agnostic K1/K2/K3/K4 from the double-SSD forward and the
68/// single-ssd form K5 below. Backends that support a custom memory-efficient
69/// backward (the Autodiff wrapper) override this to recompute forward
70/// intermediates during backward instead of saving them.
71#[backend_extension(
72    Cpu:  cfg(feature = "backend-cpu"),
73    Cuda: cfg(feature = "backend-cuda"),
74    Rocm:  cfg(feature = "backend-rocm"),
75    Metal:  cfg(feature = "backend-metal"),
76    Vulkan:  cfg(feature = "backend-vulkan"),
77    Wgpu:  cfg(feature = "backend-wgpu"),
78    WebGpu:  cfg(feature = "backend-webgpu"),
79    Flex:  cfg(feature = "backend-flex"),
80    NdArray:  cfg(feature = "backend-ndarray"),
81    LibTorch:  cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu")),
82    Autodiff:  cfg(feature = "autodiff"),
83)]
84pub trait Mamba3SingleSsdBackendExt: Backend {
85    /// Memory-efficient MIMO single-ssd form serial SSD.
86    ///
87    /// # Arguments
88    /// - `v_bnlmhp`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
89    /// - `da_bnlh`:            `[batch, nchunks, chunk_len, nheads]` — pre-combined Δ·A
90    /// - `b_bnlmhr`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
91    /// - `c_bnlmhr`:           `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
92    /// - `gamma_bnlh`:         `[batch, nchunks, chunk_len, nheads]` — `γₜ = λₜ Δₜ`
93    /// - `scale_bnlh`:         `[batch, nchunks, chunk_len, nheads]` — `scaleₜ = γₜ + (1−λₜ₊₁)Δₜ₊₁`
94    /// - `initial_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
95    ///
96    /// # Returns
97    /// - `y_bnlmhp`:         `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
98    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
99    fn single_ssd_serial_recalculated(
100        v_bnlmhp: FloatTensor<Self>,
101        da_bnlh: FloatTensor<Self>,
102        b_bnlmhr: FloatTensor<Self>,
103        c_bnlmhr: FloatTensor<Self>,
104        gamma_bnlh: FloatTensor<Self>,
105        scale_bnlh: FloatTensor<Self>,
106        initial_state_bhpr: FloatTensor<Self>,
107    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
108        // Default impl: replicate the single-ssd form K1–K5 on primitives.
109        let v_bnlmhp = F::<Self, 6>::new(v_bnlmhp);
110        let da_bnlh = F::<Self, 4>::new(da_bnlh);
111        let b_bnlmhr = F::<Self, 6>::new(b_bnlmhr);
112        let c_bnlmhr = F::<Self, 6>::new(c_bnlmhr);
113        let gamma_bnlh = F::<Self, 4>::new(gamma_bnlh);
114        let scale_bnlh = F::<Self, 4>::new(scale_bnlh);
115        let initial_state_bhpr = F::<Self, 4>::new(initial_state_bhpr);
116
117        // K1 — chunk cumulative decay.
118        let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum::<Self>(da_bnlh);
119        san(&da_cumsum_bhnl);
120
121        // K2 — CB matrix on unscaled B/C.
122        let cb_bnhLMLM = k2_ssd_bmm::<Self>(c_bnlmhr.clone(), b_bnlmhr.clone());
123        san(&cb_bnhLMLM);
124
125        // K3 — chunk state on K_scaled = scaleₜ · B.
126        let scale_bnlh11 = scale_bnlh.clone().unsqueeze_dims::<6>(&[3, 5]);
127        let k_scaled_bnlmhr = b_bnlmhr.clone() * scale_bnlh11;
128        let intra_chunk_state_bnhpr =
129            k3_ssd_chunk_state::<Self>(v_bnlmhp.clone(), k_scaled_bnlmhr, da_cumsum_bhnl.clone());
130        san(&intra_chunk_state_bnhpr);
131
132        // K4 — sequential state passing.
133        let (chunk_input_state_bnhpr, final_state_bhpr) = k4_ssd_state_passing::<Self>(
134            intra_chunk_state_bnhpr,
135            da_chunk_end_bhn,
136            initial_state_bhpr,
137        );
138        san(&chunk_input_state_bnhpr);
139        san(&final_state_bhpr);
140
141        // K5 — single-ssd form chunk scan.
142        let y_bnlmhp = k5_single_ssd_chunk_scan::<Self>(
143            da_cumsum_bhnl,
144            v_bnlmhp,
145            c_bnlmhr,
146            b_bnlmhr,
147            cb_bnhLMLM,
148            gamma_bnlh,
149            scale_bnlh,
150            chunk_input_state_bnhpr,
151        );
152        san(&y_bnlmhp);
153
154        (y_bnlmhp.inner(), final_state_bhpr.inner())
155    }
156}
157
158crate::decl_ssd_autodiff_backend_ext!(Mamba3SingleSsdAutodiffBackendExt, Mamba3SingleSsdBackendExt);
159
160// ---------------------------------------------------------------------------
161// Per-backend impls: each delegates to the trait's default (K1–K5) body. The
162// custom autodiff backward lives in `super::backward` as a separate impl.
163// ---------------------------------------------------------------------------
164crate::impl_ssd_backend_ext_for_burn_backends!(Mamba3SingleSsdBackendExt);
165
166// ─── Primitive forward K5 (single-ssd) ───────────────────────────────────────
167// Primitive port of [`super::super::serial::k5_single_ssd_chunk_scan`]. Combines
168// the strict-lower intra-chunk (LOWER), the γ-weighted same-step (DIAG), and the
169// state-to-output (BLUE/Y_off) contributions. Forward-only; the backward
170// computes these gradients analytically. K1–K4 are reused from the double-SSD
171// forward (mode-agnostic).
172
173/// SingleSsd chunk scan, on primitives.
174///
175/// - **Strict lower triangular intra-chunk** (`t1 > t2`):
176///   `(cb[i,j] · scale[t2] · exp(cumA[t1] − cumA[t2])) · V[t2]`
177/// - **Same-time-step (`t1 == t2`) γ-correction**:
178///   `γ[t] · (Σₙ C[t,r_out,n] · B[t,r_in,n]) · V[t,r_in,p]`
179/// - **State-to-output (Y_off)**: `exp(cumA[t]) · C[t] · h'[n-1]`
180///
181/// `cb_bnhLMLM` is the unscaled `C · Bᵀ` from K2; `b_bnlmhr` is the unscaled
182/// K/B tensor (γ-correction matmul).
183#[allow(clippy::too_many_arguments)]
184fn k5_single_ssd_chunk_scan<B: Backend>(
185    da_cumsum_bhnl: F<B, 4>,
186    v_bnlmhp: F<B, 6>,
187    c_bnlmhr: F<B, 6>,
188    b_bnlmhr: F<B, 6>,
189    cb_bnhLMLM: F<B, 5>,
190    gamma_bnlh: F<B, 4>,
191    scale_bnlh: F<B, 4>,
192    chunk_input_state_bnhpr: F<B, 5>,
193) -> F<B, 6> {
194    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
195    let [.., state_rank] = c_bnlmhr.dims();
196    let device = v_bnlmhp.device();
197    let dtype = v_bnlmhp.dtype();
198    let fused = chunk_len * mimo_rank;
199
200    // Fuse mimo_rank into chunk_len for the SSM-style matmul.
201    let v_bnLMhp = v_bnlmhp
202        .clone()
203        .reshape([batch, nchunks, fused, nheads, per_head_dim]);
204    let c_bnLMhr = c_bnlmhr
205        .clone()
206        .reshape([batch, nchunks, fused, nheads, state_rank]);
207
208    // Per-fused-step cumulative decay (interleave-expand the base grid).
209    let da_cumsum_bhnLM = da_cumsum_bhnl
210        .unsqueeze_dim::<5>(4)
211        .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
212        .reshape([batch, nheads, nchunks, fused]);
213
214    // ── Y_off: exp(cumA[t]) · C[t] · h'[n-1] ────────────────────────────────
215    let exp_da_bnhLMp = da_cumsum_bhnLM
216        .clone()
217        .exp()
218        .permute([0, 2, 1, 3]) // bnhLM
219        .unsqueeze_dim::<5>(4) // bnhLM1
220        .expand([batch, nchunks, nheads, fused, per_head_dim]);
221    let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
222    let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
223    let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
224    let y_off_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
225
226    // ── Y_lower: strict lower-tri intra-chunk with scale and decay ──────────
227    let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]); // bnhLM
228    let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
229        .clone()
230        .unsqueeze_dim::<5>(4) // bnhLM1
231        .expand([batch, nchunks, nheads, fused, fused]);
232    let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
233        .unsqueeze_dim::<5>(3) // bnh1LM
234        .expand([batch, nchunks, nheads, fused, fused]);
235    let diff_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
236
237    // Strict-upper -inf mask on the base time grid (`t1 <= t2` → -inf), then
238    // interleave-expand to fused length so MIMO same-time blocks are zeroed.
239    let inf_upper_bnhLMLM =
240        F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype)
241            .triu(0) // upper triangle INCLUDING diagonal
242            .unsqueeze_dims::<5>(&[0, 1, 2])
243            .expand([batch, nchunks, nheads, chunk_len, chunk_len])
244            .unsqueeze_dim::<6>(4)
245            .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len])
246            .reshape([batch, nchunks, nheads, fused, chunk_len])
247            .unsqueeze_dim::<6>(5)
248            .expand([batch, nchunks, nheads, fused, chunk_len, mimo_rank])
249            .reshape([batch, nchunks, nheads, fused, fused]);
250    let decay_strict_bnhLMLM = (diff_bnhLMLM + inf_upper_bnhLMLM).exp();
251
252    // Per-column scale: `scale[t2]` lives on the source axis (column).
253    let scale_bnhLM = scale_bnlh
254        .permute([0, 1, 3, 2]) // bnhl
255        .unsqueeze_dim::<5>(4) // bnhl1
256        .expand([batch, nchunks, nheads, chunk_len, mimo_rank])
257        .reshape([batch, nchunks, nheads, fused]);
258    let scale_col_bnhLMLM = scale_bnhLM
259        .unsqueeze_dim::<5>(3) // bnh1LM
260        .expand([batch, nchunks, nheads, fused, fused]);
261
262    let kernel_bnhLMLM = decay_strict_bnhLMLM * scale_col_bnhLMLM;
263    let masked_cb_bnhLMLM = cb_bnhLMLM * kernel_bnhLMLM;
264    let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
265    let y_lower_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
266
267    // ── Y_diag: γ-weighted same-step correction ─────────────────────────────
268    let c_bnlhmr = c_bnlmhr.permute([0, 1, 2, 4, 3, 5]);
269    let b_bnlhrm = b_bnlmhr.permute([0, 1, 2, 4, 5, 3]);
270    let qk_dot_bnlhmM = c_bnlhmr.matmul(b_bnlhrm); // bnlhm_outm_in
271    let v_bnlhmp = v_bnlmhp.permute([0, 1, 2, 4, 3, 5]);
272    let y_d_bnlhmp = qk_dot_bnlhmM.matmul(v_bnlhmp); // bnlhm_outp
273    let gamma_bnlh11 = gamma_bnlh.unsqueeze_dims::<6>(&[4, 5]);
274    let y_d_bnlhmp_scaled = y_d_bnlhmp * gamma_bnlh11;
275
276    let y_diag_bnlmhp = y_d_bnlhmp_scaled.permute([0, 1, 2, 4, 3, 5]);
277    let y_diag_bnLMhp = y_diag_bnlmhp.reshape([batch, nchunks, fused, nheads, per_head_dim]);
278    let y_diag_bnhLMp = y_diag_bnLMhp.permute([0, 1, 3, 2, 4]);
279
280    // ── Combine and reshape ─────────────────────────────────────────────────
281    let y_bnhLMp = y_off_bnhLMp + y_lower_bnhLMp + y_diag_bnhLMp;
282    let y_bnLMhp = y_bnhLMp.permute([0, 1, 3, 2, 4]);
283    y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim])
284}