1#![allow(non_snake_case)]
18
19use crate::mamba3::double_ssd::ssd::serial_recalculated::combined_backward::k3_ssd_chunk_state_extended;
20use crate::mamba3::double_ssd::ssd::serial_recalculated::{
21 k1_ssd_chunk_cumsum, k2_ssd_bmm, k4_ssd_state_passing,
22};
23use crate::utils::fprim::{F, san};
24use burn::backend::Backend;
25use burn::tensor::s;
26
27#[non_exhaustive]
31pub struct CombinedSingleSsdGrads<B: Backend> {
32 pub d_v_bnlmhp: F<B, 6>,
34 pub d_da_bnlh: F<B, 4>,
36 pub d_b_bnlmhr: F<B, 6>,
38 pub d_c_bnlmhr: F<B, 6>,
40 pub d_gamma_bnlh: F<B, 4>,
42 pub d_scale_bnlh: F<B, 4>,
44 pub d_initial_state_bhpr: F<B, 4>,
46}
47
48#[allow(clippy::too_many_arguments)]
69pub fn combined_backward<B: Backend>(
70 d_y_bnlmhp: F<B, 6>,
71 d_final_bhpr: F<B, 4>,
72 v_bnlmhp: F<B, 6>,
74 da_bnlh: F<B, 4>,
75 b_bnlmhr: F<B, 6>,
76 c_bnlmhr: F<B, 6>,
77 gamma_bnlh: F<B, 4>,
78 scale_bnlh: F<B, 4>,
79 initial_state_bhpr: F<B, 4>,
80) -> CombinedSingleSsdGrads<B> {
81 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
82 let [.., state_rank] = b_bnlmhr.dims();
83 let device = v_bnlmhp.device();
84 let dtype = v_bnlmhp.dtype();
85
86 san(&d_y_bnlmhp);
87 san(&d_final_bhpr);
88 san(&v_bnlmhp);
89 san(&da_bnlh);
90 san(&b_bnlmhr);
91 san(&c_bnlmhr);
92 san(&gamma_bnlh);
93 san(&scale_bnlh);
94 san(&initial_state_bhpr);
95
96 let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(da_bnlh.clone());
102 san(&da_cumsum_bhnl);
103
104 let cb_bnhLMLM = k2_ssd_bmm(c_bnlmhr.clone(), b_bnlmhr.clone());
106 san(&cb_bnhLMLM);
107
108 let scale_bnlh11 = scale_bnlh.clone().unsqueeze_dims::<6>(&[3, 5]);
110 let k_scaled_bnlmhr = b_bnlmhr.clone() * scale_bnlh11.clone();
111 let (intra_chunk_state_bnhpr, k3_decay_bhnLM, k3_decayed_v_bnLMhp) =
112 k3_ssd_chunk_state_extended(
113 v_bnlmhp.clone(),
114 k_scaled_bnlmhr.clone(),
115 da_cumsum_bhnl.clone(),
116 );
117
118 let (chunk_input_state_bnhpr, _final_state_bhpr) = k4_ssd_state_passing(
120 intra_chunk_state_bnhpr,
121 da_chunk_end_bhn.clone(),
122 initial_state_bhpr,
123 );
124
125 let da_cumsum_bhnLM = da_cumsum_bhnl
127 .clone()
128 .unsqueeze_dim::<5>(4)
129 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
130 .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]);
131
132 let d_y_bnhLMp = d_y_bnlmhp
134 .clone()
135 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim])
136 .permute([0, 1, 3, 2, 4]);
137 san(&d_y_bnhLMp);
138
139 let (d_v_diag_bnlmhp, d_c_diag_bnlmhr, d_b_diag_bnlmhr, d_gamma_bnlh) = {
147 let c_bnlhmr = c_bnlmhr.clone().permute([0, 1, 2, 4, 3, 5]); let b_bnlhmr = b_bnlmhr.clone().permute([0, 1, 2, 4, 3, 5]); let v_bnlhmp = v_bnlmhp.clone().permute([0, 1, 2, 4, 3, 5]); let d_y_bnlhmp = d_y_bnlmhp.clone().permute([0, 1, 2, 4, 3, 5]); let qk_dot_bnlhmM = c_bnlhmr
154 .clone()
155 .matmul(b_bnlhmr.clone().permute([0, 1, 2, 3, 5, 4]));
156 let y_d_unw_bnlhmp = qk_dot_bnlhmM.clone().matmul(v_bnlhmp.clone());
158
159 let d_gamma_bnlh: F<B, 4> = (d_y_bnlhmp.clone() * y_d_unw_bnlhmp)
161 .sum_dim(5) .squeeze_dim::<5>(5) .sum_dim(4) .squeeze_dim::<4>(4); san(&d_gamma_bnlh);
166
167 let gamma_bnlh11 = gamma_bnlh.clone().unsqueeze_dims::<6>(&[4, 5]);
169 let d_y_d_unw_bnlhmp = d_y_bnlhmp * gamma_bnlh11;
170
171 let d_qk_dot_bnlhmM = d_y_d_unw_bnlhmp
173 .clone()
174 .matmul(v_bnlhmp.clone().permute([0, 1, 2, 3, 5, 4])); let d_v_diag_bnlhmp = qk_dot_bnlhmM
178 .permute([0, 1, 2, 3, 5, 4]) .matmul(d_y_d_unw_bnlhmp.clone()); let d_c_diag_bnlhmr = d_qk_dot_bnlhmM.clone().matmul(b_bnlhmr); let d_b_diag_bnlhmr = d_qk_dot_bnlhmM
185 .permute([0, 1, 2, 3, 5, 4]) .matmul(c_bnlhmr); let d_v_diag_bnlmhp = d_v_diag_bnlhmp.permute([0, 1, 2, 4, 3, 5]);
190 let d_c_diag_bnlmhr = d_c_diag_bnlhmr.permute([0, 1, 2, 4, 3, 5]);
191 let d_b_diag_bnlmhr = d_b_diag_bnlhmr.permute([0, 1, 2, 4, 3, 5]);
192 (
193 d_v_diag_bnlmhp,
194 d_c_diag_bnlmhr,
195 d_b_diag_bnlmhr,
196 d_gamma_bnlh,
197 )
198 };
199
200 let neg_inf_strict_ll: F<B, 2> =
203 F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype).triu(0);
204
205 let mut vec_lower_d_v_bhLMp: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
209 let mut vec_blue_d_c_bhLMr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
210 let mut vec_d_cb_bhLMLM: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
211 let mut vec_blue_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
212 let mut vec_lower_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
213 let mut vec_lower_d_scale_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
214 let mut vec_d_intra_bhpr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
215 let mut vec_d_da_end_bh: Vec<F<B, 2>> = Vec::with_capacity(nchunks);
216
217 let mut d_running_state_bhpr: F<B, 4> = d_final_bhpr;
218
219 for i_chunk in (0..nchunks).rev() {
220 let v_bhLMp: F<B, 4> = v_bnlmhp
222 .clone()
223 .slice(s![.., i_chunk, .., .., .., ..])
224 .squeeze_dim::<5>(1)
225 .reshape([batch, chunk_len * mimo_rank, nheads, per_head_dim])
226 .permute([0, 2, 1, 3]);
227
228 let c_bhLMr: F<B, 4> = c_bnlmhr
229 .clone()
230 .slice(s![.., i_chunk, .., .., .., ..])
231 .squeeze_dim::<5>(1)
232 .reshape([batch, chunk_len * mimo_rank, nheads, state_rank])
233 .permute([0, 2, 1, 3]);
234
235 let cb_bhLMLM: F<B, 4> = cb_bnhLMLM
236 .clone()
237 .slice(s![.., i_chunk, .., .., ..])
238 .squeeze_dim::<4>(1);
239
240 let da_cumsum_bhLM: F<B, 3> = da_cumsum_bhnLM
241 .clone()
242 .slice(s![.., .., i_chunk, ..])
243 .squeeze_dim::<3>(2);
244
245 let scale_bhLM: F<B, 3> = scale_bnlh
247 .clone()
248 .slice(s![.., i_chunk, .., ..]) .squeeze_dim::<3>(1)
250 .swap_dims(1, 2) .unsqueeze_dim::<4>(3) .expand([batch, nheads, chunk_len, mimo_rank])
253 .reshape([batch, nheads, chunk_len * mimo_rank]);
254
255 let chunk_input_state_bhpr: F<B, 4> = chunk_input_state_bnhpr
256 .clone()
257 .slice(s![.., i_chunk, .., .., ..])
258 .squeeze_dim::<4>(1);
259 san(&chunk_input_state_bhpr);
260
261 let d_y_bhLMp: F<B, 4> = d_y_bnhLMp
262 .clone()
263 .slice(s![.., i_chunk, .., .., ..])
264 .squeeze_dim::<4>(1);
265
266 let exp_da_cumsum_bhLM: F<B, 3> = da_cumsum_bhLM.clone().exp();
268 let exp_da_cumsum_bhLMp: F<B, 4> = exp_da_cumsum_bhLM
269 .clone()
270 .unsqueeze_dim::<4>(3)
271 .expand([batch, nheads, chunk_len * mimo_rank, per_head_dim]);
272 let d_ch_bhLMp: F<B, 4> = d_y_bhLMp.clone() * exp_da_cumsum_bhLMp.clone();
273 san(&d_ch_bhLMp);
274
275 let d_chunk_input_state_bhpr: F<B, 4> = c_bhLMr
276 .clone()
277 .permute([0, 1, 3, 2]) .matmul(d_ch_bhLMp.clone()) .permute([0, 1, 3, 2]); san(&d_chunk_input_state_bhpr);
281
282 let d_c_blue_bhLMr: F<B, 4> = d_ch_bhLMp.clone().matmul(chunk_input_state_bhpr.clone());
283 vec_blue_d_c_bhLMr.push(d_c_blue_bhLMr);
284
285 let ch_bhLMp: F<B, 4> = c_bhLMr
286 .clone()
287 .matmul(chunk_input_state_bhpr.clone().permute([0, 1, 3, 2]));
288 let d_da_blue_bhLM: F<B, 3> = (d_y_bhLMp.clone() * ch_bhLMp * exp_da_cumsum_bhLMp)
289 .sum_dim(3)
290 .squeeze_dim::<3>(3);
291 let d_da_blue_bhl: F<B, 3> = d_da_blue_bhLM
292 .reshape([batch, nheads, chunk_len, mimo_rank])
293 .sum_dim(3)
294 .squeeze_dim::<3>(3);
295 vec_blue_d_da_bhl.push(d_da_blue_bhl);
296
297 let da_target_bhLMLM: F<B, 4> = da_cumsum_bhLM.clone().unsqueeze_dim::<4>(3).expand([
299 batch,
300 nheads,
301 chunk_len * mimo_rank,
302 chunk_len * mimo_rank,
303 ]);
304 let da_source_bhLMLM: F<B, 4> = da_cumsum_bhLM.unsqueeze_dim::<4>(2).expand([
305 batch,
306 nheads,
307 chunk_len * mimo_rank,
308 chunk_len * mimo_rank,
309 ]);
310 let diff_bhLMLM = da_target_bhLMLM - da_source_bhLMLM;
311
312 let neg_inf_mimo_bhLMLM: F<B, 4> = neg_inf_strict_ll
315 .clone()
316 .unsqueeze_dims::<4>(&[0, 1])
317 .expand([batch, nheads, chunk_len, chunk_len])
318 .unsqueeze_dim::<5>(3)
319 .expand([batch, nheads, chunk_len, mimo_rank, chunk_len])
320 .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len])
321 .unsqueeze_dim::<5>(4)
322 .expand([batch, nheads, chunk_len * mimo_rank, chunk_len, mimo_rank])
323 .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]);
324 let decay_strict_bhLMLM = (diff_bhLMLM + neg_inf_mimo_bhLMLM).exp();
325 san(&decay_strict_bhLMLM);
326
327 let scale_col_bhLMLM: F<B, 4> = scale_bhLM
328 .unsqueeze_dim::<4>(2) .expand([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]);
330
331 let prod_bhLMLM = cb_bhLMLM.clone() * decay_strict_bhLMLM.clone();
333 let w_bhLMLM = prod_bhLMLM.clone() * scale_col_bhLMLM.clone();
334
335 let d_w_bhLMLM: F<B, 4> = d_y_bhLMp
337 .clone()
338 .matmul(v_bhLMp.clone().permute([0, 1, 3, 2]));
339 san(&d_w_bhLMLM);
340
341 let d_v_lower_bhLMp: F<B, 4> = w_bhLMLM.permute([0, 1, 3, 2]).matmul(d_y_bhLMp.clone());
343 san(&d_v_lower_bhLMp);
344 vec_lower_d_v_bhLMp.push(d_v_lower_bhLMp);
345
346 let d_prod_bhLMLM = d_w_bhLMLM.clone() * scale_col_bhLMLM;
348 let d_scale_at_bhLMLM = d_w_bhLMLM * prod_bhLMLM;
349
350 let d_cb_lower_bhLMLM = d_prod_bhLMLM.clone() * decay_strict_bhLMLM.clone();
352 vec_d_cb_bhLMLM.push(d_cb_lower_bhLMLM);
353
354 let d_decay_strict_bhLMLM = d_prod_bhLMLM * cb_bhLMLM;
356 let d_diff_bhLMLM = d_decay_strict_bhLMLM * decay_strict_bhLMLM;
357
358 let d_da_target_bhLM: F<B, 3> = d_diff_bhLMLM.clone().sum_dim(3).squeeze_dim::<3>(3);
359 let d_da_source_bhLM: F<B, 3> = d_diff_bhLMLM.sum_dim(2).squeeze_dim::<3>(2);
360 let d_da_lower_bhLM = d_da_target_bhLM - d_da_source_bhLM;
361 let d_da_lower_bhl: F<B, 3> = d_da_lower_bhLM
362 .reshape([batch, nheads, chunk_len, mimo_rank])
363 .sum_dim(3)
364 .squeeze_dim::<3>(3);
365 vec_lower_d_da_bhl.push(d_da_lower_bhl);
366
367 let d_scale_lower_bhl: F<B, 3> = d_scale_at_bhLMLM
369 .sum_dim(2) .squeeze_dim::<3>(2) .reshape([batch, nheads, chunk_len, mimo_rank])
372 .sum_dim(3) .squeeze_dim::<3>(3); vec_lower_d_scale_bhl.push(d_scale_lower_bhl);
375
376 vec_d_intra_bhpr.push(d_running_state_bhpr.clone());
378
379 let decay_chunk_bhpr: F<B, 4> = da_chunk_end_bhn
380 .clone()
381 .slice(s![.., .., i_chunk])
382 .exp()
383 .unsqueeze_dim::<4>(3)
384 .expand([batch, nheads, per_head_dim, state_rank]);
385 san(&decay_chunk_bhpr);
386
387 let d_decay_chunk_bhpr = d_running_state_bhpr.clone() * chunk_input_state_bhpr;
388 let d_da_chunk_end_bh: F<B, 2> = (d_decay_chunk_bhpr * decay_chunk_bhpr.clone())
389 .reshape([batch, nheads, per_head_dim * state_rank])
390 .sum_dim(2)
391 .squeeze_dim::<2>(2);
392 vec_d_da_end_bh.push(d_da_chunk_end_bh);
393
394 d_running_state_bhpr = decay_chunk_bhpr * d_running_state_bhpr + d_chunk_input_state_bhpr;
395 san(&d_running_state_bhpr);
396 }
397 let d_initial_state_bhpr = d_running_state_bhpr;
398
399 vec_lower_d_v_bhLMp.reverse();
401 vec_blue_d_c_bhLMr.reverse();
402 vec_d_cb_bhLMLM.reverse();
403 vec_blue_d_da_bhl.reverse();
404 vec_lower_d_da_bhl.reverse();
405 vec_lower_d_scale_bhl.reverse();
406 vec_d_intra_bhpr.reverse();
407 vec_d_da_end_bh.reverse();
408
409 let d_v_lower_bnhLMp: F<B, 5> = F::stack(vec_lower_d_v_bhLMp, 1);
411 let d_c_blue_bnhLMr: F<B, 5> = F::stack(vec_blue_d_c_bhLMr, 1);
412 let d_cb_bnhLMLM: F<B, 5> = F::stack(vec_d_cb_bhLMLM, 1);
413 let d_da_blue_bhnl: F<B, 4> = F::stack(vec_blue_d_da_bhl, 2);
414 let d_da_lower_bhnl: F<B, 4> = F::stack(vec_lower_d_da_bhl, 2);
415 let d_scale_lower_bhnl: F<B, 4> = F::stack(vec_lower_d_scale_bhl, 2);
416 let d_intra_chunk_state_bnhpr: F<B, 5> = F::stack(vec_d_intra_bhpr, 1);
417 let d_da_end_bhn: F<B, 3> = F::stack(vec_d_da_end_bh, 2);
418 let d_da_cumsum_k4_bhnl: F<B, 4> = {
419 let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
420 let d_da_end_bhn1 = d_da_end_bhn.unsqueeze_dim::<4>(3);
421 F::cat(vec![zeros, d_da_end_bhn1], 3)
422 };
423
424 let v_bnLMhp =
432 v_bnlmhp
433 .clone()
434 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
435 let k_scaled_bnLMhr =
436 k_scaled_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
437 let k_scaled_bnhLMr = k_scaled_bnLMhr.permute([0, 1, 3, 2, 4]);
438 let decayed_v_bnhpLM = k3_decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
439
440 let d_decayed_v_bnhpLM: F<B, 5> = d_intra_chunk_state_bnhpr
441 .clone()
442 .matmul(k_scaled_bnhLMr.clone().permute([0, 1, 2, 4, 3])); let d_k_scaled_bnhLMr: F<B, 5> = decayed_v_bnhpLM
444 .permute([0, 1, 2, 4, 3]) .matmul(d_intra_chunk_state_bnhpr);
446
447 let d_decayed_v_bnLMhp = d_decayed_v_bnhpLM.permute([0, 1, 4, 2, 3]);
448 let d_decay_bhnLM: F<B, 4> = (d_decayed_v_bnLMhp.clone() * v_bnLMhp)
449 .sum_dim(4)
450 .squeeze_dim::<4>(4)
451 .permute([0, 3, 1, 2]);
452
453 let k3_decay_bnLMh1 = k3_decay_bhnLM
454 .clone()
455 .permute([0, 2, 3, 1])
456 .unsqueeze_dim::<5>(4);
457 let d_v_k3_bnLMhp: F<B, 5> = d_decayed_v_bnLMhp * k3_decay_bnLMh1;
458 let d_v_k3_bnlmhp: F<B, 6> =
459 d_v_k3_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
460
461 let d_decay_times_decay_bhnLM = d_decay_bhnLM * k3_decay_bhnLM;
463 let d_a_cumsum_last_bhn: F<B, 3> = d_decay_times_decay_bhnLM
464 .clone()
465 .sum_dim(3)
466 .squeeze_dim::<3>(3);
467 let d_da_cumsum_bhnLM = -d_decay_times_decay_bhnLM;
468
469 let d_da_cumsum_k3_from_fused_bhnl: F<B, 4> = d_da_cumsum_bhnLM
470 .reshape([batch, nheads, nchunks, chunk_len, mimo_rank])
471 .sum_dim(4)
472 .squeeze_dim::<4>(4);
473 let d_da_cumsum_k3_from_last_bhnl: F<B, 4> = {
474 let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
475 let d_last = d_a_cumsum_last_bhn.unsqueeze_dim::<4>(3);
476 F::cat(vec![zeros, d_last], 3)
477 };
478 let d_da_cumsum_k3_bhnl = d_da_cumsum_k3_from_fused_bhnl + d_da_cumsum_k3_from_last_bhnl;
479
480 let d_k_scaled_bnlmhr: F<B, 6> = d_k_scaled_bnhLMr
482 .permute([0, 1, 3, 2, 4]) .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
484 let d_b_k3_bnlmhr: F<B, 6> = d_k_scaled_bnlmhr.clone() * scale_bnlh11;
485 let d_scale_k3_bnlh: F<B, 4> = (d_k_scaled_bnlmhr * b_bnlmhr.clone())
486 .sum_dim(5) .squeeze_dim::<5>(5) .sum_dim(3) .squeeze_dim::<4>(3); let b_bnLMhr =
495 b_bnlmhr
496 .clone()
497 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
498 let c_bnhLMr = c_bnlmhr
499 .clone()
500 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank])
501 .permute([0, 1, 3, 2, 4]);
502 let b_for_k2_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
503
504 let d_c_k2_bnhLMr: F<B, 5> = d_cb_bnhLMLM.clone().matmul(b_for_k2_bnhLMr);
505 let d_b_k2_bnhrLM: F<B, 5> = c_bnhLMr.permute([0, 1, 2, 4, 3]).matmul(d_cb_bnhLMLM);
506
507 let d_c_k2_bnlmhr: F<B, 6> = d_c_k2_bnhLMr
508 .permute([0, 1, 3, 2, 4])
509 .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
510 let d_b_k2_bnlmhr: F<B, 6> = d_b_k2_bnhrLM
511 .permute([0, 1, 4, 2, 3])
512 .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
513
514 let d_c_blue_bnlmhr: F<B, 6> = d_c_blue_bnhLMr
516 .permute([0, 1, 3, 2, 4])
517 .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
518 let d_v_lower_bnlmhp: F<B, 6> = d_v_lower_bnhLMp.permute([0, 1, 3, 2, 4]).reshape([
519 batch,
520 nchunks,
521 chunk_len,
522 mimo_rank,
523 nheads,
524 per_head_dim,
525 ]);
526
527 let d_da_cumsum_bhnl =
531 d_da_blue_bhnl + d_da_lower_bhnl + d_da_cumsum_k3_bhnl + d_da_cumsum_k4_bhnl;
532 san(&d_da_cumsum_bhnl);
533
534 let d_da_bhnl = {
536 let d_total_bhnl = d_da_cumsum_bhnl
537 .clone()
538 .sum_dim(3)
539 .expand([batch, nheads, nchunks, chunk_len]);
540 let prefix_bhnl = d_da_cumsum_bhnl.cumsum(3);
541 let zeros_bhn1 = F::<B, 4>::zeros([batch, nheads, nchunks, 1], &device, dtype);
542 let prefix_shifted_bhnl =
543 F::cat(vec![zeros_bhn1, prefix_bhnl.narrow(3, 0, chunk_len - 1)], 3);
544 d_total_bhnl - prefix_shifted_bhnl
545 };
546 let d_da_bnlh = d_da_bhnl.permute([0, 2, 3, 1]);
547
548 let d_v_bnlmhp = d_v_k3_bnlmhp + d_v_lower_bnlmhp + d_v_diag_bnlmhp;
550 let d_b_bnlmhr = d_b_k2_bnlmhr + d_b_k3_bnlmhr + d_b_diag_bnlmhr;
551 let d_c_bnlmhr = d_c_k2_bnlmhr + d_c_blue_bnlmhr + d_c_diag_bnlmhr;
552 let d_scale_bnlh = d_scale_lower_bhnl.permute([0, 2, 3, 1]) + d_scale_k3_bnlh;
553
554 san(&d_v_bnlmhp);
555 san(&d_da_bnlh);
556 san(&d_b_bnlmhr);
557 san(&d_c_bnlmhr);
558 san(&d_gamma_bnlh);
559 san(&d_scale_bnlh);
560 san(&d_initial_state_bhpr);
561
562 CombinedSingleSsdGrads {
563 d_v_bnlmhp,
564 d_da_bnlh,
565 d_b_bnlmhr,
566 d_c_bnlmhr,
567 d_gamma_bnlh,
568 d_scale_bnlh,
569 d_initial_state_bhpr,
570 }
571}