burn_mamba/mamba2/ssd/serial_recalculated/combined_backward.rs
1use crate::mamba2::ssd::serial;
2use crate::utils::sanity::sanity as san;
3use burn::prelude::*;
4
5pub struct CombinedGrads<B: Backend> {
6 pub d_x_bnlhp: Tensor<B, 5>,
7 pub d_dt_discretized_bhnl: Tensor<B, 4>,
8 pub d_b_bnlgr: Tensor<B, 5>,
9 pub d_c_bnlgr: Tensor<B, 5>,
10 pub d_d_h: Tensor<B, 1>,
11 pub d_initial_state_bhpr: Tensor<B, 4>,
12 pub d_a_decay_h: Tensor<B, 1>,
13}
14
15/// Same as [k3_ssd_chunk_state](serial::k3_ssd_chunk_state) but return some intermediaries
16/// that are useful to the custom backward operation.
17///
18/// Returns:
19/// - intra_chunk_state_bnhpr
20/// - b_bar_scale_bhnl
21/// - forward_decay_to_chunk_end_bhnl
22/// - b_scaled_bnhlr
23pub fn k3_ssd_chunk_state_extended<B: Backend>(
24 x_bnlhp: Tensor<B, 5>,
25 b_bnlgr: Tensor<B, 5>,
26 da_cumsum_bhnl: Tensor<B, 4>,
27 dt_discretized_bhnl: Tensor<B, 4>,
28) -> (Tensor<B, 5>, Tensor<B, 4>, Tensor<B, 4>, Tensor<B, 5>) {
29 use burn::tensor::s;
30
31 let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
32 let [.., ngroups, state_rank] = b_bnlgr.dims();
33
34 // permute b and x to prepare them for the mamtul
35 // - 1/15: permute: (x_bnlhp [in][*]) -> (x_bnhpl)
36 let x_bnhpl = x_bnlhp.clone().permute([0, 1, 3, 4, 2]);
37 assert_eq!(
38 [batch, nchunks, nheads, per_head_dim, chunk_len],
39 x_bnhpl.dims()
40 );
41 // - 2: permute: (b_bnlgr [in][*]) -> (b_bnglr)
42 let b_bnglr = b_bnlgr.permute([0, 1, 3, 2, 4]); // note: still in groups instead of heads
43 assert_eq!(
44 [batch, nchunks, ngroups, chunk_len, state_rank],
45 b_bnglr.dims()
46 );
47
48 // Expand B from ngroups to nheads by repeating each group's
49 // projection across all heads_per_group heads in that group.
50 let heads_per_group = nheads / ngroups;
51 let b_bnhlr = b_bnglr
52 // - 3: unsqueeze: (b_bnglr) -> (b_bng1lr)
53 .unsqueeze_dim::<6>(3) // b_bng1lr
54 // - 4: expand: (b_bng1lr) -> (b_bngHlr)
55 .expand([
56 batch,
57 nchunks,
58 ngroups,
59 heads_per_group,
60 chunk_len,
61 state_rank,
62 ]) // b_bngHlr
63 // - 5: reshape: (b_bngHlr) -> (b_bnhlr)
64 .reshape([batch, nchunks, nheads, chunk_len, state_rank]);
65
66 // scale b
67 let da_cumsum_last_in_chunk_bhn1 =
68 // - 6: slice: (da_cumsum_bhnl [in][*]) -> (da_cumsum_last_in_chunk_bhn1)
69 da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
70 assert_eq!(
71 [batch, nheads, nchunks, 1],
72 da_cumsum_last_in_chunk_bhn1.dims()
73 );
74
75 // - 7: expand: (da_cumsum_last_in_chunk_bhn1) -> (da_cumsum_last_bhnl)
76 let da_cumsum_last_bhnl =
77 da_cumsum_last_in_chunk_bhn1.expand([batch, nheads, nchunks, chunk_len]);
78 // - 8: sub: (da_cumsum_last_bhnl, da_cumsum_bhnl [from K1][*]) -> (da_delta_bhnl)
79 let da_delta_bhnl = da_cumsum_last_bhnl - da_cumsum_bhnl.clone();
80 san(&da_delta_bhnl);
81 // - 9: exp: (da_delta_bhnl) -> (forward_decay_to_chunk_end_bhnl [+])
82 let forward_decay_to_chunk_end_bhnl = da_delta_bhnl.exp();
83 assert_eq!(
84 [batch, nheads, nchunks, chunk_len],
85 forward_decay_to_chunk_end_bhnl.dims()
86 );
87 san(&forward_decay_to_chunk_end_bhnl);
88
89 // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
90 let b_bar_scale_bhnl = forward_decay_to_chunk_end_bhnl.clone() * dt_discretized_bhnl.clone();
91 assert_eq!([batch, nheads, nchunks, chunk_len], b_bar_scale_bhnl.dims());
92 san(&b_bar_scale_bhnl);
93
94 // - 11: permute: (b_bar_scale_bhnl [+]) -> (b_bar_scale_bnhl)
95 let b_bar_scale_bnhl = b_bar_scale_bhnl.clone().permute([0, 2, 1, 3]);
96 assert_eq!([batch, nchunks, nheads, chunk_len], b_bar_scale_bnhl.dims());
97 let b_bar_scale_bnhlr = b_bar_scale_bnhl
98 // - 12: unsqueeze: (b_bar_scale_bnhl) -> (b_bar_scale_bnhl1)
99 .unsqueeze_dim::<5>(4) // b_bar_scale_bnhl1
100 // - 13: expand: (b_bar_scale_bnhl1) -> (b_bar_scale_bnhlr)
101 .expand([batch, nchunks, nheads, chunk_len, state_rank]);
102 // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
103 let b_scaled_bnhlr = b_bnhlr * b_bar_scale_bnhlr;
104 assert_eq!(
105 [batch, nchunks, nheads, chunk_len, state_rank],
106 b_scaled_bnhlr.dims()
107 );
108 san(&b_scaled_bnhlr);
109
110 // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
111 let intra_chunk_state_bnhpr: Tensor<B, 5> = x_bnhpl.matmul(b_scaled_bnhlr.clone());
112 assert_eq!(
113 [batch, nchunks, nheads, per_head_dim, state_rank],
114 intra_chunk_state_bnhpr.dims()
115 );
116 san(&intra_chunk_state_bnhpr);
117 (
118 intra_chunk_state_bnhpr,
119 b_bar_scale_bhnl,
120 forward_decay_to_chunk_end_bhnl,
121 b_scaled_bnhlr,
122 )
123}
124
125/// Core gradient computation. All arguments use the shapes from the forward.
126///
127/// `d_y_bnlhp` : upstream gradient of the scan output [B,N,L,H,P]
128/// `d_final_bhpr` : upstream gradient of the final state [B,H,P,R]
129///
130/// Returns one `CombinedGrads` struct containing gradients for all 7 inputs.
131#[allow(clippy::too_many_arguments)]
132pub fn combined_backward<B: Backend>(
133 d_y_bnlhp: Tensor<B, 5>,
134 d_final_bhpr: Tensor<B, 4>,
135 // Saved forward inputs
136 x_bnlhp: Tensor<B, 5>,
137 dt_discretized_bhnl: Tensor<B, 4>,
138 b_bnlgr: Tensor<B, 5>,
139 c_bnlgr: Tensor<B, 5>,
140 d_h: Tensor<B, 1>,
141 initial_state_bhpr: Tensor<B, 4>,
142 a_decay_h: Tensor<B, 1>,
143) -> CombinedGrads<B> {
144 let [batch, nheads, nchunks, chunk_len] = dt_discretized_bhnl.dims();
145 let [.., per_head_dim] = x_bnlhp.dims();
146 let [.., ngroups, state_rank] = b_bnlgr.dims();
147 let heads_per_group = nheads / ngroups;
148 let device = dt_discretized_bhnl.device();
149
150 san(&d_y_bnlhp);
151 san(&d_final_bhpr);
152 san(&x_bnlhp);
153 san(&dt_discretized_bhnl);
154 san(&b_bnlgr);
155 san(&c_bnlgr);
156 san(&d_h);
157 san(&initial_state_bhpr);
158 san(&a_decay_h);
159
160 // ═══════════════════════════════════════════════════════════════════════
161 // RECOMPUTE FORWARD INTERMEDIATES (the memory-saving heart of this op)
162 // ═══════════════════════════════════════════════════════════════════════
163
164 // K1 recomputation ─────────────────────────────────────────────────────
165 // da_cumsum is not saved across the boundary; recompute from dt and a_decay.
166 let (da_cumsum_bhnl, da_chunk_end_bhn) =
167 serial::k1_ssd_chunk_cumsum(dt_discretized_bhnl.clone(), a_decay_h.clone());
168 san(&da_cumsum_bhnl);
169 san(&da_chunk_end_bhn);
170
171 // K2 ───────────────────────────────────────────────────────────────────
172 let cb_bngll = serial::k2_ssd_bmm(c_bnlgr.clone(), b_bnlgr.clone());
173 // let cb_bngll = k2_forward(&c_bnlgr, &b_bnlgr); // [B,N,G,L,L]
174 san(&cb_bngll);
175
176 // K3 (with intermediates) ──────────────────────────────────────────────
177 let (
178 intra_chunk_state_bnhpr,
179 b_bar_scale_bhnl,
180 forward_decay_to_chunk_end_bhnl,
181 b_scaled_bnhlr,
182 ) = k3_ssd_chunk_state_extended(
183 x_bnlhp.clone(),
184 b_bnlgr.clone(),
185 da_cumsum_bhnl.clone(),
186 dt_discretized_bhnl.clone(),
187 );
188 san(&intra_chunk_state_bnhpr);
189 san(&b_bar_scale_bhnl);
190 san(&forward_decay_to_chunk_end_bhnl);
191 san(&b_scaled_bnhlr);
192
193 // K4 ───────────────────────────────────────────────────────────────────
194 let (chunk_input_state_bnhpr, _final_state_bhpr): (Tensor<B, 5>, Tensor<B, 4>) =
195 serial::k4_ssd_state_passing(
196 intra_chunk_state_bnhpr.clone(),
197 da_chunk_end_bhn.clone(),
198 initial_state_bhpr,
199 );
200 san(&chunk_input_state_bnhpr);
201 san(&_final_state_bhpr);
202
203 // ═══════════════════════════════════════════════════════════════════════
204 // K5 BACKWARD
205 // ═══════════════════════════════════════════════════════════════════════
206 // Expand CB for all heads
207 let cb_bnhll = cb_bngll
208 .clone()
209 .unsqueeze_dim::<6>(3) // cb_bng1ll
210 .expand([
211 batch,
212 nchunks,
213 ngroups,
214 heads_per_group,
215 chunk_len,
216 chunk_len,
217 ]) // cb_bngHll
218 .reshape([batch, nchunks, nheads, chunk_len, chunk_len]);
219
220 // Reshape inputs to [B,N,H,L,...] convention used inside K5
221 let da_cumsum_bnhl: Tensor<B, 4> = da_cumsum_bhnl.permute([0, 2, 1, 3]);
222 let dt_bnhl: Tensor<B, 4> = dt_discretized_bhnl.clone().permute([0, 2, 1, 3]);
223 let x_bnhlp: Tensor<B, 5> = x_bnlhp.clone().permute([0, 1, 3, 2, 4]);
224 let d_y_bnhlp: Tensor<B, 5> = d_y_bnlhp.clone().permute([0, 1, 3, 2, 4]);
225
226 // GQA-expand C: [B,N,L,G,R] → [B,N,H,L,R]
227 let c_bnhlr = c_bnlgr
228 .clone()
229 .unsqueeze_dim::<6>(4) // c_bnlg1r
230 .expand([
231 batch,
232 nchunks,
233 chunk_len,
234 ngroups,
235 heads_per_group,
236 state_rank,
237 ]) // c_bnlgHr
238 .reshape([batch, nchunks, chunk_len, nheads, state_rank]) // c_bnlhr
239 .permute([0, 1, 3, 2, 4]);
240
241 // ── SKIP backward ──────────────────────────────────────────────────────
242 // - | 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
243 // - | (d_skip_bnlhp = d_y_bnlhp)
244 let d_skip_bnlhp = d_y_bnlhp.clone();
245 //
246 //
247 // For d_h:
248 // - - | 33: mul: (d_bnlhp, x_bnlhp[*]) -> (skip_bnlhp)
249 // - - | (d_d_bnlhp = d_skip_bnlhp * x_bnlhp)
250 // - - | 32: expand: (d_111h1) -> (d_bnlhp)
251 // - - | 31: unsqueeze-dims: (d_h [*]) -> (d_111h1)
252 //
253 // - - | d_d[h] = Σ_{b,n,l,p} dy * x — use permute+reshape to avoid chained sum_dim
254 let d_d_h = {
255 // [B,N,L,H,P] → permute to [H,B,N,L,P] → reshape [H, rest] → sum → [H]
256 d_skip_bnlhp.clone()
257 .permute([3, 0, 1, 2, 4]) // d_y_hbnlp
258 .reshape([nheads, batch * nchunks * chunk_len * per_head_dim]) // d_y_hBNLP
259 * x_bnlhp
260 .clone()
261 .permute([3, 0, 1, 2, 4]) // x_hbnlp
262 .reshape([nheads, batch * nchunks * chunk_len * per_head_dim]) // x_hBNLP
263 }
264 .sum_dim(1) // d_d_h1
265 .reshape([nheads]);
266 san(&d_d_h);
267 //
268 // For d_x:
269 // - - | 33: mul: (d_bnlhp, x_bnlhp[*]) -> (skip_bnlhp)
270 // - - | (d_x_skip_bnlhp = d_skip_bnlhp * d_bnlhp)
271 let d_x_skip_bnlhp = d_skip_bnlhp
272 * d_h
273 .clone()
274 .unsqueeze_dims::<5>(&[0, 1, 2, 4]) // d_111h1
275 // d_bnlhp
276 .expand([batch, nchunks, chunk_len, nheads, per_head_dim]);
277 san(&d_x_skip_bnlhp);
278
279 // ── BLUE backward ──────────────────────────────────────────────────────
280 // - | 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
281 // - | (d_y_partial_bnlhp = d_y_bnlhp)
282 let d_y_partial_bnhlp = d_y_bnhlp.clone();
283 //
284 // - | 35: permute: (y_partial_bnhlp) -> (y_partial_bnlhp)
285 // - | 34: add: (blue_scaled_bnhlp, orange_bnhlp) -> (y_partial_bnhlp)
286 // - | (d_blue_scaled_bnhlp = d_y_partial_bnhlp)
287 let d_blue_scaled_bnhlp = d_y_partial_bnhlp;
288 // - | 16: mul: (blue_bnhlp, exp_da_cumsum_bnhlp) -> (blue_scaled_bnhlp)
289 // - | (d_blue_bnhlp = d_blue_scaled_bnhlp * exp_da_cumsum_bnhlp)
290 //
291 // - | blue[b,n,h,l,p] = exp(da[b,n,h,l]) * Σ_r C[b,n,h,l,r] * state[b,n,h,p,r]
292 let exp_da_cumsum_bnhl: Tensor<B, 4> = da_cumsum_bnhl.clone().exp();
293 san(&exp_da_cumsum_bnhl);
294 let exp_da_cumsum_bnhlp = exp_da_cumsum_bnhl.clone().unsqueeze_dim::<5>(4).expand([
295 batch,
296 nchunks,
297 nheads,
298 chunk_len,
299 per_head_dim,
300 ]);
301 let d_blue_bnhlp: Tensor<B, 5> = d_blue_scaled_bnhlp.clone() * exp_da_cumsum_bnhlp.clone();
302 san(&d_blue_bnhlp);
303 //
304 // For d_chunk_input_state_bnhpr:
305 // - | 15: matmul: (c_bnhlr, chunk_input_state_bnhrp) -> (blue_bnhlp)
306 // - - | (d_chunk_input_state_bnhrp = c_bnhlr^T @ d_blue_bnhlp)
307 // - - | 14: permute: (chunk_input_state_bnhpr [!]) -> (chunk_input_state_bnhrp)
308 //
309 // - - | d_state[b,n,h,p,r] = Σ_l (scaled_dy[b,n,h,l,p] * C[b,n,h,l,r])
310 // - - | = C^T[R,L] @ scaled_dy[L,P] for fixed (b,n,h)
311 // - - | [B,N,H,R,L] @ [B,N,H,L,P] → [B,N,H,R,P] → permute → [B,N,H,P,R]
312 let d_chunk_input_state_bnhpr = c_bnhlr
313 .clone()
314 .permute([0, 1, 2, 4, 3]) // c_bnhrl
315 .matmul(d_blue_bnhlp.clone()) // d_chunk_input_state_bnhrp
316 .permute([0, 1, 2, 4, 3]);
317 san(&d_chunk_input_state_bnhpr);
318 //
319 // For d_c from BLUE:
320 // - | 15: matmul: (c_bnhlr, chunk_input_state_bnhrp) -> (blue_bnhlp)
321 // - - | (d_c_bnhlr = d_blue_bnhlp @ chunk_input_state_bnhrp^T)
322 // - - | 7: permute: (c_bnlhr) -> (c_bnhlr)
323 // - - | 6: reshape: (c_bnlgHr) -> (c_bnlhr)
324 // - - | 5: expand: (c_bnlg1r) -> (c_bnlgHr)
325 // - - | 4: unsqueeze: (c_bnlgr [*]) -> (c_bnlg1r)
326 //
327 // - - | d_C[l,r] = Σ_p scaled_dy[l,p] * state[p,r]
328 // - - | [B,N,H,L,P] @ [B,N,H,P,R] → [B,N,H,L,R]
329 let d_c_blue_bnhlr = d_blue_bnhlp.clone().matmul(chunk_input_state_bnhpr.clone());
330 san(&d_c_blue_bnhlr);
331 // - - | GQA reduce: [B,N,H,L,R] → [B,N,L,G,R]
332 let d_c_blue_bnlgr = d_c_blue_bnhlr
333 .reshape([
334 batch,
335 nchunks,
336 ngroups,
337 heads_per_group,
338 chunk_len,
339 state_rank,
340 ]) // d_c_blue_bngHlr
341 .sum_dim(3) // d_c_blue_bng1lr
342 .squeeze_dim::<5>(3) // d_c_blue_bnglr
343 .permute([0, 1, 3, 2, 4]);
344 san(&d_c_blue_bnlgr);
345 //
346 // For d_da_cumsum from BLUE:
347 // - | 16: mul: (blue_bnhlp, exp_da_cumsum_bnhlp) -> (blue_scaled_bnhlp)
348 // - | (d_exp_da_cumsum_bnhlp = d_blue_scaled_bnhlp * blue_bnhlp)
349 let blue_bnhlp = c_bnhlr
350 .clone()
351 .matmul(chunk_input_state_bnhpr.clone().permute([0, 1, 2, 4, 3])); // replay forward step 15
352 san(&blue_bnhlp);
353 let d_exp_da_cumsum_bnhlp = d_blue_scaled_bnhlp.clone() * blue_bnhlp;
354 san(&d_exp_da_cumsum_bnhlp);
355 //
356 // - | blue_no_scale = C @ state^T [L,P]
357 // - - | 13: expand: (exp_da_cumsum_bnhl1) -> (exp_da_cumsum_bnhlp)
358 // - - | 12: unsqueeze: (exp_da_cumsum_bnhl) -> (exp_da_cumsum_bnhl1)
359 // - - | 11: exp: (da_cumsum_bnhl) -> (exp_da_cumsum_bnhl)
360 // - - | (d_da_cumsum_bnhl = d_exp_da_cumsum_bnhlp * exp(da_cumsum_bnhl))
361 // - - | 1/36: permute: (da_cumsum_bhnl [*]) -> (da_cumsum_bnhl)
362 //
363 // - - | d_da[l] = Σ_p dy[l,p] * exp_da[l] * blue_no_scale[l,p]
364 let d_da_blue_bnhl = (d_exp_da_cumsum_bnhlp * exp_da_cumsum_bnhlp)
365 .sum_dim(4) // d_da_blue_bnhl1
366 .squeeze_dim::<4>(4);
367 san(&d_da_blue_bnhl);
368 let d_da_blue_bhnl = d_da_blue_bnhl.permute([0, 2, 1, 3]);
369
370 // ── ORANGE backward ─────────────────────────────────────────────────────
371 // y_orange[l,p] = Σ_{s≤l} CB[l,s] * exp(da[l]-da[s]) * dt[s] * x[s,p]
372 // Precompute weight matrix CB_w [B,N,H,L_tgt,L_src]
373 // replay forward steps 17-29
374 let da_cumsum_target_bnhll = da_cumsum_bnhl
375 .clone()
376 .unsqueeze_dim::<5>(4) // da_cumsum_bnhl1 // forward step 17
377 .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // forward step 18
378 let da_cumsum_source_bnhll = da_cumsum_bnhl
379 .clone()
380 .unsqueeze_dim::<5>(3) // da_cumsum_bnh1l // forward step 19
381 .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // forward step 20
382 let da_cumsum_diff_bnhll = da_cumsum_target_bnhll - da_cumsum_source_bnhll; // forward step 21
383 san(&da_cumsum_diff_bnhll);
384 // forward step 21.1: built at [L,L] and broadcast — mask values do not depend on (b,n,h).
385 let causal_mask_bnhll: Tensor<B, 5, burn::prelude::Bool> =
386 Tensor::<B, 2, burn::prelude::Bool>::tril_mask([chunk_len, chunk_len], 0, &device)
387 .reshape([1, 1, 1, chunk_len, chunk_len])
388 .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
389 // forward step 21.2
390 // Causal mask and exp stabilizer (-inf above the main diagonal, 0 elsewhere).
391 let da_cumsum_diff_masked_bnhll =
392 da_cumsum_diff_bnhll.mask_fill(causal_mask_bnhll.clone(), f32::NEG_INFINITY);
393 let da_cumsum_diff_exp_bnhll = (da_cumsum_diff_masked_bnhll).exp(); // forward steps 22
394 san(&da_cumsum_diff_exp_bnhll);
395 let dt_source_bnhll = dt_bnhl
396 .clone()
397 .unsqueeze_dim::<5>(3) // dt_bnh1l // forward step 23
398 .expand([batch, nchunks, nheads, chunk_len, chunk_len]); // forward step 24
399 // // Causal mask (0 above the main diagonal, 1 elsewhere).
400 // let causal_mask_bnhll =
401 // Tensor::ones([batch, nchunks, nheads, chunk_len, chunk_len], &device).tril(0); // forward steps 25-26
402 // CB_w[l,s] = CB[l,s] * decay[l,s] * dt[s] * mask[l,s]
403 let orange_lhs_partial1_bnhll: Tensor<B, 5> = // forward step 27
404 cb_bnhll.clone() * da_cumsum_diff_exp_bnhll.clone();
405 san(&orange_lhs_partial1_bnhll);
406 let orange_lhs_partial2_bnhll: Tensor<B, 5> = // forward step 28
407 orange_lhs_partial1_bnhll.clone() * dt_source_bnhll.clone();
408 san(&orange_lhs_partial2_bnhll);
409 // let orange_lhs_partial3_bnhll: Tensor<B, 5> = // forward step 29
410 // orange_lhs_partial2_bnhll.clone() * causal_mask_bnhll.clone();
411 //
412 // Backwads:
413 // - | 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
414 // - | (d_y_partial_bnlhp = d_y_bnlhp)
415 let d_y_partial_bnhlp = d_y_bnhlp.clone();
416 // - | 35: permute: (y_partial_bnhlp) -> (y_partial_bnlhp)
417 // - | 34: add: (blue_scaled_bnhlp, orange_bnhlp) -> (y_partial_bnhlp)
418 // - | (d_orange_bnhlp = d_y_partial_bnhlp)
419 let d_orange_bnhlp = d_y_partial_bnhlp;
420 // - | 30: matmul: (orange_lhs_partial2_bnhll, x_bnhlp) -> (orange_bnhlp)
421 // - | (d_orange_lhs_partial2_bnhll = d_orange_bnhlp @ x_bnhlp^T)
422 // d_CB_w: dy @ x^T [B,N,H,L_tgt,L_src]
423 let d_orange_lhs_partial2_bnhll = d_orange_bnhlp
424 .clone()
425 .matmul(x_bnhlp.clone().permute([0, 1, 2, 4, 3])); // [B,N,H,L_tgt,L_src]
426 san(&d_orange_lhs_partial2_bnhll);
427 //
428 // - | For d_x:
429 // - - | (d_x_bnhlp = orange_lhs_partial2_bnhll^T @ d_orange_bnhlp)
430 // - - | d_x from ORANGE: CB_w^T @ dy (transpose source/target dims)
431 // - - | [B,N,H,L_src,L_tgt] @ [B,N,H,L_tgt,P] → [B,N,H,L_src,P]
432 let d_x_orange_bnhlp = orange_lhs_partial2_bnhll
433 .clone()
434 .permute([0, 1, 2, 4, 3]) // [B,N,H,L_src,L_tgt]
435 .matmul(d_orange_bnhlp.clone()); // [B,N,H,L_src,P]
436 san(&d_x_orange_bnhlp);
437 //
438 // - | 21.2: mask-fill: (.., ..) -> (..)
439 // Bring the (step 21.2) causal mask ahead: above upper diagonal set to 0.
440 let d_orange_lhs_partial2_bnhll = d_orange_lhs_partial2_bnhll.mask_fill(causal_mask_bnhll, 0.);
441 san(&d_orange_lhs_partial2_bnhll);
442 // - | 28: mul: (orange_lhs_partial1_bnhll, dt_source_bnhll) -> (orange_lhs_partial2_bnhll)
443 let d_orange_lhs_partial1_bnhll = d_orange_lhs_partial2_bnhll.clone() * dt_source_bnhll.clone();
444 san(&d_orange_lhs_partial1_bnhll);
445 // - | For d_dt from ORANGE:
446 // - - | 24: expand: (dt_bnh1l) -> (dt_source_bnhll)
447 // - - | 23: unsqueeze: (dt_bnhl) -> (dt_bnh1l)
448 // - - | 2: permute: (dt_discretized_bhnl [*]) -> (dt_bnhl)
449 // - - | d_dt[s] = Σ_{l≥s} d_CB_w[l,s] * CB[l,s] * decay[l,s] * mask[l,s]
450 // - - | = (d_cb_w * cb * decay * mask).sum(L_tgt dim=3)
451 let d_dt_orange_bnhl = (d_orange_lhs_partial2_bnhll.clone()
452 * orange_lhs_partial1_bnhll.clone())
453 .sum_dim(3) // d_dt_orange_bnh1l
454 .squeeze_dim::<4>(3);
455 san(&d_dt_orange_bnhl);
456 let d_dt_orange_bhnl = d_dt_orange_bnhl.permute([0, 2, 1, 3]);
457 //
458 // - | For d_da from ORANGE:
459 // - - | decay = exp(da_tgt - da_src)
460 // - - | d_decay = d_CB_w * CB * dt_src * mask
461 // - - | d_da_tgt[l] += Σ_s (d_decay * decay)[l,s]
462 // - - | d_da_src[s] -= Σ_l (d_decay * decay)[l,s]
463 // - - | 27: mul: (cb_bnhll, da_cumsum_diff_exp_bnhll) -> (orange_lhs_partial1_bnhll)
464 let d_da_cumsum_diff_exp_bnhll = d_orange_lhs_partial1_bnhll.clone() * cb_bnhll.clone();
465 san(&d_da_cumsum_diff_exp_bnhll);
466 // - - | 22: exp: (da_cumsum_diff_bnhll) -> (da_cumsum_diff_exp_bnhll)
467 // - - | (d_da_cumsum_diff_bnhll = d_da_cumsum_diff_exp_bnhll * exp(da_cumsum_diff_bnhll))
468 let d_da_cumsum_diff_bnhll = d_da_cumsum_diff_exp_bnhll * da_cumsum_diff_exp_bnhll.clone();
469 san(&d_da_cumsum_diff_bnhll);
470 // - - | 21: sub: (da_cumsum_target_bnhll, da_cumsum_source_bnhll) -> (da_cumsum_diff_bnhll)
471 // - - | 20: expand: (da_cumsum_bnh1l) -> (da_cumsum_source_bnhll)
472 // - - | 19: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnh1l)
473 // - - | 18: expand: (da_cumsum_bnhl1) -> (da_cumsum_target_bnhll)
474 // - - | 17: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnhl1)
475 // - - | 1/36: permute: (da_cumsum_bhnl [*]) -> (da_cumsum_bnhl)
476 let d_da_tgt_bnhl = d_da_cumsum_diff_bnhll
477 .clone()
478 .sum_dim(4) // d_da_cumsum_diff_bnhl1
479 .squeeze_dim::<4>(4);
480 san(&d_da_tgt_bnhl);
481 let d_da_src_bnhl = d_da_cumsum_diff_bnhll
482 .sum_dim(3) // d_da_cumsum_diff_bnh1l
483 .squeeze_dim::<4>(3);
484 san(&d_da_src_bnhl);
485 let d_da_orange_bhnl = (d_da_tgt_bnhl - d_da_src_bnhl).permute([0, 2, 1, 3]); // [B,H,N,L]
486 san(&d_da_orange_bhnl);
487 //
488 // - | For d_cb:
489 // - - | 27: mul: (cb_bnhll, da_cumsum_diff_exp_bnhll) -> (orange_lhs_partial1_bnhll)
490 let d_cb_bnhll = d_orange_lhs_partial1_bnhll * da_cumsum_diff_exp_bnhll.clone();
491 san(&d_cb_bnhll);
492 // - - | d_CB (per head, before GQA reduction):
493 // - - | CB_w = CB * decay * dt * mask → d_CB[l,s] = d_CB_w[l,s] * decay[l,s] * dt[s] * mask
494 // - - | GQA reduce: [B,N,H,L,L] → [B,N,G,L,L]
495 let d_cb_bngll = d_cb_bnhll
496 .reshape([
497 batch,
498 nchunks,
499 ngroups,
500 heads_per_group,
501 chunk_len,
502 chunk_len,
503 ]) // d_cb_bngHll
504 .sum_dim(3) // d_cb_bng1ll
505 .squeeze_dim::<5>(3);
506 san(&d_cb_bngll);
507
508 // ═══════════════════════════════════════════════════════════════════════
509 // K4 BACKWARD (reverse serial recurrence)
510 // ═══════════════════════════════════════════════════════════════════════
511 //
512 // - 5/5: stack: (chunk_input_state_vec_bhpr [!]) -> (chunk_input_state_bnhpr [out][!])
513 // - 4: vec-pop: (chunk_input_state_vec_bhpr [vec][!]) -> (final_state_bhpr [elem][out][!])
514 // - 3: serial-loop: (0..nchunks)
515 //
516 // last d_running_state_bhpr:
517 let mut d_running_state_bhpr: Tensor<B, 4> = d_final_bhpr; // [B,H,P,R]
518 //
519 // d_intra[c] and d_da_end[c] collected during reverse traversal.
520 let mut d_intra_slices: Vec<Tensor<B, 4>> = Vec::with_capacity(nchunks);
521 let mut d_da_end_bh_slices: Vec<Tensor<B, 2>> = Vec::with_capacity(nchunks);
522 //
523 for i_chunk in (0..nchunks).rev() {
524 // access re-calculated running state
525 let running_state_bhpr = chunk_input_state_bnhpr
526 .clone()
527 .slice(s![.., i_chunk, .., .., ..])
528 .squeeze_dim(1);
529 assert_eq!(
530 [batch, nheads, per_head_dim, state_rank],
531 running_state_bhpr.dims()
532 );
533 //
534 // - 3.9/3.9: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
535 d_intra_slices.push(d_running_state_bhpr.clone());
536 //
537 // - 3.8: add: (running_state_bhpr, intra_state_bhpr) -> (running_state_bhpr)
538 let _d_intra_state_bhpr = d_running_state_bhpr.clone();
539 //
540 // - 3.7: mul: (decay_bhpr, running_state_bhpr) -> (running_state_bhpr)
541 let d_decay_bhpr = d_running_state_bhpr.clone() * running_state_bhpr.clone();
542 san(&d_decay_bhpr);
543 // recalculate decay_bhpr
544 let decay_bhpr = da_chunk_end_bhn
545 .clone()
546 .slice(s![.., .., i_chunk]) // da_chunk_end_bh1 // replay forward step 3.3
547 .exp() // exp_da_chunk_end_bh1 // replay forward step 3.4
548 .unsqueeze_dim::<4>(3) // exp_da_chunk_end_bh11 // replay forward step 3.5
549 .expand([batch, nheads, per_head_dim, state_rank]); // replay forward step 3.6
550 san(&decay_bhpr);
551 // - 3.6: expand: (exp_da_chunk_end_bh11) -> (decay_bhpr)
552 // - 3.5: unsqueeze: (exp_da_chunk_end_bh1) -> (exp_da_chunk_end_bh11)
553 // - 3.4: exp: (da_chunk_end_bh1) -> (exp_da_chunk_end_bh1)
554 // (d_da_chunk_end_bh1 = d_exp_da_chunk_end_bh1 * exp(da_chunk_end_bh1))
555 // - 3.3: slice: (da_chunk_end_bhn [in][*]) -> (da_chunk_end_bh1)
556 let d_da_chunk_end_bhpr = d_decay_bhpr * decay_bhpr.clone(); // note: decay is expanded exp(da_chunk_end)
557 san(&d_da_chunk_end_bhpr);
558 let d_da_chunk_end_bh = d_da_chunk_end_bhpr
559 .reshape([batch, nheads, per_head_dim * state_rank]) // d_da_chunk_end_bhPR
560 .sum_dim(2) // d_da_chunk_end_bh1
561 .squeeze_dim::<2>(2);
562 san(&d_da_chunk_end_bh);
563 d_da_end_bh_slices.push(d_da_chunk_end_bh);
564 //
565 // - 3.2: squeeze: (intra_chunk_state_b1hpr) -> (intra_state_bhpr)
566 // - 3.1/3.9: slice: (intra_chunk_state_bnhpr [in][!]) -> (intra_chunk_state_b1hpr)
567 //
568 // Propagate: d_running_state_bhpr_prev = scale * d_running_state_bhpr + d_chunk_input_state_bhpr
569 // (d_cis[c] = gradient of chunk_input_state[:, c] flowing in from K5 BLUE)
570 let d_chunk_input_state_bhpr = d_chunk_input_state_bnhpr
571 .clone()
572 .slice(s![.., i_chunk, .., .., ..]) // d_chunk_input_state_b1hpr // d_chunk_input_state_b1hpr
573 .squeeze_dim::<4>(1);
574 // TODO: understand this.
575 d_running_state_bhpr = decay_bhpr * d_running_state_bhpr + d_chunk_input_state_bhpr;
576 san(&d_running_state_bhpr);
577 }
578 // - 2: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
579 // - 1/5: init-mut: (initial_state_bhpr [in][*]) -> (running_state_bhpr)
580 //
581 // After the loop, d_initial_state = the (reverse loop) tailing d_running_state_bhpr
582 let d_initial_state_bhpr = d_running_state_bhpr;
583 //
584 // Restore natural order
585 d_intra_slices.reverse();
586 d_da_end_bh_slices.reverse();
587 //
588 let d_intra_chunk_state_bnhpr = Tensor::stack(d_intra_slices, 1);
589 //
590 // d_da_end_bhn [B,H,N]: scatter to last position of d_da_cumsum
591 let d_da_end_bhn: Tensor<B, 3> = Tensor::stack(d_da_end_bh_slices, 2);
592 //
593 // TODO: understand this.
594 // Pad to [B,H,N,L] — only last L-position is non-zero
595 let d_da_cumsum_k4_bhnl = {
596 let zeros = Tensor::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device);
597 let d_da_end_bhn1 = d_da_end_bhn.unsqueeze_dim::<4>(3);
598 Tensor::cat(vec![zeros, d_da_end_bhn1], 3)
599 };
600
601 // ═══════════════════════════════════════════════════════════════════════
602 // K3 BACKWARD
603 // ═══════════════════════════════════════════════════════════════════════
604 let x_bnhpl = x_bnlhp.clone().permute([0, 1, 3, 4, 2]);
605 // For d_x_bnlhp:
606 // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
607 // - (d_x_bnhpl = d_intra_chunk_state_bnhpr @ b_scaled_bnhlr^T)
608 let d_x_k3_bnhpl = d_intra_chunk_state_bnhpr
609 .clone()
610 .matmul(b_scaled_bnhlr.clone().permute([0, 1, 2, 4, 3]));
611 san(&d_x_k3_bnhpl);
612 // - 1/15: permute: (x_bnlhp [in][*]) -> (x_bnhpl)
613 let d_x_k3_bnlhp = d_x_k3_bnhpl.permute([0, 1, 4, 2, 3]);
614 //
615 // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
616 // (d_b_scaled_bnhlr = x_bnhpl^T @ d_intra_chunk_state_bnhpr)
617 let d_b_scaled_bnhlr = x_bnhpl
618 .permute([0, 1, 2, 4, 3]) // x_bnhlp
619 .matmul(d_intra_chunk_state_bnhpr);
620 san(&d_b_scaled_bnhlr);
621 //
622 // For d_b:
623 // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
624 // - (d_b_bnhlr = d_b_scaled_bnhlr * b_bar_scale_bnhlr)
625 let b_bar_scale_bnhlr = b_bar_scale_bhnl
626 .clone()
627 .permute([0, 2, 1, 3]) // b_bar_scale_bnhl // replay forward step 11
628 .unsqueeze_dim::<5>(4) // b_bar_scale_bnhl1 // replay forward step 12
629 .expand([batch, nchunks, nheads, chunk_len, state_rank]); // replay forward step 13
630 let d_b_k3_bnhlr = d_b_scaled_bnhlr.clone() * b_bar_scale_bnhlr;
631 san(&d_b_k3_bnhlr);
632 // - 5: reshape: (b_bngHlr) -> (b_bnhlr)
633 // - 4: expand: (b_bng1lr) -> (b_bngHlr)
634 // - 3: unsqueeze: (b_bnglr) -> (b_bng1lr)
635 // - 2: permute: (b_bnlgr [in][*]) -> (b_bnglr)
636 // GQA reduce: [B,N,H,L,R] → [B,N,G,L,R] → [B,N,L,G,R]
637 let d_b_k3_bnlgr = d_b_k3_bnhlr
638 .reshape([
639 batch,
640 nchunks,
641 ngroups,
642 heads_per_group,
643 chunk_len,
644 state_rank,
645 ]) // d_b_k3_bngHlr
646 .sum_dim(3) // d_b_k3_bng1lr
647 .squeeze_dim::<5>(3) // d_b_k3_bnglr
648 .permute([0, 1, 3, 2, 4]);
649 san(&d_b_k3_bnlgr);
650
651 // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
652 // - (d_b_bar_scale_bnhlr = d_b_scaled_bnhlr * b_bnhlr)
653 // GQA-expand B back to per-head for the product: [B,N,G,L,R] → [B,N,H,L,R]
654 let b_bnhlr = b_bnlgr
655 .clone()
656 .permute([0, 1, 3, 2, 4]) // b_bnglr // replay forward step 2
657 .unsqueeze_dim::<6>(3) // b_bng1lr // replay forward step 3
658 // b_bngHlr
659 .expand([
660 batch,
661 nchunks,
662 ngroups,
663 heads_per_group,
664 chunk_len,
665 state_rank,
666 ]) // replay forward step 4
667 .reshape([batch, nchunks, nheads, chunk_len, state_rank]); // replay forward step 5
668 let d_b_bar_scale_bnhlr = d_b_scaled_bnhlr.clone() * b_bnhlr;
669 san(&d_b_bar_scale_bnhlr);
670 // - 13: expand: (b_bar_scale_bnhl1) -> (b_bar_scale_bnhlr)
671 // - 12: unsqueeze: (b_bar_scale_bnhl) -> (b_bar_scale_bnhl1)
672 // - 11: permute: (b_bar_scale_bhnl [+]) -> (b_bar_scale_bnhl)
673 let d_b_bar_scale_bhnl = d_b_bar_scale_bnhlr
674 .sum_dim(4) // d_b_bar_scale_bnhl1
675 .squeeze_dim::<4>(4) // d_b_bar_scale_bnhl
676 .permute([0, 2, 1, 3]);
677 san(&d_b_bar_scale_bhnl);
678 //
679 // For d_da_cumsum_bhnl:
680 // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
681 // - (d_forward_decay_to_chunk_end_bhnl = d_b_bar_scale_bhnl * dt_discretized_bhnl)
682 let d_forward_decay_to_chunk_end_bhnl =
683 d_b_bar_scale_bhnl.clone() * dt_discretized_bhnl.clone();
684 san(&d_forward_decay_to_chunk_end_bhnl);
685 // - 9: exp: (da_delta_bhnl) -> (forward_decay_to_chunk_end_bhnl [+])
686 // - (d_da_delta_bhnl = d_forward_decay_to_chunk_end_bhnl * exp(da_delta_bhnl))
687 // note: forward_decay_to_chunk_end_bhnl = exp(da_delta_bhnl)
688 let d_da_delta_bhnl =
689 d_forward_decay_to_chunk_end_bhnl * forward_decay_to_chunk_end_bhnl.clone();
690 san(&d_da_delta_bhnl);
691 // - 8: sub: (da_cumsum_last_bhnl, da_cumsum_bhnl [from K1][*]) -> (da_delta_bhnl)
692 let d_da_cumsum_last_bhnl = d_da_delta_bhnl.clone();
693 let d_da_cumsum_sub_bhnl = -d_da_delta_bhnl.clone();
694 // - 7: expand: (da_cumsum_last_in_chunk_bhn1) -> (da_cumsum_last_bhnl)
695 // - 6: slice: (da_cumsum_bhnl [in][*]) -> (da_cumsum_last_in_chunk_bhn1)
696 let d_da_cumsum_last_bhn = d_da_cumsum_last_bhnl
697 .sum_dim(3) // d_da_cumsum_last_bhn1
698 .squeeze_dim::<3>(3);
699 san(&d_da_cumsum_last_bhn);
700 //
701 // For d_dt_discretized_bhnl:
702 // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
703 // - (d_dt_discretized_bhnl = d_b_bar_scale_bhnl * forward_decay_to_chunk_end_bhnl)
704 let d_dt_discretized_k3_bhnl = d_b_bar_scale_bhnl * forward_decay_to_chunk_end_bhnl;
705 san(&d_dt_discretized_k3_bhnl);
706 //
707
708 // TODO: understand this.
709 let d_da_cumsum_k3_bhnl = {
710 let zeros = Tensor::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device);
711 let d_last = d_da_cumsum_last_bhn.unsqueeze_dim::<4>(3);
712 d_da_cumsum_sub_bhnl + Tensor::cat(vec![zeros, d_last], 3)
713 };
714 san(&d_da_cumsum_k3_bhnl);
715
716 // ═══════════════════════════════════════════════════════════════════════
717 // K2 BACKWARD (from d_cb_bngll)
718 // ═══════════════════════════════════════════════════════════════════════
719 let c_bnglr = c_bnlgr.clone().permute([0, 1, 3, 2, 4]);
720 let b_bnglr = b_bnlgr.clone().permute([0, 1, 3, 2, 4]);
721 // - 3/3: matmul: (c_bnglr, b_bngrl) -> (cb_bngll [out][!])
722 // - cb[b,n,g,l,s] = Σ_r c[l,r]*b[s,r] → CB = C @ B^T
723 // - d_C_bngls = d_CB @ B [B,N,G,L,L_src] @ [B,N,G,L_src,R] → [B,N,G,L,R]
724 // - d_B_bngls = d_CB^T @ C [B,N,G,L_src,L] @ [B,N,G,L,R] → [B,N,G,L_src,R]
725 let d_c_k2_bnglr = d_cb_bngll.clone().matmul(b_bnglr.clone());
726 san(&d_c_k2_bnglr);
727 let d_c_k2_bnlgr = d_c_k2_bnglr.permute([0, 1, 3, 2, 4]);
728
729 let d_b_k2_bnglr = d_cb_bngll
730 .permute([0, 1, 2, 4, 3]) // [B,N,G,L_src,L_tgt]
731 .matmul(c_bnglr.clone()); // [B,N,G,L_src,R]
732 san(&d_b_k2_bnglr);
733 let d_b_k2_bnlgr = d_b_k2_bnglr.permute([0, 1, 3, 2, 4]); // [B,N,L,G,R]
734
735 // ═══════════════════════════════════════════════════════════════════════
736 // SUM GRADIENT CONTRIBUTIONS
737 // ═══════════════════════════════════════════════════════════════════════
738
739 // Accumulated gradient of the cumulative sum produced by K1.
740 let d_da_cumsum_bhnl =
741 d_da_blue_bhnl + d_da_orange_bhnl + d_da_cumsum_k3_bhnl + d_da_cumsum_k4_bhnl;
742 san(&d_da_cumsum_bhnl);
743
744 // ── K1 BACKWARD ────────────────────────────────────────────────────────
745 // K1 forward: da_cumsum[l] = cumsum_l(dt[l] * a_decay)
746 //
747 // Reverse cumsum (suffix sum) converts d_da_cumsum → d_da:
748 // d_da[l] = sum_{k >= l} d_da_cumsum[k]
749 // = total_sum - cumsum(d_da_cumsum)[l-1] (cumsum[-1] == 0)
750 let d_da_cumsum_total_bhnl = d_da_cumsum_bhnl
751 .clone()
752 .sum_dim(3) // [B,H,N,1]
753 .expand([batch, nheads, nchunks, chunk_len]);
754 let prefix_sum_bhnl = d_da_cumsum_bhnl.clone().cumsum(3); // [B,H,N,L]
755 let zeros_bhn1 = Tensor::<B, 4>::zeros([batch, nheads, nchunks, 1], &device);
756 // prefix_sum shifted right by 1 (i.e., cumsum[l-1], with cumsum[-1] = 0)
757 let prefix_sum_shifted_bhnl = Tensor::cat(
758 vec![zeros_bhn1, prefix_sum_bhnl.narrow(3, 0, chunk_len - 1)],
759 3,
760 );
761 let d_da_bhnl = d_da_cumsum_total_bhnl - prefix_sum_shifted_bhnl; // suffix sum [B,H,N,L]
762 san(&d_da_bhnl);
763 // d_dt from K1: d_dt = d_da * a_decay
764 let a_decay_expand = a_decay_h
765 .clone()
766 .unsqueeze_dims::<4>(&[0, 2, 3])
767 .expand([batch, nheads, nchunks, chunk_len]);
768 let d_dt_k1_bhnl = d_da_bhnl.clone() * a_decay_expand;
769 san(&d_dt_k1_bhnl);
770 // d_a_decay_h from K1: d_a[h] = sum_{b,n,l} d_da[b,h,n,l] * dt[b,h,n,l]
771 let d_a_decay_h = (d_da_bhnl * dt_discretized_bhnl.clone())
772 .permute([1, 0, 2, 3]) // [H,B,N,L]
773 .reshape([nheads, batch * nchunks * chunk_len])
774 .sum_dim(1) // [H,1]
775 .reshape([nheads]);
776 san(&d_a_decay_h);
777
778 let d_dt_discretized_bhnl = d_dt_orange_bhnl + d_dt_discretized_k3_bhnl + d_dt_k1_bhnl;
779 san(&d_dt_discretized_bhnl);
780
781 let d_x_orange_bnlhp = d_x_orange_bnhlp.permute([0, 1, 3, 2, 4]);
782 let d_x_bnlhp = d_x_skip_bnlhp + d_x_k3_bnlhp + d_x_orange_bnlhp;
783 san(&d_x_bnlhp);
784
785 let d_b_bnlgr = d_b_k2_bnlgr + d_b_k3_bnlgr;
786 san(&d_b_bnlgr);
787 let d_c_bnlgr = d_c_k2_bnlgr + d_c_blue_bnlgr;
788 san(&d_c_bnlgr);
789
790 CombinedGrads {
791 d_a_decay_h,
792 d_dt_discretized_bhnl,
793 d_x_bnlhp,
794 d_b_bnlgr,
795 d_c_bnlgr,
796 d_d_h,
797 d_initial_state_bhpr,
798 }
799}