Skip to main content

burn_mamba/mamba2/ssd/serial_recalculated/
combined_backward.rs

1use crate::mamba2::ssd::serial;
2use crate::utils::sanity::sanity as san;
3use burn::prelude::*;
4
5pub struct CombinedGrads<B: Backend> {
6    pub d_x_bnlhp: Tensor<B, 5>,
7    pub d_dt_discretized_bhnl: Tensor<B, 4>,
8    pub d_b_bnlgr: Tensor<B, 5>,
9    pub d_c_bnlgr: Tensor<B, 5>,
10    pub d_d_h: Tensor<B, 1>,
11    pub d_initial_state_bhpr: Tensor<B, 4>,
12    pub d_a_decay_h: Tensor<B, 1>,
13}
14
15/// Same as [k3_ssd_chunk_state](serial::k3_ssd_chunk_state) but return some intermediaries
16/// that are useful to the custom backward operation.
17///
18/// Returns:
19/// - intra_chunk_state_bnhpr
20/// - b_bar_scale_bhnl
21/// - forward_decay_to_chunk_end_bhnl
22/// - b_scaled_bnhlr
23pub fn k3_ssd_chunk_state_extended<B: Backend>(
24    x_bnlhp: Tensor<B, 5>,
25    b_bnlgr: Tensor<B, 5>,
26    da_cumsum_bhnl: Tensor<B, 4>,
27    dt_discretized_bhnl: Tensor<B, 4>,
28) -> (Tensor<B, 5>, Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 5>) {
29    use burn::tensor::s;
30
31    let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
32    let [.., ngroups, state_rank] = b_bnlgr.dims();
33
34    // permute b and x to prepare them for the mamtul
35    // - 1/15: permute: (x_bnlhp [in][*]) -> (x_bnhpl)
36    let x_bnhpl = x_bnlhp.clone().permute([0, 1, 3, 4, 2]);
37    assert_eq!(
38        [batch, nchunks, nheads, per_head_dim, chunk_len],
39        x_bnhpl.dims()
40    );
41    // - 2: permute: (b_bnlgr [in][*]) -> (b_bnglr)
42    let b_bnglr = b_bnlgr.permute([0, 1, 3, 2, 4]); // note: still in groups instead of heads
43    assert_eq!(
44        [batch, nchunks, ngroups, chunk_len, state_rank],
45        b_bnglr.dims()
46    );
47
48    // Expand B from ngroups to nheads by repeating each group's
49    // projection across all heads_per_group heads in that group.
50    let heads_per_group = nheads / ngroups;
51    let b_bnhlr = b_bnglr
52        // - 3: unsqueeze: (b_bnglr) -> (b_bng1lr)
53        .unsqueeze_dim::<6>(3) // b_bng1lr
54        // - 4: expand: (b_bng1lr) -> (b_bngHlr)
55        .expand([
56            batch,
57            nchunks,
58            ngroups,
59            heads_per_group,
60            chunk_len,
61            state_rank,
62        ]) // b_bngHlr
63        // - 5: reshape: (b_bngHlr) -> (b_bnhlr)
64        .reshape([batch, nchunks, nheads, chunk_len, state_rank]);
65
66    // scale b
67    let da_cumsum_last_in_chunk_bhn1 =
68        // - 6: slice: (da_cumsum_bhnl [in][*]) -> (da_cumsum_last_in_chunk_bhn1)
69        da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
70    assert_eq!(
71        [batch, nheads, nchunks, 1],
72        da_cumsum_last_in_chunk_bhn1.dims()
73    );
74
75    // - 7: expand: (da_cumsum_last_in_chunk_bhn1) -> (da_cumsum_last_bhnl)
76    let da_cumsum_last_bhnl =
77        da_cumsum_last_in_chunk_bhn1.expand([batch, nheads, nchunks, chunk_len]);
78    // - 8: sub: (da_cumsum_last_bhnl, da_cumsum_bhnl [from K1][*]) -> (da_delta_bhnl)
79    let da_delta_bhnl = da_cumsum_last_bhnl - da_cumsum_bhnl.clone();
80    san(&da_delta_bhnl);
81    // - 9: exp: (da_delta_bhnl) -> (forward_decay_to_chunk_end_bhnl [+])
82    let forward_decay_to_chunk_end_bhnl = da_delta_bhnl.exp();
83    assert_eq!(
84        [batch, nheads, nchunks, chunk_len],
85        forward_decay_to_chunk_end_bhnl.dims()
86    );
87    san(&forward_decay_to_chunk_end_bhnl);
88
89    // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
90    let b_bar_scale_bhnl = forward_decay_to_chunk_end_bhnl.clone() * dt_discretized_bhnl.clone();
91    assert_eq!([batch, nheads, nchunks, chunk_len], b_bar_scale_bhnl.dims());
92    san(&b_bar_scale_bhnl);
93
94    // - 11: permute: (b_bar_scale_bhnl [+]) -> (b_bar_scale_bnhl)
95    let b_bar_scale_bnhl = b_bar_scale_bhnl.clone().permute([0, 2, 1, 3]);
96    assert_eq!([batch, nchunks, nheads, chunk_len], b_bar_scale_bnhl.dims());
97    let b_bar_scale_bnhlr = b_bar_scale_bnhl
98        // - 12: unsqueeze: (b_bar_scale_bnhl) -> (b_bar_scale_bnhl1)
99        .unsqueeze_dim::<5>(4) // b_bar_scale_bnhl1
100        // - 13: expand: (b_bar_scale_bnhl1) -> (b_bar_scale_bnhlr)
101        .expand([batch, nchunks, nheads, chunk_len, state_rank]);
102    // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
103    let b_scaled_bnhlr = b_bnhlr * b_bar_scale_bnhlr;
104    assert_eq!(
105        [batch, nchunks, nheads, chunk_len, state_rank],
106        b_scaled_bnhlr.dims()
107    );
108    san(&b_scaled_bnhlr);
109
110    // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
111    let intra_chunk_state_bnhpr: Tensor<B, 5> = x_bnhpl.matmul(b_scaled_bnhlr.clone());
112    assert_eq!(
113        [batch, nchunks, nheads, per_head_dim, state_rank],
114        intra_chunk_state_bnhpr.dims()
115    );
116    san(&intra_chunk_state_bnhpr);
117    (
118        intra_chunk_state_bnhpr,
119        b_bar_scale_bhnl,
120        forward_decay_to_chunk_end_bhnl,
121        b_scaled_bnhlr,
122    )
123}
124
125/// Core gradient computation.  All arguments use the shapes from the forward.
126///
127/// `d_y_bnlhp`         : upstream gradient of the scan output  [B,N,L,H,P]
128/// `d_final_bhpr`      : upstream gradient of the final state  [B,H,P,R]
129///
130/// Returns one `CombinedGrads` struct containing gradients for all 7 inputs.
131#[allow(clippy::too_many_arguments)]
132pub fn combined_backward<B: Backend>(
133    d_y_bnlhp: Tensor<B, 5>,
134    d_final_bhpr: Tensor<B, 4>,
135    // Saved forward inputs
136    x_bnlhp: Tensor<B, 5>,
137    dt_discretized_bhnl: Tensor<B, 4>,
138    b_bnlgr: Tensor<B, 5>,
139    c_bnlgr: Tensor<B, 5>,
140    d_h: Tensor<B, 1>,
141    initial_state_bhpr: Tensor<B, 4>,
142    a_decay_h: Tensor<B, 1>,
143) -> CombinedGrads<B> {
144    let [batch, nheads, nchunks, chunk_len] = dt_discretized_bhnl.dims();
145    let [.., per_head_dim] = x_bnlhp.dims();
146    let [.., ngroups, state_rank] = b_bnlgr.dims();
147    let heads_per_group = nheads / ngroups;
148    let device = dt_discretized_bhnl.device();
149
150    san(&d_y_bnlhp);
151    san(&d_final_bhpr);
152    san(&x_bnlhp);
153    san(&dt_discretized_bhnl);
154    san(&b_bnlgr);
155    san(&c_bnlgr);
156    san(&d_h);
157    san(&initial_state_bhpr);
158    san(&a_decay_h);
159
160    // ═══════════════════════════════════════════════════════════════════════
161    // RECOMPUTE FORWARD INTERMEDIATES (the memory-saving heart of this op)
162    // ═══════════════════════════════════════════════════════════════════════
163
164    // K1 recomputation ─────────────────────────────────────────────────────
165    // da_cumsum is not saved across the boundary; recompute from dt and a_decay.
166    let (da_cumsum_bhnl, da_chunk_end_bhn) =
167        serial::k1_ssd_chunk_cumsum(dt_discretized_bhnl.clone(), a_decay_h.clone());
168    san(&da_cumsum_bhnl);
169    san(&da_chunk_end_bhn);
170
171    // K2 ───────────────────────────────────────────────────────────────────
172    let cb_bngll = serial::k2_ssd_bmm(c_bnlgr.clone(), b_bnlgr.clone());
173    // let cb_bngll = k2_forward(&c_bnlgr, &b_bnlgr);          // [B,N,G,L,L]
174    san(&cb_bngll);
175
176    // K3 (with intermediates) ──────────────────────────────────────────────
177    let (
178        intra_chunk_state_bnhpr,
179        b_bar_scale_bhnl,
180        forward_decay_to_chunk_end_bhnl,
181        b_scaled_bnhlr,
182    ) = k3_ssd_chunk_state_extended(
183        x_bnlhp.clone(),
184        b_bnlgr.clone(),
185        da_cumsum_bhnl.clone(),
186        dt_discretized_bhnl.clone(),
187    );
188    san(&intra_chunk_state_bnhpr);
189    san(&b_bar_scale_bhnl);
190    san(&forward_decay_to_chunk_end_bhnl);
191    san(&b_scaled_bnhlr);
192
193    // K4 ───────────────────────────────────────────────────────────────────
194    let (chunk_input_state_bnhpr, _final_state_bhpr): (Tensor<B, 5>, Tensor<B, 4>) =
195        serial::k4_ssd_state_passing(
196            intra_chunk_state_bnhpr.clone(),
197            da_chunk_end_bhn.clone(),
198            initial_state_bhpr,
199        );
200    san(&chunk_input_state_bnhpr);
201    san(&_final_state_bhpr);
202
203    // ═══════════════════════════════════════════════════════════════════════
204    // K5 BACKWARD
205    // ═══════════════════════════════════════════════════════════════════════
206    // Expand CB for all heads
207    let cb_bnhll = cb_bngll
208        .clone()
209        .unsqueeze_dim::<6>(3) // cb_bng1ll
210        .expand([
211            batch,
212            nchunks,
213            ngroups,
214            heads_per_group,
215            chunk_len,
216            chunk_len,
217        ]) // cb_bngHll
218        .reshape([batch, nchunks, nheads, chunk_len, chunk_len]);
219
220    // Reshape inputs to [B,N,H,L,...] convention used inside K5
221    let da_cumsum_bnhl: Tensor<B, 4> = da_cumsum_bhnl.permute([0, 2, 1, 3]);
222    let dt_bnhl: Tensor<B, 4> = dt_discretized_bhnl.clone().permute([0, 2, 1, 3]);
223    let x_bnhlp: Tensor<B, 5> = x_bnlhp.clone().permute([0, 1, 3, 2, 4]);
224    let d_y_bnhlp: Tensor<B, 5> = d_y_bnlhp.clone().permute([0, 1, 3, 2, 4]);
225
226    // GQA-expand C: [B,N,L,G,R] → [B,N,H,L,R]
227    let c_bnhlr = c_bnlgr
228        .clone()
229        .unsqueeze_dim::<6>(4) // c_bnlg1r
230        .expand([
231            batch,
232            nchunks,
233            chunk_len,
234            ngroups,
235            heads_per_group,
236            state_rank,
237        ]) // c_bnlgHr
238        .reshape([batch, nchunks, chunk_len, nheads, state_rank]) // c_bnlhr
239        .permute([0, 1, 3, 2, 4]);
240
241    // ── SKIP backward ──────────────────────────────────────────────────────
242    // - | 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
243    // - | (d_skip_bnlhp = d_y_bnlhp)
244    let d_skip_bnlhp = d_y_bnlhp.clone();
245    //
246    //
247    // For d_h:
248    // - - | 33: mul: (d_bnlhp, x_bnlhp[*]) -> (skip_bnlhp)
249    // - - | (d_d_bnlhp = d_skip_bnlhp * x_bnlhp)
250    // - - | 32: expand: (d_111h1) -> (d_bnlhp)
251    // - - | 31: unsqueeze-dims: (d_h [*]) -> (d_111h1)
252    //
253    // - - | d_d[h] = Σ_{b,n,l,p} dy * x   — use permute+reshape to avoid chained sum_dim
254    let d_d_h = {
255        // [B,N,L,H,P] → permute to [H,B,N,L,P] → reshape [H, rest] → sum → [H]
256        d_skip_bnlhp.clone()
257            .permute([3, 0, 1, 2, 4]) // d_y_hbnlp
258            .reshape([nheads, batch * nchunks * chunk_len * per_head_dim]) // d_y_hBNLP
259            * x_bnlhp
260                .clone()
261                .permute([3, 0, 1, 2, 4]) // x_hbnlp
262                .reshape([nheads, batch * nchunks * chunk_len * per_head_dim]) // x_hBNLP
263    }
264    .sum_dim(1) // d_d_h1
265    .reshape([nheads]);
266    san(&d_d_h);
267    //
268    // For d_x:
269    // - - | 33: mul: (d_bnlhp, x_bnlhp[*]) -> (skip_bnlhp)
270    // - - | (d_x_skip_bnlhp = d_skip_bnlhp * d_bnlhp)
271    let d_x_skip_bnlhp = d_skip_bnlhp
272        * d_h
273            .clone()
274            .unsqueeze_dims::<5>(&[0, 1, 2, 4]) // d_111h1
275            // d_bnlhp
276            .expand([batch, nchunks, chunk_len, nheads, per_head_dim]);
277    san(&d_x_skip_bnlhp);
278
279    // ── BLUE backward ──────────────────────────────────────────────────────
280    // - | 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
281    // - | (d_y_partial_bnlhp = d_y_bnlhp)
282    let d_y_partial_bnhlp = d_y_bnhlp.clone();
283    //
284    // - | 35: permute: (y_partial_bnhlp) -> (y_partial_bnlhp)
285    // - | 34: add: (blue_scaled_bnhlp, orange_bnhlp) -> (y_partial_bnhlp)
286    // - | (d_blue_scaled_bnhlp = d_y_partial_bnhlp)
287    let d_blue_scaled_bnhlp = d_y_partial_bnhlp;
288    // - | 16: mul: (blue_bnhlp, exp_da_cumsum_bnhlp) -> (blue_scaled_bnhlp)
289    // - | (d_blue_bnhlp = d_blue_scaled_bnhlp * exp_da_cumsum_bnhlp)
290    //
291    // - | blue[b,n,h,l,p] = exp(da[b,n,h,l]) * Σ_r C[b,n,h,l,r] * state[b,n,h,p,r]
292    let exp_da_cumsum_bnhl: Tensor<B, 4> = da_cumsum_bnhl.clone().exp();
293    san(&exp_da_cumsum_bnhl);
294    let exp_da_cumsum_bnhlp = exp_da_cumsum_bnhl.clone().unsqueeze_dim::<5>(4).expand([
295        batch,
296        nchunks,
297        nheads,
298        chunk_len,
299        per_head_dim,
300    ]);
301    let d_blue_bnhlp: Tensor<B, 5> = d_blue_scaled_bnhlp.clone() * exp_da_cumsum_bnhlp.clone();
302    san(&d_blue_bnhlp);
303    //
304    // For d_chunk_input_state_bnhpr:
305    // - | 15: matmul: (c_bnhlr, chunk_input_state_bnhrp) -> (blue_bnhlp)
306    // - - | (d_chunk_input_state_bnhrp = c_bnhlr^T @ d_blue_bnhlp)
307    // - - | 14: permute: (chunk_input_state_bnhpr [!]) -> (chunk_input_state_bnhrp)
308    //
309    // - - | d_state[b,n,h,p,r] = Σ_l (scaled_dy[b,n,h,l,p] * C[b,n,h,l,r])
310    // - - |  = C^T[R,L] @ scaled_dy[L,P]  for fixed (b,n,h)
311    // - - |  [B,N,H,R,L] @ [B,N,H,L,P] → [B,N,H,R,P] → permute → [B,N,H,P,R]
312    let d_chunk_input_state_bnhpr = c_bnhlr
313        .clone()
314        .permute([0, 1, 2, 4, 3]) // c_bnhrl
315        .matmul(d_blue_bnhlp.clone()) // d_chunk_input_state_bnhrp
316        .permute([0, 1, 2, 4, 3]);
317    san(&d_chunk_input_state_bnhpr);
318    //
319    // For d_c from BLUE:
320    // - | 15: matmul: (c_bnhlr, chunk_input_state_bnhrp) -> (blue_bnhlp)
321    // - - | (d_c_bnhlr = d_blue_bnhlp @ chunk_input_state_bnhrp^T)
322    // - - | 7: permute: (c_bnlhr) -> (c_bnhlr)
323    // - - | 6: reshape: (c_bnlgHr) -> (c_bnlhr)
324    // - - | 5: expand: (c_bnlg1r) -> (c_bnlgHr)
325    // - - | 4: unsqueeze: (c_bnlgr [*]) -> (c_bnlg1r)
326    //
327    // - - | d_C[l,r] = Σ_p scaled_dy[l,p] * state[p,r]
328    // - - |  [B,N,H,L,P] @ [B,N,H,P,R] → [B,N,H,L,R]
329    let d_c_blue_bnhlr = d_blue_bnhlp.clone().matmul(chunk_input_state_bnhpr.clone());
330    san(&d_c_blue_bnhlr);
331    // - - | GQA reduce: [B,N,H,L,R] → [B,N,L,G,R]
332    let d_c_blue_bnlgr = d_c_blue_bnhlr
333        .reshape([
334            batch,
335            nchunks,
336            ngroups,
337            heads_per_group,
338            chunk_len,
339            state_rank,
340        ]) // d_c_blue_bngHlr
341        .sum_dim(3) // d_c_blue_bng1lr
342        .squeeze_dim::<5>(3) // d_c_blue_bnglr
343        .permute([0, 1, 3, 2, 4]);
344    san(&d_c_blue_bnlgr);
345    //
346    // For d_da_cumsum from BLUE:
347    // - | 16: mul: (blue_bnhlp, exp_da_cumsum_bnhlp) -> (blue_scaled_bnhlp)
348    // - | (d_exp_da_cumsum_bnhlp = d_blue_scaled_bnhlp * blue_bnhlp)
349    let blue_bnhlp = c_bnhlr
350        .clone()
351        .matmul(chunk_input_state_bnhpr.clone().permute([0, 1, 2, 4, 3])); // replay forward step 15
352    san(&blue_bnhlp);
353    let d_exp_da_cumsum_bnhlp = d_blue_scaled_bnhlp.clone() * blue_bnhlp;
354    san(&d_exp_da_cumsum_bnhlp);
355    //
356    // - | blue_no_scale = C @ state^T  [L,P]
357    // - - | 13: expand: (exp_da_cumsum_bnhl1) -> (exp_da_cumsum_bnhlp)
358    // - - | 12: unsqueeze: (exp_da_cumsum_bnhl) -> (exp_da_cumsum_bnhl1)
359    // - - | 11: exp: (da_cumsum_bnhl) -> (exp_da_cumsum_bnhl)
360    // - - | (d_da_cumsum_bnhl = d_exp_da_cumsum_bnhlp * exp(da_cumsum_bnhl))
361    // - - | 1/36: permute: (da_cumsum_bhnl [*]) -> (da_cumsum_bnhl)
362    //
363    // - - | d_da[l] = Σ_p dy[l,p] * exp_da[l] * blue_no_scale[l,p]
364    let d_da_blue_bnhl = (d_exp_da_cumsum_bnhlp * exp_da_cumsum_bnhlp)
365        .sum_dim(4) // d_da_blue_bnhl1
366        .squeeze_dim::<4>(4);
367    san(&d_da_blue_bnhl);
368    let d_da_blue_bhnl = d_da_blue_bnhl.permute([0, 2, 1, 3]);
369
370    // ── ORANGE backward ─────────────────────────────────────────────────────
371    //  y_orange[l,p] = Σ_{s≤l} CB[l,s] * exp(da[l]-da[s]) * dt[s] * x[s,p]
372    // Precompute weight matrix CB_w [B,N,H,L_tgt,L_src]
373    // replay forward steps 17-29
374    let da_cumsum_target_bnhll = da_cumsum_bnhl
375        .clone()
376        .unsqueeze_dim::<5>(4) // da_cumsum_bnhl1 // forward step 17
377        .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // forward step 18
378    let da_cumsum_source_bnhll = da_cumsum_bnhl
379        .clone()
380        .unsqueeze_dim::<5>(3) // da_cumsum_bnh1l // forward step 19
381        .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // forward step 20
382    let da_cumsum_diff_bnhll = da_cumsum_target_bnhll - da_cumsum_source_bnhll; // forward step 21
383    san(&da_cumsum_diff_bnhll);
384    // forward step 21.1: built at [L,L] and broadcast — mask values do not depend on (b,n,h).
385    let causal_mask_bnhll: Tensor<B, 5, burn::prelude::Bool> =
386        Tensor::<B, 2, burn::prelude::Bool>::tril_mask([chunk_len, chunk_len], 0, &device)
387            .reshape([1, 1, 1, chunk_len, chunk_len])
388            .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
389    // forward step 21.2
390    // Causal mask and exp stabilizer (-inf above the main diagonal, 0 elsewhere).
391    let da_cumsum_diff_masked_bnhll =
392        da_cumsum_diff_bnhll.mask_fill(causal_mask_bnhll.clone(), f32::NEG_INFINITY);
393    let da_cumsum_diff_exp_bnhll = (da_cumsum_diff_masked_bnhll).exp(); // forward steps 22
394    san(&da_cumsum_diff_exp_bnhll);
395    let dt_source_bnhll = dt_bnhl
396        .clone()
397        .unsqueeze_dim::<5>(3) // dt_bnh1l // forward step 23
398        .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // forward step 24
399    // // Causal mask (0 above the main diagonal, 1 elsewhere).
400    // let causal_mask_bnhll =
401    //     Tensor::ones([batch, nchunks, nheads, chunk_len, chunk_len], &device).tril(0); // forward steps 25-26
402    // CB_w[l,s] = CB[l,s] * decay[l,s] * dt[s] * mask[l,s]
403    let orange_lhs_partial1_bnhll: Tensor<B, 5> = // forward step 27
404        cb_bnhll.clone() * da_cumsum_diff_exp_bnhll.clone();
405    san(&orange_lhs_partial1_bnhll);
406    let orange_lhs_partial2_bnhll: Tensor<B, 5> = // forward step 28
407        orange_lhs_partial1_bnhll.clone() * dt_source_bnhll.clone();
408    san(&orange_lhs_partial2_bnhll);
409    // let orange_lhs_partial3_bnhll: Tensor<B, 5> = // forward step 29
410    //     orange_lhs_partial2_bnhll.clone() * causal_mask_bnhll.clone();
411    //
412    // Backwads:
413    // - | 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
414    // - | (d_y_partial_bnlhp = d_y_bnlhp)
415    let d_y_partial_bnhlp = d_y_bnhlp.clone();
416    // - | 35: permute: (y_partial_bnhlp) -> (y_partial_bnlhp)
417    // - | 34: add: (blue_scaled_bnhlp, orange_bnhlp) -> (y_partial_bnhlp)
418    // - | (d_orange_bnhlp = d_y_partial_bnhlp)
419    let d_orange_bnhlp = d_y_partial_bnhlp;
420    // - | 30: matmul: (orange_lhs_partial2_bnhll, x_bnhlp) -> (orange_bnhlp)
421    // - | (d_orange_lhs_partial2_bnhll = d_orange_bnhlp @ x_bnhlp^T)
422    // d_CB_w: dy @ x^T   [B,N,H,L_tgt,L_src]
423    let d_orange_lhs_partial2_bnhll = d_orange_bnhlp
424        .clone()
425        .matmul(x_bnhlp.clone().permute([0, 1, 2, 4, 3])); // [B,N,H,L_tgt,L_src]
426    san(&d_orange_lhs_partial2_bnhll);
427    //
428    // - | For d_x:
429    // - - | (d_x_bnhlp = orange_lhs_partial2_bnhll^T @ d_orange_bnhlp)
430    // - - | d_x from ORANGE: CB_w^T @ dy  (transpose source/target dims)
431    // - - |  [B,N,H,L_src,L_tgt] @ [B,N,H,L_tgt,P] → [B,N,H,L_src,P]
432    let d_x_orange_bnhlp = orange_lhs_partial2_bnhll
433        .clone()
434        .permute([0, 1, 2, 4, 3]) // [B,N,H,L_src,L_tgt]
435        .matmul(d_orange_bnhlp.clone()); // [B,N,H,L_src,P]
436    san(&d_x_orange_bnhlp);
437    //
438    // - | 21.2: mask-fill: (.., ..) -> (..)
439    // Bring the (step 21.2) causal mask ahead: above upper diagonal set to 0.
440    let d_orange_lhs_partial2_bnhll = d_orange_lhs_partial2_bnhll.mask_fill(causal_mask_bnhll, 0.);
441    san(&d_orange_lhs_partial2_bnhll);
442    // - | 28: mul: (orange_lhs_partial1_bnhll, dt_source_bnhll) -> (orange_lhs_partial2_bnhll)
443    let d_orange_lhs_partial1_bnhll = d_orange_lhs_partial2_bnhll.clone() * dt_source_bnhll.clone();
444    san(&d_orange_lhs_partial1_bnhll);
445    // - | For d_dt from ORANGE:
446    // - - | 24: expand: (dt_bnh1l) -> (dt_source_bnhll)
447    // - - | 23: unsqueeze: (dt_bnhl) -> (dt_bnh1l)
448    // - - | 2: permute: (dt_discretized_bhnl [*]) -> (dt_bnhl)
449    // - - | d_dt[s] = Σ_{l≥s} d_CB_w[l,s] * CB[l,s] * decay[l,s] * mask[l,s]
450    // - - |  = (d_cb_w * cb * decay * mask).sum(L_tgt dim=3)
451    let d_dt_orange_bnhl = (d_orange_lhs_partial2_bnhll.clone()
452        * orange_lhs_partial1_bnhll.clone())
453    .sum_dim(3) // d_dt_orange_bnh1l
454    .squeeze_dim::<4>(3);
455    san(&d_dt_orange_bnhl);
456    let d_dt_orange_bhnl = d_dt_orange_bnhl.permute([0, 2, 1, 3]);
457    //
458    // - | For d_da from ORANGE:
459    // - - | decay = exp(da_tgt - da_src)
460    // - - | d_decay = d_CB_w * CB * dt_src * mask
461    // - - | d_da_tgt[l] += Σ_s (d_decay * decay)[l,s]
462    // - - | d_da_src[s] -= Σ_l (d_decay * decay)[l,s]
463    // - - | 27: mul: (cb_bnhll, da_cumsum_diff_exp_bnhll) -> (orange_lhs_partial1_bnhll)
464    let d_da_cumsum_diff_exp_bnhll = d_orange_lhs_partial1_bnhll.clone() * cb_bnhll.clone();
465    san(&d_da_cumsum_diff_exp_bnhll);
466    // - - | 22: exp: (da_cumsum_diff_bnhll) -> (da_cumsum_diff_exp_bnhll)
467    // - - | (d_da_cumsum_diff_bnhll = d_da_cumsum_diff_exp_bnhll * exp(da_cumsum_diff_bnhll))
468    let d_da_cumsum_diff_bnhll = d_da_cumsum_diff_exp_bnhll * da_cumsum_diff_exp_bnhll.clone();
469    san(&d_da_cumsum_diff_bnhll);
470    // - - | 21: sub: (da_cumsum_target_bnhll, da_cumsum_source_bnhll) -> (da_cumsum_diff_bnhll)
471    // - - | 20: expand: (da_cumsum_bnh1l) -> (da_cumsum_source_bnhll)
472    // - - | 19: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnh1l)
473    // - - | 18: expand: (da_cumsum_bnhl1) -> (da_cumsum_target_bnhll)
474    // - - | 17: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnhl1)
475    // - - | 1/36: permute: (da_cumsum_bhnl [*]) -> (da_cumsum_bnhl)
476    let d_da_tgt_bnhl = d_da_cumsum_diff_bnhll
477        .clone()
478        .sum_dim(4) // d_da_cumsum_diff_bnhl1
479        .squeeze_dim::<4>(4);
480    san(&d_da_tgt_bnhl);
481    let d_da_src_bnhl = d_da_cumsum_diff_bnhll
482        .sum_dim(3) // d_da_cumsum_diff_bnh1l
483        .squeeze_dim::<4>(3);
484    san(&d_da_src_bnhl);
485    let d_da_orange_bhnl = (d_da_tgt_bnhl - d_da_src_bnhl).permute([0, 2, 1, 3]); // [B,H,N,L]
486    san(&d_da_orange_bhnl);
487    //
488    // - | For d_cb:
489    // - - | 27: mul: (cb_bnhll, da_cumsum_diff_exp_bnhll) -> (orange_lhs_partial1_bnhll)
490    let d_cb_bnhll = d_orange_lhs_partial1_bnhll * da_cumsum_diff_exp_bnhll.clone();
491    san(&d_cb_bnhll);
492    // - - | d_CB (per head, before GQA reduction):
493    // - - |  CB_w = CB * decay * dt * mask  →  d_CB[l,s] = d_CB_w[l,s] * decay[l,s] * dt[s] * mask
494    // - - | GQA reduce: [B,N,H,L,L] → [B,N,G,L,L]
495    let d_cb_bngll = d_cb_bnhll
496        .reshape([
497            batch,
498            nchunks,
499            ngroups,
500            heads_per_group,
501            chunk_len,
502            chunk_len,
503        ]) // d_cb_bngHll
504        .sum_dim(3) // d_cb_bng1ll
505        .squeeze_dim::<5>(3);
506    san(&d_cb_bngll);
507
508    // ═══════════════════════════════════════════════════════════════════════
509    // K4 BACKWARD (reverse serial recurrence)
510    // ═══════════════════════════════════════════════════════════════════════
511    //
512    // - 5/5: stack: (chunk_input_state_vec_bhpr [!]) -> (chunk_input_state_bnhpr [out][!])
513    // - 4: vec-pop: (chunk_input_state_vec_bhpr [vec][!]) -> (final_state_bhpr [elem][out][!])
514    // - 3: serial-loop: (0..nchunks)
515    //
516    // last d_running_state_bhpr:
517    let mut d_running_state_bhpr: Tensor<B, 4> = d_final_bhpr; // [B,H,P,R]
518    //
519    // d_intra[c] and d_da_end[c] collected during reverse traversal.
520    let mut d_intra_slices: Vec<Tensor<B, 4>> = Vec::with_capacity(nchunks);
521    let mut d_da_end_bh_slices: Vec<Tensor<B, 2>> = Vec::with_capacity(nchunks);
522    //
523    for i_chunk in (0..nchunks).rev() {
524        // access re-calculated running state
525        let running_state_bhpr = chunk_input_state_bnhpr
526            .clone()
527            .slice(s![.., i_chunk, .., .., ..])
528            .squeeze_dim(1);
529        assert_eq!(
530            [batch, nheads, per_head_dim, state_rank],
531            running_state_bhpr.dims()
532        );
533        //
534        // - 3.9/3.9: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
535        d_intra_slices.push(d_running_state_bhpr.clone());
536        //
537        // - 3.8: add: (running_state_bhpr, intra_state_bhpr) -> (running_state_bhpr)
538        let _d_intra_state_bhpr = d_running_state_bhpr.clone();
539        //
540        // - 3.7: mul: (decay_bhpr, running_state_bhpr) -> (running_state_bhpr)
541        let d_decay_bhpr = d_running_state_bhpr.clone() * running_state_bhpr.clone();
542        san(&d_decay_bhpr);
543        // recalculate decay_bhpr
544        let decay_bhpr = da_chunk_end_bhn
545            .clone()
546            .slice(s![.., .., i_chunk]) // da_chunk_end_bh1 // replay forward step 3.3
547            .exp() // exp_da_chunk_end_bh1 // replay forward step 3.4
548            .unsqueeze_dim::<4>(3) // exp_da_chunk_end_bh11 // replay forward step 3.5
549            .expand([batch, nheads, per_head_dim, state_rank]); // replay forward step 3.6
550        san(&decay_bhpr);
551        // - 3.6: expand: (exp_da_chunk_end_bh11) -> (decay_bhpr)
552        // - 3.5: unsqueeze: (exp_da_chunk_end_bh1) -> (exp_da_chunk_end_bh11)
553        // - 3.4: exp: (da_chunk_end_bh1) -> (exp_da_chunk_end_bh1)
554        // (d_da_chunk_end_bh1 = d_exp_da_chunk_end_bh1 * exp(da_chunk_end_bh1))
555        // - 3.3: slice: (da_chunk_end_bhn [in][*]) -> (da_chunk_end_bh1)
556        let d_da_chunk_end_bhpr = d_decay_bhpr * decay_bhpr.clone(); // note: decay is expanded exp(da_chunk_end)
557        san(&d_da_chunk_end_bhpr);
558        let d_da_chunk_end_bh = d_da_chunk_end_bhpr
559            .reshape([batch, nheads, per_head_dim * state_rank]) // d_da_chunk_end_bhPR
560            .sum_dim(2) // d_da_chunk_end_bh1
561            .squeeze_dim::<2>(2);
562        san(&d_da_chunk_end_bh);
563        d_da_end_bh_slices.push(d_da_chunk_end_bh);
564        //
565        // - 3.2: squeeze: (intra_chunk_state_b1hpr) -> (intra_state_bhpr)
566        // - 3.1/3.9: slice: (intra_chunk_state_bnhpr [in][!]) -> (intra_chunk_state_b1hpr)
567        //
568        // Propagate: d_running_state_bhpr_prev = scale * d_running_state_bhpr + d_chunk_input_state_bhpr
569        //   (d_cis[c] = gradient of chunk_input_state[:, c] flowing in from K5 BLUE)
570        let d_chunk_input_state_bhpr = d_chunk_input_state_bnhpr
571            .clone()
572            .slice(s![.., i_chunk, .., .., ..]) // d_chunk_input_state_b1hpr // d_chunk_input_state_b1hpr
573            .squeeze_dim::<4>(1);
574        // TODO: understand this.
575        d_running_state_bhpr = decay_bhpr * d_running_state_bhpr + d_chunk_input_state_bhpr;
576        san(&d_running_state_bhpr);
577    }
578    // - 2: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
579    // - 1/5: init-mut: (initial_state_bhpr [in][*]) -> (running_state_bhpr)
580    //
581    // After the loop, d_initial_state = the (reverse loop) tailing d_running_state_bhpr
582    let d_initial_state_bhpr = d_running_state_bhpr;
583    //
584    // Restore natural order
585    d_intra_slices.reverse();
586    d_da_end_bh_slices.reverse();
587    //
588    let d_intra_chunk_state_bnhpr = Tensor::stack(d_intra_slices, 1);
589    //
590    // d_da_end_bhn [B,H,N]: scatter to last position of d_da_cumsum
591    let d_da_end_bhn: Tensor<B, 3> = Tensor::stack(d_da_end_bh_slices, 2);
592    //
593    // TODO: understand this.
594    // Pad to [B,H,N,L] — only last L-position is non-zero
595    let d_da_cumsum_k4_bhnl = {
596        let zeros = Tensor::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device);
597        let d_da_end_bhn1 = d_da_end_bhn.unsqueeze_dim::<4>(3);
598        Tensor::cat(vec![zeros, d_da_end_bhn1], 3)
599    };
600
601    // ═══════════════════════════════════════════════════════════════════════
602    // K3 BACKWARD
603    // ═══════════════════════════════════════════════════════════════════════
604    let x_bnhpl = x_bnlhp.clone().permute([0, 1, 3, 4, 2]);
605    // For d_x_bnlhp:
606    // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
607    // - (d_x_bnhpl = d_intra_chunk_state_bnhpr @ b_scaled_bnhlr^T)
608    let d_x_k3_bnhpl = d_intra_chunk_state_bnhpr
609        .clone()
610        .matmul(b_scaled_bnhlr.clone().permute([0, 1, 2, 4, 3]));
611    san(&d_x_k3_bnhpl);
612    // - 1/15: permute: (x_bnlhp [in][*]) -> (x_bnhpl)
613    let d_x_k3_bnlhp = d_x_k3_bnhpl.permute([0, 1, 4, 2, 3]);
614    //
615    // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
616    // (d_b_scaled_bnhlr = x_bnhpl^T @ d_intra_chunk_state_bnhpr)
617    let d_b_scaled_bnhlr = x_bnhpl
618        .permute([0, 1, 2, 4, 3]) // x_bnhlp
619        .matmul(d_intra_chunk_state_bnhpr);
620    san(&d_b_scaled_bnhlr);
621    //
622    // For d_b:
623    // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
624    // - (d_b_bnhlr = d_b_scaled_bnhlr * b_bar_scale_bnhlr)
625    let b_bar_scale_bnhlr = b_bar_scale_bhnl
626        .clone()
627        .permute([0, 2, 1, 3]) // b_bar_scale_bnhl // replay forward step 11
628        .unsqueeze_dim::<5>(4) // b_bar_scale_bnhl1 // replay forward step 12
629        .expand([batch, nchunks, nheads, chunk_len, state_rank]); // replay forward step 13
630    let d_b_k3_bnhlr = d_b_scaled_bnhlr.clone() * b_bar_scale_bnhlr;
631    san(&d_b_k3_bnhlr);
632    // - 5: reshape: (b_bngHlr) -> (b_bnhlr)
633    // - 4: expand: (b_bng1lr) -> (b_bngHlr)
634    // - 3: unsqueeze: (b_bnglr) -> (b_bng1lr)
635    // - 2: permute: (b_bnlgr [in][*]) -> (b_bnglr)
636    // GQA reduce: [B,N,H,L,R] → [B,N,G,L,R] → [B,N,L,G,R]
637    let d_b_k3_bnlgr = d_b_k3_bnhlr
638        .reshape([
639            batch,
640            nchunks,
641            ngroups,
642            heads_per_group,
643            chunk_len,
644            state_rank,
645        ]) // d_b_k3_bngHlr
646        .sum_dim(3) // d_b_k3_bng1lr
647        .squeeze_dim::<5>(3) // d_b_k3_bnglr
648        .permute([0, 1, 3, 2, 4]);
649    san(&d_b_k3_bnlgr);
650
651    // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
652    // - (d_b_bar_scale_bnhlr = d_b_scaled_bnhlr * b_bnhlr)
653    // GQA-expand B back to per-head for the product: [B,N,G,L,R] → [B,N,H,L,R]
654    let b_bnhlr = b_bnlgr
655        .clone()
656        .permute([0, 1, 3, 2, 4]) // b_bnglr // replay forward step 2
657        .unsqueeze_dim::<6>(3) // b_bng1lr // replay forward step 3
658        // b_bngHlr
659        .expand([
660            batch,
661            nchunks,
662            ngroups,
663            heads_per_group,
664            chunk_len,
665            state_rank,
666        ]) // replay forward step 4
667        .reshape([batch, nchunks, nheads, chunk_len, state_rank]); // replay forward step 5
668    let d_b_bar_scale_bnhlr = d_b_scaled_bnhlr.clone() * b_bnhlr;
669    san(&d_b_bar_scale_bnhlr);
670    // - 13: expand: (b_bar_scale_bnhl1) -> (b_bar_scale_bnhlr)
671    // - 12: unsqueeze: (b_bar_scale_bnhl) -> (b_bar_scale_bnhl1)
672    // - 11: permute: (b_bar_scale_bhnl [+]) -> (b_bar_scale_bnhl)
673    let d_b_bar_scale_bhnl = d_b_bar_scale_bnhlr
674        .sum_dim(4) // d_b_bar_scale_bnhl1
675        .squeeze_dim::<4>(4) // d_b_bar_scale_bnhl
676        .permute([0, 2, 1, 3]);
677    san(&d_b_bar_scale_bhnl);
678    //
679    // For d_da_cumsum_bhnl:
680    // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
681    // - (d_forward_decay_to_chunk_end_bhnl = d_b_bar_scale_bhnl * dt_discretized_bhnl)
682    let d_forward_decay_to_chunk_end_bhnl =
683        d_b_bar_scale_bhnl.clone() * dt_discretized_bhnl.clone();
684    san(&d_forward_decay_to_chunk_end_bhnl);
685    // - 9: exp: (da_delta_bhnl) -> (forward_decay_to_chunk_end_bhnl [+])
686    // - (d_da_delta_bhnl = d_forward_decay_to_chunk_end_bhnl * exp(da_delta_bhnl))
687    // note: forward_decay_to_chunk_end_bhnl = exp(da_delta_bhnl)
688    let d_da_delta_bhnl =
689        d_forward_decay_to_chunk_end_bhnl * forward_decay_to_chunk_end_bhnl.clone();
690    san(&d_da_delta_bhnl);
691    // - 8: sub: (da_cumsum_last_bhnl, da_cumsum_bhnl [from K1][*]) -> (da_delta_bhnl)
692    let d_da_cumsum_last_bhnl = d_da_delta_bhnl.clone();
693    let d_da_cumsum_sub_bhnl = -d_da_delta_bhnl.clone();
694    // - 7: expand: (da_cumsum_last_in_chunk_bhn1) -> (da_cumsum_last_bhnl)
695    // - 6: slice: (da_cumsum_bhnl [in][*]) -> (da_cumsum_last_in_chunk_bhn1)
696    let d_da_cumsum_last_bhn = d_da_cumsum_last_bhnl
697        .sum_dim(3) // d_da_cumsum_last_bhn1
698        .squeeze_dim::<3>(3);
699    san(&d_da_cumsum_last_bhn);
700    //
701    // For d_dt_discretized_bhnl:
702    // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
703    // - (d_dt_discretized_bhnl = d_b_bar_scale_bhnl * forward_decay_to_chunk_end_bhnl)
704    let d_dt_discretized_k3_bhnl = d_b_bar_scale_bhnl * forward_decay_to_chunk_end_bhnl;
705    san(&d_dt_discretized_k3_bhnl);
706    //
707
708    // TODO: understand this.
709    let d_da_cumsum_k3_bhnl = {
710        let zeros = Tensor::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device);
711        let d_last = d_da_cumsum_last_bhn.unsqueeze_dim::<4>(3);
712        d_da_cumsum_sub_bhnl + Tensor::cat(vec![zeros, d_last], 3)
713    };
714    san(&d_da_cumsum_k3_bhnl);
715
716    // ═══════════════════════════════════════════════════════════════════════
717    // K2 BACKWARD (from d_cb_bngll)
718    // ═══════════════════════════════════════════════════════════════════════
719    let c_bnglr = c_bnlgr.clone().permute([0, 1, 3, 2, 4]);
720    let b_bnglr = b_bnlgr.clone().permute([0, 1, 3, 2, 4]);
721    // - 3/3: matmul: (c_bnglr, b_bngrl) -> (cb_bngll [out][!])
722    // - cb[b,n,g,l,s] = Σ_r c[l,r]*b[s,r]  →  CB = C @ B^T
723    // -  d_C_bngls = d_CB @ B   [B,N,G,L,L_src] @ [B,N,G,L_src,R] → [B,N,G,L,R]
724    // -  d_B_bngls = d_CB^T @ C [B,N,G,L_src,L] @ [B,N,G,L,R]   → [B,N,G,L_src,R]
725    let d_c_k2_bnglr = d_cb_bngll.clone().matmul(b_bnglr.clone());
726    san(&d_c_k2_bnglr);
727    let d_c_k2_bnlgr = d_c_k2_bnglr.permute([0, 1, 3, 2, 4]);
728
729    let d_b_k2_bnglr = d_cb_bngll
730        .permute([0, 1, 2, 4, 3]) // [B,N,G,L_src,L_tgt]
731        .matmul(c_bnglr.clone()); // [B,N,G,L_src,R]
732    san(&d_b_k2_bnglr);
733    let d_b_k2_bnlgr = d_b_k2_bnglr.permute([0, 1, 3, 2, 4]); // [B,N,L,G,R]
734
735    // ═══════════════════════════════════════════════════════════════════════
736    // SUM GRADIENT CONTRIBUTIONS
737    // ═══════════════════════════════════════════════════════════════════════
738
739    // Accumulated gradient of the cumulative sum produced by K1.
740    let d_da_cumsum_bhnl =
741        d_da_blue_bhnl + d_da_orange_bhnl + d_da_cumsum_k3_bhnl + d_da_cumsum_k4_bhnl;
742    san(&d_da_cumsum_bhnl);
743
744    // ── K1 BACKWARD ────────────────────────────────────────────────────────
745    // K1 forward: da_cumsum[l] = cumsum_l(dt[l] * a_decay)
746    //
747    // Reverse cumsum (suffix sum) converts d_da_cumsum → d_da:
748    //   d_da[l] = sum_{k >= l} d_da_cumsum[k]
749    //           = total_sum - cumsum(d_da_cumsum)[l-1]   (cumsum[-1] == 0)
750    let d_da_cumsum_total_bhnl = d_da_cumsum_bhnl
751        .clone()
752        .sum_dim(3) // [B,H,N,1]
753        .expand([batch, nheads, nchunks, chunk_len]);
754    let prefix_sum_bhnl = d_da_cumsum_bhnl.clone().cumsum(3); // [B,H,N,L]
755    let zeros_bhn1 = Tensor::<B, 4>::zeros([batch, nheads, nchunks, 1], &device);
756    // prefix_sum shifted right by 1 (i.e., cumsum[l-1], with cumsum[-1] = 0)
757    let prefix_sum_shifted_bhnl = Tensor::cat(
758        vec![zeros_bhn1, prefix_sum_bhnl.narrow(3, 0, chunk_len - 1)],
759        3,
760    );
761    let d_da_bhnl = d_da_cumsum_total_bhnl - prefix_sum_shifted_bhnl; // suffix sum [B,H,N,L]
762    san(&d_da_bhnl);
763    // d_dt from K1: d_dt = d_da * a_decay
764    let a_decay_expand = a_decay_h
765        .clone()
766        .unsqueeze_dims::<4>(&[0, 2, 3])
767        .expand([batch, nheads, nchunks, chunk_len]);
768    let d_dt_k1_bhnl = d_da_bhnl.clone() * a_decay_expand;
769    san(&d_dt_k1_bhnl);
770    // d_a_decay_h from K1: d_a[h] = sum_{b,n,l} d_da[b,h,n,l] * dt[b,h,n,l]
771    let d_a_decay_h = (d_da_bhnl * dt_discretized_bhnl.clone())
772        .permute([1, 0, 2, 3]) // [H,B,N,L]
773        .reshape([nheads, batch * nchunks * chunk_len])
774        .sum_dim(1) // [H,1]
775        .reshape([nheads]);
776    san(&d_a_decay_h);
777
778    let d_dt_discretized_bhnl = d_dt_orange_bhnl + d_dt_discretized_k3_bhnl + d_dt_k1_bhnl;
779    san(&d_dt_discretized_bhnl);
780
781    let d_x_orange_bnlhp = d_x_orange_bnhlp.permute([0, 1, 3, 2, 4]);
782    let d_x_bnlhp = d_x_skip_bnlhp + d_x_k3_bnlhp + d_x_orange_bnlhp;
783    san(&d_x_bnlhp);
784
785    let d_b_bnlgr = d_b_k2_bnlgr + d_b_k3_bnlgr;
786    san(&d_b_bnlgr);
787    let d_c_bnlgr = d_c_k2_bnlgr + d_c_blue_bnlgr;
788    san(&d_c_bnlgr);
789
790    CombinedGrads {
791        d_a_decay_h,
792        d_dt_discretized_bhnl,
793        d_x_bnlhp,
794        d_b_bnlgr,
795        d_c_bnlgr,
796        d_d_h,
797        d_initial_state_bhpr,
798    }
799}