Skip to main content

burn_mamba/mamba3/double_ssd/ssd/serial_recalculated/
combined_backward.rs

1//! # Recompute-based gradient math for the Mamba-3 double-SSD
2//!
3//! The analytic backward of the MIMO-first serial scan used by each pass of the
4//! double-SSD decomposition.  The forward intermediates (K1–K4) are recomputed
5//! from the saved leaf inputs rather than stashed, then a reverse per-chunk loop
6//! fuses the K5 and K4 backwards; K1/K2/K3 backwards run batched once the loop
7//! has gathered the per-chunk slices.  The fused `L·M` length carries the
8//! `mimo_rank` axis through the intra-chunk products.
9//!
10//! Everything operates on backend **primitives** through the rank-tagged [`F`]
11//! wrapper: the custom [`Backward`](burn::backend::autodiff::ops::Backward) node
12//! runs with a generic backend `B`, so the high-level `Tensor` is unavailable
13//! and the math uses `B`'s `float_*` ops.  The recomputed K1/K2/K4 kernels are
14//! local primitive ports of the high-level [`super::super::serial`] kernels.
15
16#![allow(non_snake_case)]
17
18use super::serial_recalculated::{k1_ssd_chunk_cumsum, k2_ssd_bmm, k4_ssd_state_passing};
19use crate::utils::fprim::{F, san};
20use burn::backend::Backend;
21use burn::tensor::s;
22
23/// Per-input gradients produced by [`combined_backward`] (one field per
24/// differentiable forward input of the double-SSD scan).
25#[non_exhaustive]
26pub struct CombinedGrads<B: Backend> {
27    /// Gradient of the (pre-scaled) input `v`.
28    pub d_v_bnlmhp: F<B, 6>,
29    /// Gradient of `Δ·A` (`da`).
30    pub d_da_bnlh: F<B, 4>,
31    /// Gradient of the input projection `B`.
32    pub d_b_bnlmhr: F<B, 6>,
33    /// Gradient of the output projection `C`.
34    pub d_c_bnlmhr: F<B, 6>,
35    /// Gradient of the initial SSM state.
36    pub d_initial_state_bhpr: F<B, 4>,
37}
38
39// ─── Recomputed forward kernels ──────────────────────────────────────────────
40// The recompute backward replays the forward's K1/K2/K4 (imported above from
41// [`super::serial_recalculated`]) plus the extended K3 below, which returns the
42// extra intermediates the gradient math needs.
43
44/// Same as [`k3_ssd_chunk_state`](super::serial_recalculated::k3_ssd_chunk_state) but
45/// also returns intermediates needed by the custom backward:
46/// - `intra_chunk_state_bnhpr` — the chunk-end state assuming zero initial state
47/// - `decay_bhnLM` — the fused-length K3 decay factor `exp(cumA_last − cumA_fused)`
48/// - `decayed_v_bnLMhp` — V already scaled by `decay_bnLMh1`
49pub fn k3_ssd_chunk_state_extended<B: Backend>(
50    v_bnlmhp: F<B, 6>,
51    b_bnlmhr: F<B, 6>,
52    da_cumsum_bhnl: F<B, 4>,
53) -> (F<B, 5>, F<B, 4>, F<B, 5>) {
54    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
55    let [.., state_rank] = b_bnlmhr.dims();
56
57    let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
58    let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
59
60    let da_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
61    let da_cumsum_bhnLM = da_cumsum_bhnl
62        .unsqueeze_dim::<5>(4) // da_cumsum_bhnl1
63        .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // da_cumsum_bhnlm
64        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // da_cumsum_bhnLM
65    let decay_bhnLM = (da_cumsum_last_bhn1 - da_cumsum_bhnLM).exp();
66    san(&decay_bhnLM);
67
68    let decay_bnLMh1 = decay_bhnLM
69        .clone()
70        .permute([0, 2, 3, 1])
71        .unsqueeze_dim::<5>(4);
72    let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp;
73    san(&decayed_v_bnLMhp);
74
75    let decayed_v_bnhpLM = decayed_v_bnLMhp.clone().permute([0, 1, 3, 4, 2]);
76    let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
77    let intra_chunk_state_bnhpr = decayed_v_bnhpLM.matmul(b_bnhLMr);
78    san(&intra_chunk_state_bnhpr);
79
80    (intra_chunk_state_bnhpr, decay_bhnLM, decayed_v_bnLMhp)
81}
82
83/// Memory-efficient backward for the Mamba-3 MIMO-first chunkwise SSD.
84///
85/// Recomputes the forward intermediates (K1-K4) from the saved inputs, then
86/// runs a reverse per-chunk loop that fuses the K5 (BLUE + ORANGE) backward
87/// with the K4 state-passing backward.  K3/K2/K1 backwards run as single
88/// batched ops once the loop has collected all per-chunk slices.
89///
90/// # Arguments
91/// - `d_y_bnlmhp` — upstream gradient of the SSD output
92/// - `d_final_bhpr` — upstream gradient of the final SSM state
93/// - `v_bnlmhp`, `da_bnlh`, `b_bnlmhr`, `c_bnlmhr`, `initial_state_bhpr` —
94///   the five saved forward inputs
95///
96/// # Returns
97/// One [`CombinedGrads`] struct containing gradients for all 5 inputs.
98pub fn combined_backward<B: Backend>(
99    d_y_bnlmhp: F<B, 6>,
100    d_final_bhpr: F<B, 4>,
101    //
102    v_bnlmhp: F<B, 6>,
103    da_bnlh: F<B, 4>,
104    b_bnlmhr: F<B, 6>,
105    c_bnlmhr: F<B, 6>,
106    initial_state_bhpr: F<B, 4>,
107) -> CombinedGrads<B> {
108    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
109    let [.., state_rank] = b_bnlmhr.dims();
110    let device = v_bnlmhp.device();
111    let dtype = v_bnlmhp.dtype();
112
113    san(&d_y_bnlmhp);
114    san(&d_final_bhpr);
115    san(&v_bnlmhp);
116    san(&da_bnlh);
117    san(&b_bnlmhr);
118    san(&c_bnlmhr);
119    san(&initial_state_bhpr);
120
121    // ═══════════════════════════════════════════════════════════════════════
122    // RECOMPUTE FORWARD INTERMEDIATES
123    // ═══════════════════════════════════════════════════════════════════════
124
125    // K1 — pre-combined Δ·A → intra-chunk cumsum
126    let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(da_bnlh.clone());
127    san(&da_cumsum_bhnl);
128
129    // K2 — CB matrix used in K5 ORANGE
130    let cb_bnhLMLM = k2_ssd_bmm(c_bnlmhr.clone(), b_bnlmhr.clone());
131    san(&cb_bnhLMLM);
132
133    // K3 — intra-chunk state + decay/decayed-V intermediates
134    let (intra_chunk_state_bnhpr, k3_decay_bhnLM, k3_decayed_v_bnLMhp) =
135        k3_ssd_chunk_state_extended(v_bnlmhp.clone(), b_bnlmhr.clone(), da_cumsum_bhnl.clone());
136
137    // K4 — chunk-input state stream consumed by K5 BLUE
138    let (chunk_input_state_bnhpr, _final_state_bhpr) = k4_ssd_state_passing(
139        intra_chunk_state_bnhpr,
140        da_chunk_end_bhn.clone(),
141        initial_state_bhpr,
142    );
143
144    // ═══════════════════════════════════════════════════════════════════════
145    // FUSED-L INTERMEDIATES USED BY THE REVERSE LOOP
146    // ═══════════════════════════════════════════════════════════════════════
147    //
148    // da_cumsum_bhnLM: cumA per fused position. The expand-then-reshape
149    // repeats each base position mimo_rank times along the fused dim, matching K5.
150    let da_cumsum_bhnLM = da_cumsum_bhnl
151        .clone()
152        .unsqueeze_dim::<5>(4) // da_cumsum_bhnl1
153        .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) // da_cumsum_bhnlm
154        .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); // da_cumsum_bhnLM
155
156    // d_y in (batch, nchunks, nheads, chunk_len * mimo_rank, per_head_dim) ordering
157    // — matches the per-chunk slicing.
158    let d_y_bnhLMp = d_y_bnlmhp
159        .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]) // d_y_bnLMhp
160        .permute([0, 1, 3, 2, 4]); // d_y_bnhLMp
161    san(&d_y_bnhLMp);
162
163    // Reusable [chunk_len, chunk_len] -inf upper-triangular base mask for ORANGE.
164    let neg_inf_base_ll: F<B, 2> =
165        { F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype).triu(1) };
166
167    // ═══════════════════════════════════════════════════════════════════════
168    // REVERSE PER-CHUNK LOOP — K5 (BLUE + ORANGE) + K4 fused
169    //
170    // Per-iteration working tensors are [batch,nheads,chunk_len*mimo_rank,...] rather than the
171    // [batch,state_rank,nheads,chunk_len*mimo_rank,...] tensors a fully batched K5 backward would allocate.
172    // ═══════════════════════════════════════════════════════════════════════
173    let mut vec_orange_d_v_bhLMp: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
174    let mut vec_blue_d_c_bhLMr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
175    let mut vec_d_cb_bhLMLM: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
176    let mut vec_blue_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
177    let mut vec_orange_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
178    let mut vec_d_intra_bhpr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
179    let mut vec_d_da_end_bh: Vec<F<B, 2>> = Vec::with_capacity(nchunks);
180
181    let mut d_running_state_bhpr: F<B, 4> = d_final_bhpr;
182
183    for i_chunk in (0..nchunks).rev() {
184        // ── Per-chunk slices (fused chunk_len · mimo_rank) ─────────────
185        let v_bhLMp: F<B, 4> = v_bnlmhp
186            .clone()
187            .slice(s![.., i_chunk, .., .., .., ..]) // v_b1lmhp
188            .squeeze_dim::<5>(1) // v_blmhp
189            .reshape([batch, chunk_len * mimo_rank, nheads, per_head_dim]) // v_bLMhp
190            .permute([0, 2, 1, 3]); // v_bhLMp
191
192        let c_bhLMr: F<B, 4> = c_bnlmhr
193            .clone()
194            .slice(s![.., i_chunk, .., .., .., ..]) // c_b1lmhr
195            .squeeze_dim::<5>(1) // c_blmhr
196            .reshape([batch, chunk_len * mimo_rank, nheads, state_rank]) // c_bLMhr
197            .permute([0, 2, 1, 3]); // c_bhLMr
198
199        let cb_bhLMLM: F<B, 4> = cb_bnhLMLM
200            .clone()
201            .slice(s![.., i_chunk, .., .., ..]) // cb_b1hLMLM
202            .squeeze_dim::<4>(1); // cb_bhLMLM
203
204        let da_cumsum_bhLM: F<B, 3> = da_cumsum_bhnLM
205            .clone()
206            .slice(s![.., .., i_chunk, ..]) // da_cumsum_bh1LM
207            .squeeze_dim::<3>(2); // da_cumsum_bhLM
208
209        let chunk_input_state_bhpr: F<B, 4> = chunk_input_state_bnhpr
210            .clone()
211            .slice(s![.., i_chunk, .., .., ..]) // chunk_input_state_b1hpr
212            .squeeze_dim::<4>(1); // chunk_input_state_bhpr
213        san(&chunk_input_state_bhpr);
214
215        let d_y_bhLMp: F<B, 4> = d_y_bnhLMp
216            .clone()
217            .slice(s![.., i_chunk, .., .., ..]) // d_y_b1hLMp
218            .squeeze_dim::<4>(1); // d_y_bhLMp
219
220        // ── BLUE backward ──────────────────────────────────────────────
221        //
222        //   blue[LM,p] = exp(cumA[LM]) · Σᵣ C[LM,r] · state[p,r]
223        //
224        // exp_da depends on the fused position LM only — broadcast over per_head_dim.
225        let exp_da_cumsum_bhLM: F<B, 3> = da_cumsum_bhLM.clone().exp();
226        let exp_da_cumsum_bhLMp: F<B, 4> = exp_da_cumsum_bhLM
227            .clone()
228            .unsqueeze_dim::<4>(3) // exp_da_cumsum_bhLM1
229            .expand([batch, nheads, chunk_len * mimo_rank, per_head_dim]); // exp_da_cumsum_bhLMp
230        let d_ch_bhLMp: F<B, 4> = d_y_bhLMp.clone() * exp_da_cumsum_bhLMp.clone();
231        san(&d_ch_bhLMp);
232
233        // d_chunk_input_state[p,r] = Σ_LM C[LM,r] · d_ch[LM,p]
234        //   C^T (bhrLM) @ d_ch (bhLMp)  → bhrp  → permute → bhpr
235        let d_chunk_input_state_bhpr: F<B, 4> = c_bhLMr
236            .clone()
237            .permute([0, 1, 3, 2]) // c_bhrLM
238            .matmul(d_ch_bhLMp.clone()) // d_chunk_input_state_bhrp
239            .permute([0, 1, 3, 2]); // d_chunk_input_state_bhpr
240        san(&d_chunk_input_state_bhpr);
241
242        // d_C_blue[LM,r] = Σₚ d_ch[LM,p] · state[p,r]
243        //   d_ch (bhLMp) @ state (bhpr)  → bhLMr
244        let d_c_blue_bhLMr: F<B, 4> = d_ch_bhLMp.matmul(chunk_input_state_bhpr.clone());
245        san(&d_c_blue_bhLMr);
246        vec_blue_d_c_bhLMr.push(d_c_blue_bhLMr);
247
248        // d_da from BLUE:
249        //   ch[LM,p] = Σᵣ C[LM,r] · state[p,r]      (= C @ state_rp after permute)
250        //   d_da[LM] = (Σₚ d_y[LM,p] · ch[LM,p]) · exp_da[LM]
251        let ch_bhLMp: F<B, 4> = c_bhLMr.clone().matmul(
252            chunk_input_state_bhpr.clone().permute([0, 1, 3, 2]), // chunk_input_state_bhrp
253        ); // ch_bhLMp
254        let d_da_blue_bhLM: F<B, 3> = (d_y_bhLMp.clone() * ch_bhLMp * exp_da_cumsum_bhLMp)
255            .sum_dim(3) // d_da_blue_bhLM1
256            .squeeze_dim::<3>(3); // d_da_blue_bhLM
257        san(&d_da_blue_bhLM);
258
259        // Reduce fused LM → l (sum the mimo_rank copies that K5 broadcast).
260        let d_da_blue_bhl: F<B, 3> = d_da_blue_bhLM
261            .reshape([batch, nheads, chunk_len, mimo_rank]) // d_da_blue_bhlm
262            .sum_dim(3) // d_da_blue_bhl1
263            .squeeze_dim::<3>(3); // d_da_blue_bhl
264        vec_blue_d_da_bhl.push(d_da_blue_bhl);
265
266        // ── ORANGE backward ────────────────────────────────────────────
267        //
268        //   w[LMₜ,LMₛ] = CB[LMₜ,LMₛ] · decay[LMₜ,LMₛ]   (MIMO causal mask in decay)
269        //   orange[LMₜ,p] = Σ_{LMₛ} w[LMₜ,LMₛ] · v[LMₛ,p]
270        let da_target_bhLMLM: F<B, 4> = da_cumsum_bhLM
271            .clone()
272            .unsqueeze_dim::<4>(3) // da_cumsum_bhLMₜ1
273            .expand([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]); // da_target_bhLMₜLM
274        let da_source_bhLMLM: F<B, 4> = da_cumsum_bhLM
275            .unsqueeze_dim::<4>(2) // da_cumsum_bh1LMₛ
276            .expand([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]); // da_source_bhLMLMₛ
277        let diff_bhLMLM = da_target_bhLMLM - da_source_bhLMLM;
278        san(&diff_bhLMLM);
279
280        // MIMO causal mask: -inf where LMₛ//mimo_rank > LMₜ//mimo_rank — interleaved expansion
281        // of the [l, l] upper-triangular base mask (matches K5).
282        let neg_inf_mimo_bhLMLM: F<B, 4> = neg_inf_base_ll
283            .clone()
284            .unsqueeze_dims::<4>(&[0, 1]) // neg_inf_base_11ll
285            .expand([batch, nheads, chunk_len, chunk_len]) // neg_inf_base_bhll
286            .unsqueeze_dim::<5>(3) // neg_inf_base_bhl1l
287            .expand([batch, nheads, chunk_len, mimo_rank, chunk_len]) // neg_inf_base_bhlml
288            .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len]) // neg_inf_base_bhLMl
289            .unsqueeze_dim::<5>(4) // neg_inf_base_bhLMl1
290            .expand([batch, nheads, chunk_len * mimo_rank, chunk_len, mimo_rank]) // neg_inf_base_bhLMlm
291            .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]); // neg_inf_mimo_bhLMLM
292        let decay_bhLMLM = (diff_bhLMLM + neg_inf_mimo_bhLMLM).exp();
293        san(&decay_bhLMLM);
294
295        // d_v_orange = w^T @ d_orange ; d_w = d_orange @ v^T
296        let d_orange_bhLMp = d_y_bhLMp;
297        let w_bhLMLM = cb_bhLMLM.clone() * decay_bhLMLM.clone();
298        let d_w_bhLMLM: F<B, 4> = d_orange_bhLMp.clone().matmul(
299            v_bhLMp.clone().permute([0, 1, 3, 2]), // v_bhpLM
300        ); // d_w_bhLMₜLMₛ
301        san(&d_w_bhLMLM);
302        let d_v_orange_bhLMp: F<B, 4> = w_bhLMLM
303            .permute([0, 1, 3, 2]) // w_bhLMₛLMₜ
304            .matmul(d_orange_bhLMp); // d_v_orange_bhLMₛp
305        san(&d_v_orange_bhLMp);
306        vec_orange_d_v_bhLMp.push(d_v_orange_bhLMp);
307
308        // d_cb = d_w · decay ; d_decay = d_w · cb ; d_diff = d_decay · decay
309        // (masked positions where decay=0 contribute 0 to d_diff automatically)
310        let d_cb_bhLMLM = d_w_bhLMLM.clone() * decay_bhLMLM.clone();
311        vec_d_cb_bhLMLM.push(d_cb_bhLMLM);
312
313        let d_decay_bhLMLM = d_w_bhLMLM * cb_bhLMLM;
314        let d_diff_bhLMLM = d_decay_bhLMLM * decay_bhLMLM;
315
316        // d_da_target[LMₜ] = Σ_{LMₛ} d_diff[LMₜ, LMₛ] ;
317        // d_da_source[LMₛ] = Σ_{LMₜ} d_diff[LMₜ, LMₛ] ;
318        // d_da_orange = d_da_target − d_da_source  (diff = target − source).
319        let d_da_target_bhLM: F<B, 3> = d_diff_bhLMLM
320            .clone()
321            .sum_dim(3) // d_diff_bhLMₜ1
322            .squeeze_dim::<3>(3); // d_da_target_bhLMₜ
323        let d_da_source_bhLM: F<B, 3> = d_diff_bhLMLM
324            .sum_dim(2) // d_diff_bh1LMₛ
325            .squeeze_dim::<3>(2); // d_da_source_bhLMₛ
326        let d_da_orange_bhLM = d_da_target_bhLM - d_da_source_bhLM;
327        san(&d_da_orange_bhLM);
328
329        // Reduce fused LM → l (sum over the mimo_rank-broadcast copies).
330        let d_da_orange_bhl: F<B, 3> = d_da_orange_bhLM
331            .reshape([batch, nheads, chunk_len, mimo_rank]) // d_da_orange_bhlm
332            .sum_dim(3) // d_da_orange_bhl1
333            .squeeze_dim::<3>(3); // d_da_orange_bhl
334        vec_orange_d_da_bhl.push(d_da_orange_bhl);
335
336        // ── K4 backward step for chunk i_chunk ─────────────────────────
337        //
338        // Forward (recap):  sᵢ₊₁ = decayᵢ · sᵢ + intra_stateᵢ
339        //   - d_intra_stateᵢ      = d_sᵢ₊₁      (current d_running_state)
340        //   - d_decayᵢ            = d_sᵢ₊₁ · sᵢ
341        //   - d_sᵢ (propagated)   = decayᵢ · d_sᵢ₊₁ + d_chunk_input_state_blue
342        vec_d_intra_bhpr.push(d_running_state_bhpr.clone());
343
344        let decay_chunk_bhpr: F<B, 4> = da_chunk_end_bhn
345            .clone()
346            .slice(s![.., .., i_chunk]) // da_chunk_end_bh1
347            .exp() // decay_chunk_bh
348            .unsqueeze_dim::<4>(3) // decay_chunk_bh11
349            .expand([batch, nheads, per_head_dim, state_rank]); // decay_chunk_bhpr
350        san(&decay_chunk_bhpr);
351
352        let d_decay_chunk_bhpr = d_running_state_bhpr.clone() * chunk_input_state_bhpr;
353        // d_da_chunk_end[b,h] = Σ_{p,r} d_decay · decay (since decay = exp(da_chunk_end))
354        let d_da_chunk_end_bh: F<B, 2> = (d_decay_chunk_bhpr * decay_chunk_bhpr.clone())
355            .reshape([batch, nheads, per_head_dim * state_rank]) // d_da_chunk_end_bhPR
356            .sum_dim(2) // d_da_chunk_end_bh1
357            .squeeze_dim::<2>(2); // d_da_chunk_end_bh
358        san(&d_da_chunk_end_bh);
359        vec_d_da_end_bh.push(d_da_chunk_end_bh);
360
361        d_running_state_bhpr = decay_chunk_bhpr * d_running_state_bhpr + d_chunk_input_state_bhpr;
362        san(&d_running_state_bhpr);
363    }
364    let d_initial_state_bhpr = d_running_state_bhpr;
365
366    // ── Restore natural (forward) chunk order ─────────────────────────────
367    vec_orange_d_v_bhLMp.reverse();
368    vec_blue_d_c_bhLMr.reverse();
369    vec_d_cb_bhLMLM.reverse();
370    vec_blue_d_da_bhl.reverse();
371    vec_orange_d_da_bhl.reverse();
372    vec_d_intra_bhpr.reverse();
373    vec_d_da_end_bh.reverse();
374
375    // ── Stack per-chunk slices back into batched tensors ──────────────────
376    let d_v_orange_bnhLMp: F<B, 5> = F::stack(vec_orange_d_v_bhLMp, 1);
377    let d_c_blue_bnhLMr: F<B, 5> = F::stack(vec_blue_d_c_bhLMr, 1);
378    let d_cb_bnhLMLM: F<B, 5> = F::stack(vec_d_cb_bhLMLM, 1);
379    let d_da_blue_bhnl: F<B, 4> = F::stack(vec_blue_d_da_bhl, 2);
380    let d_da_orange_bhnl: F<B, 4> = F::stack(vec_orange_d_da_bhl, 2);
381    let d_intra_chunk_state_bnhpr: F<B, 5> = F::stack(vec_d_intra_bhpr, 1);
382    // d_da_end:
383    // [batch,nheads]     → stack@2 → [batch,nheads,nchunks]; scatter into last-l of d_da_cumsum_k4
384    let d_da_end_bhn: F<B, 3> = F::stack(vec_d_da_end_bh, 2);
385    let d_da_cumsum_k4_bhnl: F<B, 4> = {
386        let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
387        let d_da_end_bhn1 = d_da_end_bhn.unsqueeze_dim::<4>(3);
388        F::cat(vec![zeros, d_da_end_bhn1], 3)
389    };
390
391    // ═══════════════════════════════════════════════════════════════════════
392    // K3 BACKWARD (batched)
393    //
394    // Forward (recap):
395    //   v_bnLMhp        = v.reshape
396    //   b_bnLMhr        = b.reshape
397    //   decay_bhnLM     = exp(cumA_last − cumA)
398    //   decay_bnLMh1    = decay_bhnLM.permute([0,2,3,1]).unsqueeze(4)
399    //   decayed_v_bnLMhp = decay_bnLMh1 · v_bnLMhp                      (elementwise)
400    //   decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0,1,3,4,2])
401    //   b_bnhLMr        = b_bnLMhr.permute([0,1,3,2,4])
402    //   intra_state    = decayed_v_bnhpLM @ b_bnhLMr
403    // ═══════════════════════════════════════════════════════════════════════
404    let v_bnLMhp =
405        v_bnlmhp
406            .clone()
407            .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
408    let b_bnLMhr =
409        b_bnlmhr
410            .clone()
411            .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
412    let b_bnhLMr = b_bnLMhr.clone().permute([0, 1, 3, 2, 4]);
413    let decayed_v_bnhpLM = k3_decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
414
415    let d_decayed_v_bnhpLM: F<B, 5> = d_intra_chunk_state_bnhpr.clone().matmul(
416        b_bnhLMr.permute([0, 1, 2, 4, 3]), // b_bnhrLM
417    ); // d_decayed_v_bnhpLM
418    let d_b_k3_bnhLMr: F<B, 5> = decayed_v_bnhpLM
419        .permute([0, 1, 2, 4, 3]) // decayed_v_bnhLMp
420        .matmul(d_intra_chunk_state_bnhpr); // d_b_k3_bnhLMr
421
422    let d_decayed_v_bnLMhp = d_decayed_v_bnhpLM.permute([0, 1, 4, 2, 3]);
423    let d_decay_bhnLM: F<B, 4> = (d_decayed_v_bnLMhp.clone() * v_bnLMhp)
424        .sum_dim(4) // d_decay_bnLMh1
425        .squeeze_dim::<4>(4) // d_decay_bnLMh
426        .permute([0, 3, 1, 2]); // d_decay_bhnLM
427
428    // d_v_k3_bnLMhp = d_decayed_v · decay (broadcast)
429    let k3_decay_bnLMh1 = k3_decay_bhnLM
430        .clone()
431        .permute([0, 2, 3, 1]) // k3_decay_bnLMh
432        .unsqueeze_dim::<5>(4); // k3_decay_bnLMh1
433    let d_v_k3_bnLMhp: F<B, 5> = d_decayed_v_bnLMhp * k3_decay_bnLMh1;
434    let d_v_k3_bnlrhp: F<B, 6> =
435        d_v_k3_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
436
437    // d(cumA_last − cumA) = d_decay · decay
438    let d_decay_times_decay_bhnLM = d_decay_bhnLM * k3_decay_bhnLM;
439    // d_a_cumsum_last: Σ over LM (broadcast dim).
440    let d_a_cumsum_last_bhn: F<B, 3> = d_decay_times_decay_bhnLM
441        .clone()
442        .sum_dim(3) // d_decay_times_decay_bhn1
443        .squeeze_dim::<3>(3); // d_a_cumsum_last_bhn
444    // d_a_cumsum: negated (subtraction).
445    let d_da_cumsum_bhnLM = -d_decay_times_decay_bhnLM;
446
447    // Contribution to d_da_cumsum from the fused-cumA expand (sum mimo_rank copies).
448    let d_da_cumsum_k3_from_fused_bhnl: F<B, 4> = d_da_cumsum_bhnLM
449        .reshape([batch, nheads, nchunks, chunk_len, mimo_rank]) // d_da_cumsum_bhnlm
450        .sum_dim(4) // d_da_cumsum_bhnl1
451        .squeeze_dim::<4>(4); // d_da_cumsum_k3_from_fused_bhnl
452    // Contribution from cumA_last: only the last-l position.
453    let d_da_cumsum_k3_from_last_bhnl: F<B, 4> = {
454        let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
455        let d_last = d_a_cumsum_last_bhn.unsqueeze_dim::<4>(3);
456        F::cat(vec![zeros, d_last], 3)
457    };
458    let d_da_cumsum_k3_bhnl = d_da_cumsum_k3_from_fused_bhnl + d_da_cumsum_k3_from_last_bhnl;
459
460    // d_b_k3
461    let d_b_k3_bnLMhr = d_b_k3_bnhLMr.permute([0, 1, 3, 2, 4]);
462    let d_b_k3_bnlmhr: F<B, 6> =
463        d_b_k3_bnLMhr.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
464
465    // ═══════════════════════════════════════════════════════════════════════
466    // K2 BACKWARD (batched)
467    //
468    //   cb_bnhLMLM = c_bnhLMr @ b_bnhrLM   (contracts state_rank)
469    //   d_c_bnhLMr = d_cb @ b_bnhLMr      (= d_cb @ b_bnhrLM^T)
470    //   d_b_bnhrLM = c_bnhrLM @ d_cb      (= c_bnhLMr^T @ d_cb)
471    // ═══════════════════════════════════════════════════════════════════════
472    let c_bnhLMr = c_bnlmhr
473        .clone()
474        .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]) // c_bnLMhr
475        .permute([0, 1, 3, 2, 4]); // c_bnhLMr
476    let b_for_k2_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
477
478    let d_c_k2_bnhLMr: F<B, 5> = d_cb_bnhLMLM.clone().matmul(b_for_k2_bnhLMr);
479    let d_b_k2_bnhrLM: F<B, 5> = c_bnhLMr
480        .permute([0, 1, 2, 4, 3]) // c_bnhrLM
481        .matmul(d_cb_bnhLMLM); // d_b_k2_bnhrLM
482
483    // Undo permutes and reshape back
484    let d_c_k2_bnlmhr: F<B, 6> = d_c_k2_bnhLMr
485        .permute([0, 1, 3, 2, 4]) // d_c_k2_bnLMhr
486        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]); // d_c_k2_bnlmhr
487    let d_b_k2_bnlmhr: F<B, 6> = d_b_k2_bnhrLM
488        .permute([0, 1, 4, 2, 3]) // d_b_k2_bnLMhr
489        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]); // d_b_k2_bnlmhr
490
491    // ── Unstack d_c_blue / d_v_orange and reshape back ────────────────────
492    let d_c_blue_bnlmhr: F<B, 6> = d_c_blue_bnhLMr
493        .permute([0, 1, 3, 2, 4]) // d_c_blue_bnLMhr
494        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]); // d_c_blue_bnlmhr
495    let d_v_orange_bnlrhp: F<B, 6> = d_v_orange_bnhLMp
496        .permute([0, 1, 3, 2, 4]) // d_v_orange_bnLMhp
497        .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]); // d_v_orange_bnlrhp
498
499    // ═══════════════════════════════════════════════════════════════════════
500    // K1 BACKWARD + SUM CONTRIBUTIONS
501    // ═══════════════════════════════════════════════════════════════════════
502    let d_da_cumsum_bhnl =
503        d_da_blue_bhnl + d_da_orange_bhnl + d_da_cumsum_k3_bhnl + d_da_cumsum_k4_bhnl;
504    san(&d_da_cumsum_bhnl);
505
506    // K1 inverse: da_cumsum[l] = cumsum(da)[l]  →  d_da[l] = Σ_{k ≥ l} d_da_cumsum[k]
507    //
508    // Suffix sum:  d_da[l] = total_sum − cumsum(d_da_cumsum)[l-1] (cumsum[-1] = 0).
509    let d_da_bhnl = {
510        let d_total_bhnl = d_da_cumsum_bhnl
511            .clone()
512            .sum_dim(3) // d_da_cumsum_bhn1
513            .expand([batch, nheads, nchunks, chunk_len]); // d_total_bhnl
514        let prefix_bhnl = d_da_cumsum_bhnl.cumsum(3);
515        let zeros_bhn1 = F::<B, 4>::zeros([batch, nheads, nchunks, 1], &device, dtype);
516        let prefix_shifted_bhnl =
517            F::cat(vec![zeros_bhn1, prefix_bhnl.narrow(3, 0, chunk_len - 1)], 3);
518        d_total_bhnl - prefix_shifted_bhnl
519    };
520    san(&d_da_bhnl);
521    // Undo permute
522    let d_da_bnlh = d_da_bhnl.permute([0, 2, 3, 1]);
523
524    // ── Combine per-input gradient contributions ──────────────────────────
525    let d_v_bnlmhp = d_v_k3_bnlrhp + d_v_orange_bnlrhp;
526    let d_b_bnlmhr = d_b_k2_bnlmhr + d_b_k3_bnlmhr;
527    let d_c_bnlmhr = d_c_k2_bnlmhr + d_c_blue_bnlmhr;
528
529    san(&d_v_bnlmhp);
530    san(&d_da_bnlh);
531    san(&d_b_bnlmhr);
532    san(&d_c_bnlmhr);
533    san(&d_initial_state_bhpr);
534
535    CombinedGrads {
536        d_v_bnlmhp,
537        d_da_bnlh,
538        d_b_bnlmhr,
539        d_c_bnlmhr,
540        d_initial_state_bhpr,
541    }
542}