Skip to main content

burn_mamba/mamba3/single_ssd/ssd/serial_recalculated/
combined_backward.rs

1//! # Recompute-based gradient math for the Mamba-3 single-SSD
2//!
3//! The analytic backward of the single-pass MIMO-first scan.  Forward
4//! intermediates (K1–K4) are recomputed from the saved leaf inputs, then a
5//! reverse per-chunk loop fuses the K5 state-to-output (BLUE), the strict
6//! lower-triangular intra-chunk (LOWER), and the K4 state-passing backwards; the
7//! γ-weighted same-step (DIAG) term is computed batched (no recurrence, tiny
8//! `m × m` tensors).  Because this pathway applies the trapezoid weights
9//! internally, it additionally returns `d_gamma` and `d_scale`.  The shared K3
10//! extended helper (and K1/K2/K4) are reused from the double-SSD module.
11//!
12//! Everything operates on backend **primitives** through the rank-tagged [`F`]
13//! wrapper: the custom [`Backward`](burn::backend::autodiff::ops::Backward) node
14//! runs with a generic backend `B`, so the high-level `Tensor` is unavailable
15//! and the math uses `B`'s `float_*` ops.
16
17#![allow(non_snake_case)]
18
19use crate::mamba3::double_ssd::ssd::serial_recalculated::combined_backward::k3_ssd_chunk_state_extended;
20use crate::mamba3::double_ssd::ssd::serial_recalculated::{
21    k1_ssd_chunk_cumsum, k2_ssd_bmm, k4_ssd_state_passing,
22};
23use crate::utils::fprim::{F, san};
24use burn::backend::Backend;
25use burn::tensor::s;
26
27/// Per-input gradients produced by [`combined_backward`] for the Single-SSD.
28/// Adds `d_gamma_bnlh` and `d_scale_bnlh` over the double-ssd form
29/// [`crate::mamba3::double_ssd::ssd::serial_recalculated::combined_backward::CombinedGrads`].
30#[non_exhaustive]
31pub struct CombinedSingleSsdGrads<B: Backend> {
32    /// Gradient of the raw input `v`.
33    pub d_v_bnlmhp: F<B, 6>,
34    /// Gradient of `Δ·A` (`da`).
35    pub d_da_bnlh: F<B, 4>,
36    /// Gradient of the input projection `B`.
37    pub d_b_bnlmhr: F<B, 6>,
38    /// Gradient of the output projection `C`.
39    pub d_c_bnlmhr: F<B, 6>,
40    /// Gradient of the same-step trapezoid weight `γ`.
41    pub d_gamma_bnlh: F<B, 4>,
42    /// Gradient of the key scale `scale = γ + (1−λ₊₁)·Δ₊₁`.
43    pub d_scale_bnlh: F<B, 4>,
44    /// Gradient of the initial SSM state.
45    pub d_initial_state_bhpr: F<B, 4>,
46}
47
48/// Memory-efficient backward for the Mamba-3 MIMO-first chunkwise Single-SSD.
49///
50/// Recomputes the forward intermediates (K1–K4) from the saved inputs, then:
51/// - runs a reverse per-chunk loop that fuses the K5 BLUE (state-to-output) and
52///   the strict lower-triangular LOWER (intra-chunk) backward with the K4
53///   state-passing backward, and
54/// - computes the γ-weighted same-step DIAG backward batched (it has no
55///   recurrence, and the `m × m` working tensors are tiny).
56///
57/// K3/K2/K1 backwards run as single batched ops once the loop has collected all
58/// per-chunk slices.
59///
60/// # Arguments
61/// - `d_y_bnlmhp` — upstream gradient of the SSD output
62/// - `d_final_bhpr` — upstream gradient of the final SSM state
63/// - `v_bnlmhp`, `da_bnlh`, `b_bnlmhr`, `c_bnlmhr`, `gamma_bnlh`, `scale_bnlh`,
64///   `initial_state_bhpr` — the seven saved forward inputs
65///
66/// # Returns
67/// One [`CombinedSingleSsdGrads`] with gradients for all 7 inputs.
68#[allow(clippy::too_many_arguments)]
69pub fn combined_backward<B: Backend>(
70    d_y_bnlmhp: F<B, 6>,
71    d_final_bhpr: F<B, 4>,
72    //
73    v_bnlmhp: F<B, 6>,
74    da_bnlh: F<B, 4>,
75    b_bnlmhr: F<B, 6>,
76    c_bnlmhr: F<B, 6>,
77    gamma_bnlh: F<B, 4>,
78    scale_bnlh: F<B, 4>,
79    initial_state_bhpr: F<B, 4>,
80) -> CombinedSingleSsdGrads<B> {
81    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
82    let [.., state_rank] = b_bnlmhr.dims();
83    let device = v_bnlmhp.device();
84    let dtype = v_bnlmhp.dtype();
85
86    san(&d_y_bnlmhp);
87    san(&d_final_bhpr);
88    san(&v_bnlmhp);
89    san(&da_bnlh);
90    san(&b_bnlmhr);
91    san(&c_bnlmhr);
92    san(&gamma_bnlh);
93    san(&scale_bnlh);
94    san(&initial_state_bhpr);
95
96    // ═══════════════════════════════════════════════════════════════════════
97    // RECOMPUTE FORWARD INTERMEDIATES (K1–K4, single-ssd form)
98    // ═══════════════════════════════════════════════════════════════════════
99
100    // K1
101    let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(da_bnlh.clone());
102    san(&da_cumsum_bhnl);
103
104    // K2 — CB matrix (unscaled), used by LOWER.
105    let cb_bnhLMLM = k2_ssd_bmm(c_bnlmhr.clone(), b_bnlmhr.clone());
106    san(&cb_bnhLMLM);
107
108    // K3 — chunk state from K_scaled = scaleₜ·B.
109    let scale_bnlh11 = scale_bnlh.clone().unsqueeze_dims::<6>(&[3, 5]);
110    let k_scaled_bnlmhr = b_bnlmhr.clone() * scale_bnlh11.clone();
111    let (intra_chunk_state_bnhpr, k3_decay_bhnLM, k3_decayed_v_bnLMhp) =
112        k3_ssd_chunk_state_extended(
113            v_bnlmhp.clone(),
114            k_scaled_bnlmhr.clone(),
115            da_cumsum_bhnl.clone(),
116        );
117
118    // K4 — chunk-input state stream consumed by BLUE.
119    let (chunk_input_state_bnhpr, _final_state_bhpr) = k4_ssd_state_passing(
120        intra_chunk_state_bnhpr,
121        da_chunk_end_bhn.clone(),
122        initial_state_bhpr,
123    );
124
125    // Fused-position cumulative decay.
126    let da_cumsum_bhnLM = da_cumsum_bhnl
127        .clone()
128        .unsqueeze_dim::<5>(4)
129        .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
130        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]);
131
132    // d_y in (batch, nchunks, nheads, chunk_len·mimo_rank, per_head_dim) ordering.
133    let d_y_bnhLMp = d_y_bnlmhp
134        .clone()
135        .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim])
136        .permute([0, 1, 3, 2, 4]);
137    san(&d_y_bnhLMp);
138
139    // ═══════════════════════════════════════════════════════════════════════
140    // DIAG BACKWARD (batched — no recurrence; m × m working set is tiny)
141    //
142    // Forward (per (b,n,l,h)):
143    //   qk_dot[m_out, m_in] = Σ_r C[m_out, r] · B[m_in, r]
144    //   y_diag[m_out, p]    = γ · Σ_{m_in} qk_dot[m_out, m_in] · V[m_in, p]
145    // ═══════════════════════════════════════════════════════════════════════
146    let (d_v_diag_bnlmhp, d_c_diag_bnlmhr, d_b_diag_bnlmhr, d_gamma_bnlh) = {
147        let c_bnlhmr = c_bnlmhr.clone().permute([0, 1, 2, 4, 3, 5]); // [b,n,l,h,m_out,r]
148        let b_bnlhmr = b_bnlmhr.clone().permute([0, 1, 2, 4, 3, 5]); // [b,n,l,h,m_in,r]
149        let v_bnlhmp = v_bnlmhp.clone().permute([0, 1, 2, 4, 3, 5]); // [b,n,l,h,m_in,p]
150        let d_y_bnlhmp = d_y_bnlmhp.clone().permute([0, 1, 2, 4, 3, 5]); // [b,n,l,h,m_out,p]
151
152        // qk_dot[m_out, m_in] = Σ_r C[m_out,r] · B[m_in,r]
153        let qk_dot_bnlhmM = c_bnlhmr
154            .clone()
155            .matmul(b_bnlhmr.clone().permute([0, 1, 2, 3, 5, 4]));
156        // y_d_unweighted[m_out, p] = Σ_{m_in} qk_dot · V[m_in, p]
157        let y_d_unw_bnlhmp = qk_dot_bnlhmM.clone().matmul(v_bnlhmp.clone());
158
159        // d_gamma[b,n,l,h] = Σ_{m_out,p} d_y · y_d_unweighted
160        let d_gamma_bnlh: F<B, 4> = (d_y_bnlhmp.clone() * y_d_unw_bnlhmp)
161            .sum_dim(5) // bnlhm1
162            .squeeze_dim::<5>(5) // bnlhm
163            .sum_dim(4) // bnlh1
164            .squeeze_dim::<4>(4); // bnlh
165        san(&d_gamma_bnlh);
166
167        // d_y_d_unweighted = γ · d_y  (γ broadcast over m_out, p)
168        let gamma_bnlh11 = gamma_bnlh.clone().unsqueeze_dims::<6>(&[4, 5]);
169        let d_y_d_unw_bnlhmp = d_y_bnlhmp * gamma_bnlh11;
170
171        // d_qk_dot[m_out, m_in] = Σ_p d_y_d_unweighted[m_out, p] · V[m_in, p]
172        let d_qk_dot_bnlhmM = d_y_d_unw_bnlhmp
173            .clone()
174            .matmul(v_bnlhmp.clone().permute([0, 1, 2, 3, 5, 4])); // [b,n,l,h,m_out,m_in]
175
176        // d_v_diag[m_in, p] = Σ_{m_out} qk_dot[m_out, m_in] · d_y_d_unweighted[m_out, p]
177        let d_v_diag_bnlhmp = qk_dot_bnlhmM
178            .permute([0, 1, 2, 3, 5, 4]) // qk_dot^T: [b,n,l,h,m_in,m_out]
179            .matmul(d_y_d_unw_bnlhmp.clone()); // [b,n,l,h,m_in,p]
180
181        // d_C_diag[m_out, r] = Σ_{m_in} d_qk_dot[m_out, m_in] · B[m_in, r]
182        let d_c_diag_bnlhmr = d_qk_dot_bnlhmM.clone().matmul(b_bnlhmr); // [b,n,l,h,m_out,r]
183        // d_B_diag[m_in, r] = Σ_{m_out} d_qk_dot[m_out, m_in] · C[m_out, r]
184        let d_b_diag_bnlhmr = d_qk_dot_bnlhmM
185            .permute([0, 1, 2, 3, 5, 4]) // d_qk_dot^T: [b,n,l,h,m_in,m_out]
186            .matmul(c_bnlhmr); // [b,n,l,h,m_in,r]
187
188        // Back to [b,n,l,m,h,*].
189        let d_v_diag_bnlmhp = d_v_diag_bnlhmp.permute([0, 1, 2, 4, 3, 5]);
190        let d_c_diag_bnlmhr = d_c_diag_bnlhmr.permute([0, 1, 2, 4, 3, 5]);
191        let d_b_diag_bnlmhr = d_b_diag_bnlhmr.permute([0, 1, 2, 4, 3, 5]);
192        (
193            d_v_diag_bnlmhp,
194            d_c_diag_bnlmhr,
195            d_b_diag_bnlmhr,
196            d_gamma_bnlh,
197        )
198    };
199
200    // Reusable [chunk_len, chunk_len] -inf strict-upper mask (triu(0): on+above
201    // diagonal → -inf) for the LOWER (strict lower-triangular) path.
202    let neg_inf_strict_ll: F<B, 2> =
203        F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype).triu(0);
204
205    // ═══════════════════════════════════════════════════════════════════════
206    // REVERSE PER-CHUNK LOOP — K5 (BLUE + LOWER) + K4 fused
207    // ═══════════════════════════════════════════════════════════════════════
208    let mut vec_lower_d_v_bhLMp: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
209    let mut vec_blue_d_c_bhLMr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
210    let mut vec_d_cb_bhLMLM: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
211    let mut vec_blue_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
212    let mut vec_lower_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
213    let mut vec_lower_d_scale_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
214    let mut vec_d_intra_bhpr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
215    let mut vec_d_da_end_bh: Vec<F<B, 2>> = Vec::with_capacity(nchunks);
216
217    let mut d_running_state_bhpr: F<B, 4> = d_final_bhpr;
218
219    for i_chunk in (0..nchunks).rev() {
220        // ── Per-chunk slices (fused chunk_len · mimo_rank) ─────────────
221        let v_bhLMp: F<B, 4> = v_bnlmhp
222            .clone()
223            .slice(s![.., i_chunk, .., .., .., ..])
224            .squeeze_dim::<5>(1)
225            .reshape([batch, chunk_len * mimo_rank, nheads, per_head_dim])
226            .permute([0, 2, 1, 3]);
227
228        let c_bhLMr: F<B, 4> = c_bnlmhr
229            .clone()
230            .slice(s![.., i_chunk, .., .., .., ..])
231            .squeeze_dim::<5>(1)
232            .reshape([batch, chunk_len * mimo_rank, nheads, state_rank])
233            .permute([0, 2, 1, 3]);
234
235        let cb_bhLMLM: F<B, 4> = cb_bnhLMLM
236            .clone()
237            .slice(s![.., i_chunk, .., .., ..])
238            .squeeze_dim::<4>(1);
239
240        let da_cumsum_bhLM: F<B, 3> = da_cumsum_bhnLM
241            .clone()
242            .slice(s![.., .., i_chunk, ..])
243            .squeeze_dim::<3>(2);
244
245        // scaleₜ per fused source position: scale[s_time] broadcast over s_m.
246        let scale_bhLM: F<B, 3> = scale_bnlh
247            .clone()
248            .slice(s![.., i_chunk, .., ..]) // [b, l, h]
249            .squeeze_dim::<3>(1)
250            .swap_dims(1, 2) // [b, h, l]
251            .unsqueeze_dim::<4>(3) // [b, h, l, 1]
252            .expand([batch, nheads, chunk_len, mimo_rank])
253            .reshape([batch, nheads, chunk_len * mimo_rank]);
254
255        let chunk_input_state_bhpr: F<B, 4> = chunk_input_state_bnhpr
256            .clone()
257            .slice(s![.., i_chunk, .., .., ..])
258            .squeeze_dim::<4>(1);
259        san(&chunk_input_state_bhpr);
260
261        let d_y_bhLMp: F<B, 4> = d_y_bnhLMp
262            .clone()
263            .slice(s![.., i_chunk, .., .., ..])
264            .squeeze_dim::<4>(1);
265
266        // ── BLUE backward (identical to double-ssd form) ─────────────────
267        let exp_da_cumsum_bhLM: F<B, 3> = da_cumsum_bhLM.clone().exp();
268        let exp_da_cumsum_bhLMp: F<B, 4> = exp_da_cumsum_bhLM
269            .clone()
270            .unsqueeze_dim::<4>(3)
271            .expand([batch, nheads, chunk_len * mimo_rank, per_head_dim]);
272        let d_ch_bhLMp: F<B, 4> = d_y_bhLMp.clone() * exp_da_cumsum_bhLMp.clone();
273        san(&d_ch_bhLMp);
274
275        let d_chunk_input_state_bhpr: F<B, 4> = c_bhLMr
276            .clone()
277            .permute([0, 1, 3, 2]) // c_bhrLM
278            .matmul(d_ch_bhLMp.clone()) // bhrp
279            .permute([0, 1, 3, 2]); // bhpr
280        san(&d_chunk_input_state_bhpr);
281
282        let d_c_blue_bhLMr: F<B, 4> = d_ch_bhLMp.clone().matmul(chunk_input_state_bhpr.clone());
283        vec_blue_d_c_bhLMr.push(d_c_blue_bhLMr);
284
285        let ch_bhLMp: F<B, 4> = c_bhLMr
286            .clone()
287            .matmul(chunk_input_state_bhpr.clone().permute([0, 1, 3, 2]));
288        let d_da_blue_bhLM: F<B, 3> = (d_y_bhLMp.clone() * ch_bhLMp * exp_da_cumsum_bhLMp)
289            .sum_dim(3)
290            .squeeze_dim::<3>(3);
291        let d_da_blue_bhl: F<B, 3> = d_da_blue_bhLM
292            .reshape([batch, nheads, chunk_len, mimo_rank])
293            .sum_dim(3)
294            .squeeze_dim::<3>(3);
295        vec_blue_d_da_bhl.push(d_da_blue_bhl);
296
297        // ── LOWER backward (strict lower-tri + per-column scale) ────────
298        let da_target_bhLMLM: F<B, 4> = da_cumsum_bhLM.clone().unsqueeze_dim::<4>(3).expand([
299            batch,
300            nheads,
301            chunk_len * mimo_rank,
302            chunk_len * mimo_rank,
303        ]);
304        let da_source_bhLMLM: F<B, 4> = da_cumsum_bhLM.unsqueeze_dim::<4>(2).expand([
305            batch,
306            nheads,
307            chunk_len * mimo_rank,
308            chunk_len * mimo_rank,
309        ]);
310        let diff_bhLMLM = da_target_bhLMLM - da_source_bhLMLM;
311
312        // Strict-lower MIMO mask: -inf where s_time ≥ t_time — interleaved
313        // expansion of the [l, l] strict-upper (triu(0)) base mask.
314        let neg_inf_mimo_bhLMLM: F<B, 4> = neg_inf_strict_ll
315            .clone()
316            .unsqueeze_dims::<4>(&[0, 1])
317            .expand([batch, nheads, chunk_len, chunk_len])
318            .unsqueeze_dim::<5>(3)
319            .expand([batch, nheads, chunk_len, mimo_rank, chunk_len])
320            .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len])
321            .unsqueeze_dim::<5>(4)
322            .expand([batch, nheads, chunk_len * mimo_rank, chunk_len, mimo_rank])
323            .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]);
324        let decay_strict_bhLMLM = (diff_bhLMLM + neg_inf_mimo_bhLMLM).exp();
325        san(&decay_strict_bhLMLM);
326
327        let scale_col_bhLMLM: F<B, 4> = scale_bhLM
328            .unsqueeze_dim::<4>(2) // [b,h,1,LMs]
329            .expand([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]);
330
331        // w = cb · decay_strict · scale_col
332        let prod_bhLMLM = cb_bhLMLM.clone() * decay_strict_bhLMLM.clone();
333        let w_bhLMLM = prod_bhLMLM.clone() * scale_col_bhLMLM.clone();
334
335        // d_w = d_y · vᵀ
336        let d_w_bhLMLM: F<B, 4> = d_y_bhLMp
337            .clone()
338            .matmul(v_bhLMp.clone().permute([0, 1, 3, 2]));
339        san(&d_w_bhLMLM);
340
341        // d_v_lower = wᵀ · d_y
342        let d_v_lower_bhLMp: F<B, 4> = w_bhLMLM.permute([0, 1, 3, 2]).matmul(d_y_bhLMp.clone());
343        san(&d_v_lower_bhLMp);
344        vec_lower_d_v_bhLMp.push(d_v_lower_bhLMp);
345
346        // d_prod = d_w · scale_col ; d_scale_at = d_w · prod
347        let d_prod_bhLMLM = d_w_bhLMLM.clone() * scale_col_bhLMLM;
348        let d_scale_at_bhLMLM = d_w_bhLMLM * prod_bhLMLM;
349
350        // d_cb_lower = d_prod · decay_strict
351        let d_cb_lower_bhLMLM = d_prod_bhLMLM.clone() * decay_strict_bhLMLM.clone();
352        vec_d_cb_bhLMLM.push(d_cb_lower_bhLMLM);
353
354        // d_decay_strict = d_prod · cb ; d_diff = d_decay_strict · decay_strict
355        let d_decay_strict_bhLMLM = d_prod_bhLMLM * cb_bhLMLM;
356        let d_diff_bhLMLM = d_decay_strict_bhLMLM * decay_strict_bhLMLM;
357
358        let d_da_target_bhLM: F<B, 3> = d_diff_bhLMLM.clone().sum_dim(3).squeeze_dim::<3>(3);
359        let d_da_source_bhLM: F<B, 3> = d_diff_bhLMLM.sum_dim(2).squeeze_dim::<3>(2);
360        let d_da_lower_bhLM = d_da_target_bhLM - d_da_source_bhLM;
361        let d_da_lower_bhl: F<B, 3> = d_da_lower_bhLM
362            .reshape([batch, nheads, chunk_len, mimo_rank])
363            .sum_dim(3)
364            .squeeze_dim::<3>(3);
365        vec_lower_d_da_bhl.push(d_da_lower_bhl);
366
367        // d_scale[s_time] = Σ_{LMt, s_m} d_scale_at[LMt, LMs]
368        let d_scale_lower_bhl: F<B, 3> = d_scale_at_bhLMLM
369            .sum_dim(2) // sum over target LMt → [b,h,1,LMs]
370            .squeeze_dim::<3>(2) // [b,h,LMs]
371            .reshape([batch, nheads, chunk_len, mimo_rank])
372            .sum_dim(3) // sum over source mimo → [b,h,l,1]
373            .squeeze_dim::<3>(3); // [b,h,l]
374        vec_lower_d_scale_bhl.push(d_scale_lower_bhl);
375
376        // ── K4 backward step ───────────────────────────────────────────
377        vec_d_intra_bhpr.push(d_running_state_bhpr.clone());
378
379        let decay_chunk_bhpr: F<B, 4> = da_chunk_end_bhn
380            .clone()
381            .slice(s![.., .., i_chunk])
382            .exp()
383            .unsqueeze_dim::<4>(3)
384            .expand([batch, nheads, per_head_dim, state_rank]);
385        san(&decay_chunk_bhpr);
386
387        let d_decay_chunk_bhpr = d_running_state_bhpr.clone() * chunk_input_state_bhpr;
388        let d_da_chunk_end_bh: F<B, 2> = (d_decay_chunk_bhpr * decay_chunk_bhpr.clone())
389            .reshape([batch, nheads, per_head_dim * state_rank])
390            .sum_dim(2)
391            .squeeze_dim::<2>(2);
392        vec_d_da_end_bh.push(d_da_chunk_end_bh);
393
394        d_running_state_bhpr = decay_chunk_bhpr * d_running_state_bhpr + d_chunk_input_state_bhpr;
395        san(&d_running_state_bhpr);
396    }
397    let d_initial_state_bhpr = d_running_state_bhpr;
398
399    // ── Restore natural (forward) chunk order ─────────────────────────────
400    vec_lower_d_v_bhLMp.reverse();
401    vec_blue_d_c_bhLMr.reverse();
402    vec_d_cb_bhLMLM.reverse();
403    vec_blue_d_da_bhl.reverse();
404    vec_lower_d_da_bhl.reverse();
405    vec_lower_d_scale_bhl.reverse();
406    vec_d_intra_bhpr.reverse();
407    vec_d_da_end_bh.reverse();
408
409    // ── Stack per-chunk slices back into batched tensors ──────────────────
410    let d_v_lower_bnhLMp: F<B, 5> = F::stack(vec_lower_d_v_bhLMp, 1);
411    let d_c_blue_bnhLMr: F<B, 5> = F::stack(vec_blue_d_c_bhLMr, 1);
412    let d_cb_bnhLMLM: F<B, 5> = F::stack(vec_d_cb_bhLMLM, 1);
413    let d_da_blue_bhnl: F<B, 4> = F::stack(vec_blue_d_da_bhl, 2);
414    let d_da_lower_bhnl: F<B, 4> = F::stack(vec_lower_d_da_bhl, 2);
415    let d_scale_lower_bhnl: F<B, 4> = F::stack(vec_lower_d_scale_bhl, 2);
416    let d_intra_chunk_state_bnhpr: F<B, 5> = F::stack(vec_d_intra_bhpr, 1);
417    let d_da_end_bhn: F<B, 3> = F::stack(vec_d_da_end_bh, 2);
418    let d_da_cumsum_k4_bhnl: F<B, 4> = {
419        let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
420        let d_da_end_bhn1 = d_da_end_bhn.unsqueeze_dim::<4>(3);
421        F::cat(vec![zeros, d_da_end_bhn1], 3)
422    };
423
424    // ═══════════════════════════════════════════════════════════════════════
425    // K3 BACKWARD (batched) — K_scaled = scaleₜ·B
426    //
427    // intra_state = decayed_vᵀ @ K_scaled, with decayed_v = decay·V.
428    // d_K_scaled = decayed_vᵀ-contraction ; then split into d_b_k3 (·scale) and
429    // d_scale_k3 (Σ_{m,r} ·B).
430    // ═══════════════════════════════════════════════════════════════════════
431    let v_bnLMhp =
432        v_bnlmhp
433            .clone()
434            .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
435    let k_scaled_bnLMhr =
436        k_scaled_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
437    let k_scaled_bnhLMr = k_scaled_bnLMhr.permute([0, 1, 3, 2, 4]);
438    let decayed_v_bnhpLM = k3_decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
439
440    let d_decayed_v_bnhpLM: F<B, 5> = d_intra_chunk_state_bnhpr
441        .clone()
442        .matmul(k_scaled_bnhLMr.clone().permute([0, 1, 2, 4, 3])); // k_scaled_bnhrLM
443    let d_k_scaled_bnhLMr: F<B, 5> = decayed_v_bnhpLM
444        .permute([0, 1, 2, 4, 3]) // decayed_v_bnhLMp
445        .matmul(d_intra_chunk_state_bnhpr);
446
447    let d_decayed_v_bnLMhp = d_decayed_v_bnhpLM.permute([0, 1, 4, 2, 3]);
448    let d_decay_bhnLM: F<B, 4> = (d_decayed_v_bnLMhp.clone() * v_bnLMhp)
449        .sum_dim(4)
450        .squeeze_dim::<4>(4)
451        .permute([0, 3, 1, 2]);
452
453    let k3_decay_bnLMh1 = k3_decay_bhnLM
454        .clone()
455        .permute([0, 2, 3, 1])
456        .unsqueeze_dim::<5>(4);
457    let d_v_k3_bnLMhp: F<B, 5> = d_decayed_v_bnLMhp * k3_decay_bnLMh1;
458    let d_v_k3_bnlmhp: F<B, 6> =
459        d_v_k3_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
460
461    // d(cumA_last − cumA) = d_decay · decay
462    let d_decay_times_decay_bhnLM = d_decay_bhnLM * k3_decay_bhnLM;
463    let d_a_cumsum_last_bhn: F<B, 3> = d_decay_times_decay_bhnLM
464        .clone()
465        .sum_dim(3)
466        .squeeze_dim::<3>(3);
467    let d_da_cumsum_bhnLM = -d_decay_times_decay_bhnLM;
468
469    let d_da_cumsum_k3_from_fused_bhnl: F<B, 4> = d_da_cumsum_bhnLM
470        .reshape([batch, nheads, nchunks, chunk_len, mimo_rank])
471        .sum_dim(4)
472        .squeeze_dim::<4>(4);
473    let d_da_cumsum_k3_from_last_bhnl: F<B, 4> = {
474        let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
475        let d_last = d_a_cumsum_last_bhn.unsqueeze_dim::<4>(3);
476        F::cat(vec![zeros, d_last], 3)
477    };
478    let d_da_cumsum_k3_bhnl = d_da_cumsum_k3_from_fused_bhnl + d_da_cumsum_k3_from_last_bhnl;
479
480    // d_K_scaled → bnlmhr, then split into d_b_k3 (·scale) and d_scale_k3 (Σ·B).
481    let d_k_scaled_bnlmhr: F<B, 6> = d_k_scaled_bnhLMr
482        .permute([0, 1, 3, 2, 4]) // bnLMhr
483        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
484    let d_b_k3_bnlmhr: F<B, 6> = d_k_scaled_bnlmhr.clone() * scale_bnlh11;
485    let d_scale_k3_bnlh: F<B, 4> = (d_k_scaled_bnlmhr * b_bnlmhr.clone())
486        .sum_dim(5) // sum over state_rank → [b,n,l,m,h,1]
487        .squeeze_dim::<5>(5) // [b,n,l,m,h]
488        .sum_dim(3) // sum over mimo_rank → [b,n,l,1,h]
489        .squeeze_dim::<4>(3); // [b,n,l,h]
490
491    // ═══════════════════════════════════════════════════════════════════════
492    // K2 BACKWARD (batched) — cb = C @ Bᵀ
493    // ═══════════════════════════════════════════════════════════════════════
494    let b_bnLMhr =
495        b_bnlmhr
496            .clone()
497            .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
498    let c_bnhLMr = c_bnlmhr
499        .clone()
500        .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank])
501        .permute([0, 1, 3, 2, 4]);
502    let b_for_k2_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
503
504    let d_c_k2_bnhLMr: F<B, 5> = d_cb_bnhLMLM.clone().matmul(b_for_k2_bnhLMr);
505    let d_b_k2_bnhrLM: F<B, 5> = c_bnhLMr.permute([0, 1, 2, 4, 3]).matmul(d_cb_bnhLMLM);
506
507    let d_c_k2_bnlmhr: F<B, 6> = d_c_k2_bnhLMr
508        .permute([0, 1, 3, 2, 4])
509        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
510    let d_b_k2_bnlmhr: F<B, 6> = d_b_k2_bnhrLM
511        .permute([0, 1, 4, 2, 3])
512        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
513
514    // ── Unstack d_c_blue / d_v_lower and reshape back ─────────────────────
515    let d_c_blue_bnlmhr: F<B, 6> = d_c_blue_bnhLMr
516        .permute([0, 1, 3, 2, 4])
517        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
518    let d_v_lower_bnlmhp: F<B, 6> = d_v_lower_bnhLMp.permute([0, 1, 3, 2, 4]).reshape([
519        batch,
520        nchunks,
521        chunk_len,
522        mimo_rank,
523        nheads,
524        per_head_dim,
525    ]);
526
527    // ═══════════════════════════════════════════════════════════════════════
528    // K1 BACKWARD + SUM CONTRIBUTIONS
529    // ═══════════════════════════════════════════════════════════════════════
530    let d_da_cumsum_bhnl =
531        d_da_blue_bhnl + d_da_lower_bhnl + d_da_cumsum_k3_bhnl + d_da_cumsum_k4_bhnl;
532    san(&d_da_cumsum_bhnl);
533
534    // K1 inverse: suffix sum.
535    let d_da_bhnl = {
536        let d_total_bhnl = d_da_cumsum_bhnl
537            .clone()
538            .sum_dim(3)
539            .expand([batch, nheads, nchunks, chunk_len]);
540        let prefix_bhnl = d_da_cumsum_bhnl.cumsum(3);
541        let zeros_bhn1 = F::<B, 4>::zeros([batch, nheads, nchunks, 1], &device, dtype);
542        let prefix_shifted_bhnl =
543            F::cat(vec![zeros_bhn1, prefix_bhnl.narrow(3, 0, chunk_len - 1)], 3);
544        d_total_bhnl - prefix_shifted_bhnl
545    };
546    let d_da_bnlh = d_da_bhnl.permute([0, 2, 3, 1]);
547
548    // ── Combine per-input gradient contributions ──────────────────────────
549    let d_v_bnlmhp = d_v_k3_bnlmhp + d_v_lower_bnlmhp + d_v_diag_bnlmhp;
550    let d_b_bnlmhr = d_b_k2_bnlmhr + d_b_k3_bnlmhr + d_b_diag_bnlmhr;
551    let d_c_bnlmhr = d_c_k2_bnlmhr + d_c_blue_bnlmhr + d_c_diag_bnlmhr;
552    let d_scale_bnlh = d_scale_lower_bhnl.permute([0, 2, 3, 1]) + d_scale_k3_bnlh;
553
554    san(&d_v_bnlmhp);
555    san(&d_da_bnlh);
556    san(&d_b_bnlmhr);
557    san(&d_c_bnlmhr);
558    san(&d_gamma_bnlh);
559    san(&d_scale_bnlh);
560    san(&d_initial_state_bhpr);
561
562    CombinedSingleSsdGrads {
563        d_v_bnlmhp,
564        d_da_bnlh,
565        d_b_bnlmhr,
566        d_c_bnlmhr,
567        d_gamma_bnlh,
568        d_scale_bnlh,
569        d_initial_state_bhpr,
570    }
571}