1#![allow(non_snake_case)]
17
18use super::serial_recalculated::{k1_ssd_chunk_cumsum, k2_ssd_bmm, k4_ssd_state_passing};
19use crate::utils::fprim::{F, san};
20use burn::backend::Backend;
21use burn::tensor::s;
22
23#[non_exhaustive]
26pub struct CombinedGrads<B: Backend> {
27 pub d_v_bnlmhp: F<B, 6>,
29 pub d_da_bnlh: F<B, 4>,
31 pub d_b_bnlmhr: F<B, 6>,
33 pub d_c_bnlmhr: F<B, 6>,
35 pub d_initial_state_bhpr: F<B, 4>,
37}
38
39pub fn k3_ssd_chunk_state_extended<B: Backend>(
50 v_bnlmhp: F<B, 6>,
51 b_bnlmhr: F<B, 6>,
52 da_cumsum_bhnl: F<B, 4>,
53) -> (F<B, 5>, F<B, 4>, F<B, 5>) {
54 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
55 let [.., state_rank] = b_bnlmhr.dims();
56
57 let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
58 let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
59
60 let da_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
61 let da_cumsum_bhnLM = da_cumsum_bhnl
62 .unsqueeze_dim::<5>(4) .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); let decay_bhnLM = (da_cumsum_last_bhn1 - da_cumsum_bhnLM).exp();
66 san(&decay_bhnLM);
67
68 let decay_bnLMh1 = decay_bhnLM
69 .clone()
70 .permute([0, 2, 3, 1])
71 .unsqueeze_dim::<5>(4);
72 let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp;
73 san(&decayed_v_bnLMhp);
74
75 let decayed_v_bnhpLM = decayed_v_bnLMhp.clone().permute([0, 1, 3, 4, 2]);
76 let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
77 let intra_chunk_state_bnhpr = decayed_v_bnhpLM.matmul(b_bnhLMr);
78 san(&intra_chunk_state_bnhpr);
79
80 (intra_chunk_state_bnhpr, decay_bhnLM, decayed_v_bnLMhp)
81}
82
83pub fn combined_backward<B: Backend>(
99 d_y_bnlmhp: F<B, 6>,
100 d_final_bhpr: F<B, 4>,
101 v_bnlmhp: F<B, 6>,
103 da_bnlh: F<B, 4>,
104 b_bnlmhr: F<B, 6>,
105 c_bnlmhr: F<B, 6>,
106 initial_state_bhpr: F<B, 4>,
107) -> CombinedGrads<B> {
108 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
109 let [.., state_rank] = b_bnlmhr.dims();
110 let device = v_bnlmhp.device();
111 let dtype = v_bnlmhp.dtype();
112
113 san(&d_y_bnlmhp);
114 san(&d_final_bhpr);
115 san(&v_bnlmhp);
116 san(&da_bnlh);
117 san(&b_bnlmhr);
118 san(&c_bnlmhr);
119 san(&initial_state_bhpr);
120
121 let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(da_bnlh.clone());
127 san(&da_cumsum_bhnl);
128
129 let cb_bnhLMLM = k2_ssd_bmm(c_bnlmhr.clone(), b_bnlmhr.clone());
131 san(&cb_bnhLMLM);
132
133 let (intra_chunk_state_bnhpr, k3_decay_bhnLM, k3_decayed_v_bnLMhp) =
135 k3_ssd_chunk_state_extended(v_bnlmhp.clone(), b_bnlmhr.clone(), da_cumsum_bhnl.clone());
136
137 let (chunk_input_state_bnhpr, _final_state_bhpr) = k4_ssd_state_passing(
139 intra_chunk_state_bnhpr,
140 da_chunk_end_bhn.clone(),
141 initial_state_bhpr,
142 );
143
144 let da_cumsum_bhnLM = da_cumsum_bhnl
151 .clone()
152 .unsqueeze_dim::<5>(4) .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); let d_y_bnhLMp = d_y_bnlmhp
159 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]) .permute([0, 1, 3, 2, 4]); san(&d_y_bnhLMp);
162
163 let neg_inf_base_ll: F<B, 2> =
165 { F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype).triu(1) };
166
167 let mut vec_orange_d_v_bhLMp: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
174 let mut vec_blue_d_c_bhLMr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
175 let mut vec_d_cb_bhLMLM: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
176 let mut vec_blue_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
177 let mut vec_orange_d_da_bhl: Vec<F<B, 3>> = Vec::with_capacity(nchunks);
178 let mut vec_d_intra_bhpr: Vec<F<B, 4>> = Vec::with_capacity(nchunks);
179 let mut vec_d_da_end_bh: Vec<F<B, 2>> = Vec::with_capacity(nchunks);
180
181 let mut d_running_state_bhpr: F<B, 4> = d_final_bhpr;
182
183 for i_chunk in (0..nchunks).rev() {
184 let v_bhLMp: F<B, 4> = v_bnlmhp
186 .clone()
187 .slice(s![.., i_chunk, .., .., .., ..]) .squeeze_dim::<5>(1) .reshape([batch, chunk_len * mimo_rank, nheads, per_head_dim]) .permute([0, 2, 1, 3]); let c_bhLMr: F<B, 4> = c_bnlmhr
193 .clone()
194 .slice(s![.., i_chunk, .., .., .., ..]) .squeeze_dim::<5>(1) .reshape([batch, chunk_len * mimo_rank, nheads, state_rank]) .permute([0, 2, 1, 3]); let cb_bhLMLM: F<B, 4> = cb_bnhLMLM
200 .clone()
201 .slice(s![.., i_chunk, .., .., ..]) .squeeze_dim::<4>(1); let da_cumsum_bhLM: F<B, 3> = da_cumsum_bhnLM
205 .clone()
206 .slice(s![.., .., i_chunk, ..]) .squeeze_dim::<3>(2); let chunk_input_state_bhpr: F<B, 4> = chunk_input_state_bnhpr
210 .clone()
211 .slice(s![.., i_chunk, .., .., ..]) .squeeze_dim::<4>(1); san(&chunk_input_state_bhpr);
214
215 let d_y_bhLMp: F<B, 4> = d_y_bnhLMp
216 .clone()
217 .slice(s![.., i_chunk, .., .., ..]) .squeeze_dim::<4>(1); let exp_da_cumsum_bhLM: F<B, 3> = da_cumsum_bhLM.clone().exp();
226 let exp_da_cumsum_bhLMp: F<B, 4> = exp_da_cumsum_bhLM
227 .clone()
228 .unsqueeze_dim::<4>(3) .expand([batch, nheads, chunk_len * mimo_rank, per_head_dim]); let d_ch_bhLMp: F<B, 4> = d_y_bhLMp.clone() * exp_da_cumsum_bhLMp.clone();
231 san(&d_ch_bhLMp);
232
233 let d_chunk_input_state_bhpr: F<B, 4> = c_bhLMr
236 .clone()
237 .permute([0, 1, 3, 2]) .matmul(d_ch_bhLMp.clone()) .permute([0, 1, 3, 2]); san(&d_chunk_input_state_bhpr);
241
242 let d_c_blue_bhLMr: F<B, 4> = d_ch_bhLMp.matmul(chunk_input_state_bhpr.clone());
245 san(&d_c_blue_bhLMr);
246 vec_blue_d_c_bhLMr.push(d_c_blue_bhLMr);
247
248 let ch_bhLMp: F<B, 4> = c_bhLMr.clone().matmul(
252 chunk_input_state_bhpr.clone().permute([0, 1, 3, 2]), ); let d_da_blue_bhLM: F<B, 3> = (d_y_bhLMp.clone() * ch_bhLMp * exp_da_cumsum_bhLMp)
255 .sum_dim(3) .squeeze_dim::<3>(3); san(&d_da_blue_bhLM);
258
259 let d_da_blue_bhl: F<B, 3> = d_da_blue_bhLM
261 .reshape([batch, nheads, chunk_len, mimo_rank]) .sum_dim(3) .squeeze_dim::<3>(3); vec_blue_d_da_bhl.push(d_da_blue_bhl);
265
266 let da_target_bhLMLM: F<B, 4> = da_cumsum_bhLM
271 .clone()
272 .unsqueeze_dim::<4>(3) .expand([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]); let da_source_bhLMLM: F<B, 4> = da_cumsum_bhLM
275 .unsqueeze_dim::<4>(2) .expand([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]); let diff_bhLMLM = da_target_bhLMLM - da_source_bhLMLM;
278 san(&diff_bhLMLM);
279
280 let neg_inf_mimo_bhLMLM: F<B, 4> = neg_inf_base_ll
283 .clone()
284 .unsqueeze_dims::<4>(&[0, 1]) .expand([batch, nheads, chunk_len, chunk_len]) .unsqueeze_dim::<5>(3) .expand([batch, nheads, chunk_len, mimo_rank, chunk_len]) .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len]) .unsqueeze_dim::<5>(4) .expand([batch, nheads, chunk_len * mimo_rank, chunk_len, mimo_rank]) .reshape([batch, nheads, chunk_len * mimo_rank, chunk_len * mimo_rank]); let decay_bhLMLM = (diff_bhLMLM + neg_inf_mimo_bhLMLM).exp();
293 san(&decay_bhLMLM);
294
295 let d_orange_bhLMp = d_y_bhLMp;
297 let w_bhLMLM = cb_bhLMLM.clone() * decay_bhLMLM.clone();
298 let d_w_bhLMLM: F<B, 4> = d_orange_bhLMp.clone().matmul(
299 v_bhLMp.clone().permute([0, 1, 3, 2]), ); san(&d_w_bhLMLM);
302 let d_v_orange_bhLMp: F<B, 4> = w_bhLMLM
303 .permute([0, 1, 3, 2]) .matmul(d_orange_bhLMp); san(&d_v_orange_bhLMp);
306 vec_orange_d_v_bhLMp.push(d_v_orange_bhLMp);
307
308 let d_cb_bhLMLM = d_w_bhLMLM.clone() * decay_bhLMLM.clone();
311 vec_d_cb_bhLMLM.push(d_cb_bhLMLM);
312
313 let d_decay_bhLMLM = d_w_bhLMLM * cb_bhLMLM;
314 let d_diff_bhLMLM = d_decay_bhLMLM * decay_bhLMLM;
315
316 let d_da_target_bhLM: F<B, 3> = d_diff_bhLMLM
320 .clone()
321 .sum_dim(3) .squeeze_dim::<3>(3); let d_da_source_bhLM: F<B, 3> = d_diff_bhLMLM
324 .sum_dim(2) .squeeze_dim::<3>(2); let d_da_orange_bhLM = d_da_target_bhLM - d_da_source_bhLM;
327 san(&d_da_orange_bhLM);
328
329 let d_da_orange_bhl: F<B, 3> = d_da_orange_bhLM
331 .reshape([batch, nheads, chunk_len, mimo_rank]) .sum_dim(3) .squeeze_dim::<3>(3); vec_orange_d_da_bhl.push(d_da_orange_bhl);
335
336 vec_d_intra_bhpr.push(d_running_state_bhpr.clone());
343
344 let decay_chunk_bhpr: F<B, 4> = da_chunk_end_bhn
345 .clone()
346 .slice(s![.., .., i_chunk]) .exp() .unsqueeze_dim::<4>(3) .expand([batch, nheads, per_head_dim, state_rank]); san(&decay_chunk_bhpr);
351
352 let d_decay_chunk_bhpr = d_running_state_bhpr.clone() * chunk_input_state_bhpr;
353 let d_da_chunk_end_bh: F<B, 2> = (d_decay_chunk_bhpr * decay_chunk_bhpr.clone())
355 .reshape([batch, nheads, per_head_dim * state_rank]) .sum_dim(2) .squeeze_dim::<2>(2); san(&d_da_chunk_end_bh);
359 vec_d_da_end_bh.push(d_da_chunk_end_bh);
360
361 d_running_state_bhpr = decay_chunk_bhpr * d_running_state_bhpr + d_chunk_input_state_bhpr;
362 san(&d_running_state_bhpr);
363 }
364 let d_initial_state_bhpr = d_running_state_bhpr;
365
366 vec_orange_d_v_bhLMp.reverse();
368 vec_blue_d_c_bhLMr.reverse();
369 vec_d_cb_bhLMLM.reverse();
370 vec_blue_d_da_bhl.reverse();
371 vec_orange_d_da_bhl.reverse();
372 vec_d_intra_bhpr.reverse();
373 vec_d_da_end_bh.reverse();
374
375 let d_v_orange_bnhLMp: F<B, 5> = F::stack(vec_orange_d_v_bhLMp, 1);
377 let d_c_blue_bnhLMr: F<B, 5> = F::stack(vec_blue_d_c_bhLMr, 1);
378 let d_cb_bnhLMLM: F<B, 5> = F::stack(vec_d_cb_bhLMLM, 1);
379 let d_da_blue_bhnl: F<B, 4> = F::stack(vec_blue_d_da_bhl, 2);
380 let d_da_orange_bhnl: F<B, 4> = F::stack(vec_orange_d_da_bhl, 2);
381 let d_intra_chunk_state_bnhpr: F<B, 5> = F::stack(vec_d_intra_bhpr, 1);
382 let d_da_end_bhn: F<B, 3> = F::stack(vec_d_da_end_bh, 2);
385 let d_da_cumsum_k4_bhnl: F<B, 4> = {
386 let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
387 let d_da_end_bhn1 = d_da_end_bhn.unsqueeze_dim::<4>(3);
388 F::cat(vec![zeros, d_da_end_bhn1], 3)
389 };
390
391 let v_bnLMhp =
405 v_bnlmhp
406 .clone()
407 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
408 let b_bnLMhr =
409 b_bnlmhr
410 .clone()
411 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
412 let b_bnhLMr = b_bnLMhr.clone().permute([0, 1, 3, 2, 4]);
413 let decayed_v_bnhpLM = k3_decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
414
415 let d_decayed_v_bnhpLM: F<B, 5> = d_intra_chunk_state_bnhpr.clone().matmul(
416 b_bnhLMr.permute([0, 1, 2, 4, 3]), ); let d_b_k3_bnhLMr: F<B, 5> = decayed_v_bnhpLM
419 .permute([0, 1, 2, 4, 3]) .matmul(d_intra_chunk_state_bnhpr); let d_decayed_v_bnLMhp = d_decayed_v_bnhpLM.permute([0, 1, 4, 2, 3]);
423 let d_decay_bhnLM: F<B, 4> = (d_decayed_v_bnLMhp.clone() * v_bnLMhp)
424 .sum_dim(4) .squeeze_dim::<4>(4) .permute([0, 3, 1, 2]); let k3_decay_bnLMh1 = k3_decay_bhnLM
430 .clone()
431 .permute([0, 2, 3, 1]) .unsqueeze_dim::<5>(4); let d_v_k3_bnLMhp: F<B, 5> = d_decayed_v_bnLMhp * k3_decay_bnLMh1;
434 let d_v_k3_bnlrhp: F<B, 6> =
435 d_v_k3_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
436
437 let d_decay_times_decay_bhnLM = d_decay_bhnLM * k3_decay_bhnLM;
439 let d_a_cumsum_last_bhn: F<B, 3> = d_decay_times_decay_bhnLM
441 .clone()
442 .sum_dim(3) .squeeze_dim::<3>(3); let d_da_cumsum_bhnLM = -d_decay_times_decay_bhnLM;
446
447 let d_da_cumsum_k3_from_fused_bhnl: F<B, 4> = d_da_cumsum_bhnLM
449 .reshape([batch, nheads, nchunks, chunk_len, mimo_rank]) .sum_dim(4) .squeeze_dim::<4>(4); let d_da_cumsum_k3_from_last_bhnl: F<B, 4> = {
454 let zeros = F::<B, 4>::zeros([batch, nheads, nchunks, chunk_len - 1], &device, dtype);
455 let d_last = d_a_cumsum_last_bhn.unsqueeze_dim::<4>(3);
456 F::cat(vec![zeros, d_last], 3)
457 };
458 let d_da_cumsum_k3_bhnl = d_da_cumsum_k3_from_fused_bhnl + d_da_cumsum_k3_from_last_bhnl;
459
460 let d_b_k3_bnLMhr = d_b_k3_bnhLMr.permute([0, 1, 3, 2, 4]);
462 let d_b_k3_bnlmhr: F<B, 6> =
463 d_b_k3_bnLMhr.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
464
465 let c_bnhLMr = c_bnlmhr
473 .clone()
474 .reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]) .permute([0, 1, 3, 2, 4]); let b_for_k2_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
477
478 let d_c_k2_bnhLMr: F<B, 5> = d_cb_bnhLMLM.clone().matmul(b_for_k2_bnhLMr);
479 let d_b_k2_bnhrLM: F<B, 5> = c_bnhLMr
480 .permute([0, 1, 2, 4, 3]) .matmul(d_cb_bnhLMLM); let d_c_k2_bnlmhr: F<B, 6> = d_c_k2_bnhLMr
485 .permute([0, 1, 3, 2, 4]) .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]); let d_b_k2_bnlmhr: F<B, 6> = d_b_k2_bnhrLM
488 .permute([0, 1, 4, 2, 3]) .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]); let d_c_blue_bnlmhr: F<B, 6> = d_c_blue_bnhLMr
493 .permute([0, 1, 3, 2, 4]) .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]); let d_v_orange_bnlrhp: F<B, 6> = d_v_orange_bnhLMp
496 .permute([0, 1, 3, 2, 4]) .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]); let d_da_cumsum_bhnl =
503 d_da_blue_bhnl + d_da_orange_bhnl + d_da_cumsum_k3_bhnl + d_da_cumsum_k4_bhnl;
504 san(&d_da_cumsum_bhnl);
505
506 let d_da_bhnl = {
510 let d_total_bhnl = d_da_cumsum_bhnl
511 .clone()
512 .sum_dim(3) .expand([batch, nheads, nchunks, chunk_len]); let prefix_bhnl = d_da_cumsum_bhnl.cumsum(3);
515 let zeros_bhn1 = F::<B, 4>::zeros([batch, nheads, nchunks, 1], &device, dtype);
516 let prefix_shifted_bhnl =
517 F::cat(vec![zeros_bhn1, prefix_bhnl.narrow(3, 0, chunk_len - 1)], 3);
518 d_total_bhnl - prefix_shifted_bhnl
519 };
520 san(&d_da_bhnl);
521 let d_da_bnlh = d_da_bhnl.permute([0, 2, 3, 1]);
523
524 let d_v_bnlmhp = d_v_k3_bnlrhp + d_v_orange_bnlrhp;
526 let d_b_bnlmhr = d_b_k2_bnlmhr + d_b_k3_bnlmhr;
527 let d_c_bnlmhr = d_c_k2_bnlmhr + d_c_blue_bnlmhr;
528
529 san(&d_v_bnlmhp);
530 san(&d_da_bnlh);
531 san(&d_b_bnlmhr);
532 san(&d_c_bnlmhr);
533 san(&d_initial_state_bhpr);
534
535 CombinedGrads {
536 d_v_bnlmhp,
537 d_da_bnlh,
538 d_b_bnlmhr,
539 d_c_bnlmhr,
540 d_initial_state_bhpr,
541 }
542}