burn_mamba/mamba2/ssd/serial.rs
1#![allow(unused_variables)]
2
3use crate::mamba2::prelude::*;
4use crate::utils::sanity::sanity as san;
5use burn::prelude::*;
6
7impl<B: Backend> Mamba2<B> {
8 /// Forward pass for the Mamba-2 SSD module.
9 ///
10 /// Returns:
11 /// - `y_bnlhp`.
12 /// - `final_state_bhpr`.
13 #[allow(non_snake_case)]
14 pub fn ssd_serial(input: super::Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>) {
15 let [batch, nchunks, chunk_len, nheads, per_head_dim] = input.x_bnlhp.dims();
16 let [.., ngroups, state_rank] = input.b_bnlgr.dims();
17 let device = input.x_bnlhp.device();
18 assert_ne!(ngroups, 0);
19 assert_eq!(nheads % ngroups, 0);
20 assert!(nchunks > 0, "sequence length must be at least 1");
21 // `heads_per_group` is called `nheads_ngroups_ratio` in every Triton kernel.
22 // It is the compile-time constant used by GQA (Grouped Query Attention) to map
23 // a head index to its B/C group: `group_idx = head_idx / heads_per_group`.
24 let heads_per_group = nheads / ngroups;
25
26 san(&input.x_bnlhp);
27 san(&input.dt_bnlh);
28 san(&input.a_decay_h);
29 san(&input.b_bnlgr);
30 san(&input.c_bnlgr);
31 san(&input.d_h);
32 san(&input.initial_state_bhpr);
33
34 assert!(
35 input.init_state_hpr.is_none(),
36 "init_state_hpr not yet implemented"
37 );
38
39 // ── Permutes ──────────────────────────────────────────────────────────────────
40 // Note: dt_bnlh calculation (originally in Kernel 1) moved to Step 4 (before padding).
41 let dt_discretized_bhnl = input.dt_bnlh.permute([0, 3, 1, 2]);
42 assert_eq!(
43 [batch, nheads, nchunks, chunk_len],
44 dt_discretized_bhnl.dims()
45 );
46 san(&dt_discretized_bhnl);
47
48 // ── Kernel 1 ──────────────────────────────────────────────────────────────────
49 // IO: (..) -> (da_cumsum_bhnl [used in K3+K5][*], da_chunk_end_bhn [used in K4][omitted][*])
50 let (da_cumsum_bhnl, da_chunk_end_bhn): (Tensor<B, 4>, Tensor<B, 3>) =
51 k1_ssd_chunk_cumsum(dt_discretized_bhnl.clone(), input.a_decay_h.clone());
52 assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
53 assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
54 san(&da_cumsum_bhnl);
55 san(&da_chunk_end_bhn);
56
57 // ── Kernel 2 ──────────────────────────────────────────────────────────────────
58 // IO: (..) -> (cb_bngll [used in K5][!])
59 let cb_bngll: Tensor<B, 5> = k2_ssd_bmm(input.c_bnlgr.clone(), input.b_bnlgr.clone());
60 assert_eq!(
61 [batch, nchunks, ngroups, chunk_len, chunk_len],
62 cb_bngll.dims()
63 );
64 // Note: cb_bngll is then only used by Kernel 5.
65 san(&cb_bngll);
66
67 // ── Kernel 3 ──────────────────────────────────────────────────────────────────
68 // IO: (..) -> (intra_chunk_state_bnhpr [used in K4][!])
69 let intra_chunk_state_bnhpr: Tensor<B, 5> = k3_ssd_chunk_state(
70 input.x_bnlhp.clone(),
71 input.b_bnlgr.clone(),
72 da_cumsum_bhnl.clone(),
73 dt_discretized_bhnl.clone(),
74 );
75 assert_eq!(
76 [batch, nchunks, nheads, per_head_dim, state_rank],
77 intra_chunk_state_bnhpr.dims()
78 );
79 san(&intra_chunk_state_bnhpr);
80
81 // ── Kernel 4 ──────────────────────────────────────────────────────────────────
82 // IO: (..) -> (chunk_input_state_bnhpr [used in K5][!], final_state_bhpr [final output])
83 let flat_state_dim = per_head_dim * state_rank;
84 let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<B, 5>, Tensor<B, 4>) =
85 k4_ssd_state_passing(
86 intra_chunk_state_bnhpr.clone(),
87 da_chunk_end_bhn.clone(),
88 input.initial_state_bhpr,
89 );
90 assert_eq!(
91 [batch, nchunks, nheads, per_head_dim, state_rank],
92 chunk_input_state_bnhpr.dims()
93 );
94 assert_eq!(
95 [batch, nheads, per_head_dim, state_rank],
96 final_state_bhpr.dims()
97 );
98 san(&chunk_input_state_bnhpr);
99 san(&final_state_bhpr);
100
101 // ── Kernel 5 ──────────────────────────────────────────────────────────────────
102 let y_bnlhp: Tensor<B, 5> = k5_ssd_chunk_scan(
103 da_cumsum_bhnl,
104 dt_discretized_bhnl,
105 input.x_bnlhp,
106 input.c_bnlgr,
107 cb_bngll,
108 chunk_input_state_bnhpr,
109 input.d_h,
110 );
111 assert_eq!(
112 [batch, nchunks, chunk_len, nheads, per_head_dim],
113 y_bnlhp.dims()
114 );
115 san(&y_bnlhp);
116
117 (y_bnlhp, final_state_bhpr)
118 }
119}
120
121/// Based on the Kernel 1 Triton reference `_chunk_cumsum_fwd_kernel` (`ssd_chunk_state.py`).
122///
123/// Returns:
124/// - da_cumsum_bhnl [used in K3+K5][*] - intra-chunk cumsum.
125/// - da_chunk_end_bhn [used in K4][omitted][*] - last da_cumsum per chunk.
126pub fn k1_ssd_chunk_cumsum<B: Backend>(
127 dt_discretized_bhnl: Tensor<B, 4>,
128 a_decay_h: Tensor<B, 1>,
129) -> (Tensor<B, 4>, Tensor<B, 3>) {
130 let [batch, nheads, nchunks, chunk_len] = dt_discretized_bhnl.dims();
131 let da_cumsum_bhnl: Tensor<B, 4> = {
132 let a_decay_bhnl = a_decay_h
133 // - 1/6: unsqueeze-dims: (a_decay_h [*]) -> (a_decay_1h11)
134 .unsqueeze_dims::<4>(&[0, 2, 3]) // a_decay_1h11
135 // - 2: expand: (a_decay_1h11) -> (a_decay_bhnl)
136 .expand([batch, nheads, nchunks, chunk_len]);
137 // - 3: mul: (dt_discretized_bhnl [*], a_decay_bhnl) -> (da_bhnl)
138 // - 4: cumsum: (da_bhnl) -> (da_cumsum_bhnl [out][*])
139 (dt_discretized_bhnl * a_decay_bhnl).cumsum(3)
140 };
141 assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
142
143 let da_chunk_end_bhn = da_cumsum_bhnl
144 .clone()
145 // - 5: slice: (da_cumsum_bhnl [*]) -> (da_cumsum_bhn1)
146 .slice(s![.., .., .., -1]) // da_cumsum_bhn1
147 // - 6/6: squeeze: (da_cumsum_bhn1) -> (da_chunk_end_bhn [out])
148 .squeeze_dim::<3>(3);
149 assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
150
151 (da_cumsum_bhnl, da_chunk_end_bhn)
152}
153
154/// Based on the Kernel 2 Triton reference `_bmm_chunk_fwd_kernel` (`ssd_bmm.py`).
155///
156/// Returns:
157/// - cb_bngll [used in K5][!].
158pub fn k2_ssd_bmm<B: Backend>(c_bnlgr: Tensor<B, 5>, b_bnlgr: Tensor<B, 5>) -> Tensor<B, 5> {
159 let [batch, nchunks, chunk_len, ngroups, state_rank] = c_bnlgr.dims();
160
161 // - 1/3: permute: (c_bnlgr [in][*]) -> (c_bnglr)
162 let c_bnglr = c_bnlgr.clone().permute([0, 1, 3, 2, 4]);
163 // - 2: permute: (b_bnlgr [in][*]) -> (b_bngrl)
164 let b_bngrl = b_bnlgr.clone().permute([0, 1, 3, 4, 2]);
165 // - 3/3: matmul: (c_bnglr, b_bngrl) -> (cb_bngll [out][!])
166 let cb_bngll: Tensor<B, 5> = c_bnglr.matmul(b_bngrl);
167 assert_eq!(
168 [batch, nchunks, ngroups, chunk_len, chunk_len],
169 cb_bngll.dims()
170 );
171 // Note: cb_bngll is then only used by Kernel 5.
172 cb_bngll
173}
174
175/// Based on the Kernel 3 Triton reference `_chunk_state_fwd_kernel` (`ssd_chunk_state.py`).
176///
177/// Returns:
178/// - cb_bngll [used in K5][!] - state assuming zero initial state at each chunk boundary.
179/// - b_bar_scale_bhnl [*] - intermediary
180pub fn k3_ssd_chunk_state<B: Backend>(
181 x_bnlhp: Tensor<B, 5>,
182 b_bnlgr: Tensor<B, 5>,
183 da_cumsum_bhnl: Tensor<B, 4>,
184 dt_discretized_bhnl: Tensor<B, 4>,
185) -> Tensor<B, 5> {
186 use burn::tensor::s;
187
188 let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
189 let [.., ngroups, state_rank] = b_bnlgr.dims();
190
191 // permute b and x to prepare them for the mamtul
192 // - 1/15: permute: (x_bnlhp [in][*]) -> (x_bnhpl)
193 let x_bnhpl = x_bnlhp.clone().permute([0, 1, 3, 4, 2]);
194 assert_eq!(
195 [batch, nchunks, nheads, per_head_dim, chunk_len],
196 x_bnhpl.dims()
197 );
198 // - 2: permute: (b_bnlgr [in][*]) -> (b_bnglr)
199 let b_bnglr = b_bnlgr.permute([0, 1, 3, 2, 4]); // note: still in groups instead of heads
200 assert_eq!(
201 [batch, nchunks, ngroups, chunk_len, state_rank],
202 b_bnglr.dims()
203 );
204
205 // Expand B from ngroups to nheads by repeating each group's
206 // projection across all heads_per_group heads in that group.
207 let heads_per_group = nheads / ngroups;
208 let b_bnhlr = b_bnglr
209 // - 3: unsqueeze: (b_bnglr) -> (b_bng1lr)
210 .unsqueeze_dim::<6>(3) // b_bng1lr
211 // - 4: expand: (b_bng1lr) -> (b_bngHlr)
212 .expand([
213 batch,
214 nchunks,
215 ngroups,
216 heads_per_group,
217 chunk_len,
218 state_rank,
219 ]) // b_bngHlr
220 // - 5: reshape: (b_bngHlr) -> (b_bnhlr)
221 .reshape([batch, nchunks, nheads, chunk_len, state_rank]);
222
223 // scale b
224 let b_scaled_bnhlr = {
225 let b_bar_scale_bhnl = {
226 let da_cumsum_last_in_chunk_bhn1 =
227 // - 6: slice: (da_cumsum_bhnl [in][*]) -> (da_cumsum_last_in_chunk_bhn1)
228 da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
229 assert_eq!(
230 [batch, nheads, nchunks, 1],
231 da_cumsum_last_in_chunk_bhn1.dims()
232 );
233
234 // - 7: expand: (da_cumsum_last_in_chunk_bhn1) -> (da_cumsum_last_bhnl)
235 let da_cumsum_last_bhnl =
236 da_cumsum_last_in_chunk_bhn1.expand([batch, nheads, nchunks, chunk_len]);
237 // - 8: sub: (da_cumsum_last_bhnl, da_cumsum_bhnl [from K1][*]) -> (da_delta_bhnl)
238 let da_delta_bhnl = da_cumsum_last_bhnl - da_cumsum_bhnl.clone();
239 // - 9: exp: (da_delta_bhnl) -> (forward_decay_to_chunk_end_bhnl [+])
240 let forward_decay_to_chunk_end_bhnl = da_delta_bhnl.exp();
241 assert_eq!(
242 [batch, nheads, nchunks, chunk_len],
243 forward_decay_to_chunk_end_bhnl.dims()
244 );
245
246 // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
247 forward_decay_to_chunk_end_bhnl * dt_discretized_bhnl.clone()
248 };
249 assert_eq!([batch, nheads, nchunks, chunk_len], b_bar_scale_bhnl.dims());
250
251 // - 11: permute: (b_bar_scale_bhnl [+]) -> (b_bar_scale_bnhl)
252 let b_bar_scale_bnhl = b_bar_scale_bhnl.permute([0, 2, 1, 3]);
253 assert_eq!([batch, nchunks, nheads, chunk_len], b_bar_scale_bnhl.dims());
254 let b_bar_scale_bnhlr = b_bar_scale_bnhl
255 // - 12: unsqueeze: (b_bar_scale_bnhl) -> (b_bar_scale_bnhl1)
256 .unsqueeze_dim::<5>(4) // b_bar_scale_bnhl1
257 // - 13: expand: (b_bar_scale_bnhl1) -> (b_bar_scale_bnhlr)
258 .expand([batch, nchunks, nheads, chunk_len, state_rank]);
259 // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
260 b_bnhlr * b_bar_scale_bnhlr
261 };
262 assert_eq!(
263 [batch, nchunks, nheads, chunk_len, state_rank],
264 b_scaled_bnhlr.dims()
265 );
266
267 // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
268 let intra_chunk_state_bnhpr: Tensor<B, 5> = x_bnhpl.matmul(b_scaled_bnhlr);
269 assert_eq!(
270 [batch, nchunks, nheads, per_head_dim, state_rank],
271 intra_chunk_state_bnhpr.dims()
272 );
273 intra_chunk_state_bnhpr
274}
275
276/// Based on the Kernel 4 Triton reference `_state_passing_fwd_kernel` (`ssd_state_passing.py`).
277///
278/// Returns:
279/// - chunk_input_state_bnhpr [used in K5][!].
280/// - final_state_bhpr [final output].
281pub fn k4_ssd_state_passing<B: Backend>(
282 intra_chunk_state_bnhpr: Tensor<B, 5>,
283 da_chunk_end_bhn: Tensor<B, 3>,
284 initial_state_bhpr: Tensor<B, 4>,
285) -> (Tensor<B, 5>, Tensor<B, 4>) {
286 let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
287 let flat_state_dim = per_head_dim * state_rank;
288
289 // - 1/5: init-mut: (initial_state_bhpr [in][*]) -> (running_state_bhpr)
290 let mut running_state_bhpr = initial_state_bhpr;
291 assert_eq!(
292 [batch, nheads, per_head_dim, state_rank],
293 running_state_bhpr.dims()
294 );
295
296 let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
297 // - 2: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
298 chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
299
300 // - 3: serial-loop: (0..nchunks)
301 for i_chunk in 0..nchunks {
302 let intra_state_bhpr = intra_chunk_state_bnhpr
303 .clone()
304 // - 3.1/3.9: slice: (intra_chunk_state_bnhpr [in][!]) -> (intra_chunk_state_b1hpr)
305 .slice(s![.., i_chunk, .., .., ..]) // intra_chunk_state_b1hpr
306 // - 3.2: squeeze: (intra_chunk_state_b1hpr) -> (intra_state_bhpr)
307 .squeeze_dim::<4>(1);
308 assert_eq!(
309 [batch, nheads, per_head_dim, state_rank],
310 intra_state_bhpr.dims()
311 );
312
313 let decay_bhpr = da_chunk_end_bhn
314 .clone()
315 // - 3.3: slice: (da_chunk_end_bhn [in][*]) -> (da_chunk_end_bh1)
316 .slice(s![.., .., i_chunk]) // da_chunk_end_bh1
317 // - 3.4: exp: (da_chunk_end_bh1) -> (exp_da_chunk_end_bh1)
318 .exp() // exp_da_chunk_end_bh1
319 // - 3.5: unsqueeze: (exp_da_chunk_end_bh1) -> (exp_da_chunk_end_bh11)
320 .unsqueeze_dim::<4>(3) // exp_da_chunk_end_bh11
321 // - 3.6: expand: (exp_da_chunk_end_bh11) -> (decay_bhpr)
322 .expand([batch, nheads, per_head_dim, state_rank]);
323
324 // SSM recurrence: running_state = decay * running_state + intra_state
325 running_state_bhpr =
326 // - 3.7: mul: (decay_bhpr, running_state_bhpr) -> (running_state_bhpr)
327 (decay_bhpr * running_state_bhpr) // running_state_bhpr
328 // - 3.8: add: (running_state_bhpr, intra_state_bhpr) -> (running_state_bhpr)
329 + intra_state_bhpr;
330 // - 3.9/3.9: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
331 chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
332 }
333
334 // - 4: vec-pop: (chunk_input_state_vec_bhpr [vec][!]) -> (final_state_bhpr [elem][out][!])
335 let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
336 assert_eq!(
337 [batch, nheads, per_head_dim, state_rank],
338 final_state_bhpr.dims()
339 );
340
341 // - 5/5: stack: (chunk_input_state_vec_bhpr [!]) -> (chunk_input_state_bnhpr [out][!])
342 let chunk_input_state_bnhpr = Tensor::stack(chunk_input_state_vec_bhpr, 1);
343 assert_eq!(
344 [batch, nchunks, nheads, per_head_dim, state_rank],
345 chunk_input_state_bnhpr.dims()
346 );
347
348 (chunk_input_state_bnhpr, final_state_bhpr)
349}
350
351/// Based on the Kernel 5 Triton reference `_chunk_scan_fwd_kernel` (`ssd_chunk_scan.py`).
352///
353/// Returns:
354/// - y_bnlhp [final output]
355pub fn k5_ssd_chunk_scan<B: Backend>(
356 da_cumsum_bhnl: Tensor<B, 4>,
357 dt_discretized_bhnl: Tensor<B, 4>,
358 x_bnlhp: Tensor<B, 5>,
359 c_bnlgr: Tensor<B, 5>,
360 cb_bngll: Tensor<B, 5>,
361 chunk_input_state_bnhpr: Tensor<B, 5>,
362 d_h: Tensor<B, 1>,
363) -> Tensor<B, 5> {
364 let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
365 let [.., ngroups, state_rank] = c_bnlgr.dims();
366 let heads_per_group = nheads / ngroups;
367 let device = x_bnlhp.device();
368
369 // Rearrange inputs to the common [batch, nchunks, nheads, ...] ordering used below.
370 // - 1/36: permute: (da_cumsum_bhnl [*]) -> (da_cumsum_bnhl)
371 let da_cumsum_bnhl = da_cumsum_bhnl.permute([0, 2, 1, 3]);
372 san(&da_cumsum_bnhl);
373 // - 2: permute: (dt_discretized_bhnl [*]) -> (dt_bnhl)
374 let dt_bnhl = dt_discretized_bhnl.permute([0, 2, 1, 3]);
375 san(&dt_bnhl);
376 // - 3: permute: (x_bnlhp [*]) -> (x_bnhlp)
377 let x_bnhlp = x_bnlhp.clone().permute([0, 1, 3, 2, 4]);
378 san(&x_bnhlp);
379
380 // GQA: expand C [b,n,l,g,r] → [b,n,h,l,r].
381 let c_bnhlr = c_bnlgr
382 // - 4: unsqueeze: (c_bnlgr [*]) -> (c_bnlg1r)
383 .unsqueeze_dim::<6>(4) // c_bnlg1r
384 // - 5: expand: (c_bnlg1r) -> (c_bnlgHr)
385 .expand([
386 batch,
387 nchunks,
388 chunk_len,
389 ngroups,
390 heads_per_group,
391 state_rank,
392 ]) // c_bnlgHr
393 // - 6: reshape: (c_bnlgHr) -> (c_bnlhr)
394 .reshape([batch, nchunks, chunk_len, nheads, state_rank]) // c_bnlhr
395 // - 7: permute: (c_bnlhr) -> (c_bnhlr)
396 .permute([0, 1, 3, 2, 4]);
397 san(&c_bnhlr);
398
399 // GQA: expand CB [b,n,g,l,l] → [b,n,h,l,l].
400 let cb_bnhll = cb_bngll
401 // - 8: unsqueeze: (cb_bngll [!]) -> (cb_bng1ll)
402 .unsqueeze_dim::<6>(3) // cb_bng1ll
403 // - 9: expand: (cb_bng1ll) -> (cb_bngHll)
404 .expand([
405 batch,
406 nchunks,
407 ngroups,
408 heads_per_group,
409 chunk_len,
410 chunk_len,
411 ]) // cb_bngHll
412 // - 10: reshape: (cb_bngHll) -> (cb_bnhll)
413 .reshape([batch, nchunks, nheads, chunk_len, chunk_len]);
414 san(&cb_bnhll);
415
416 // ── BLUE: exp(dA[l]) · C[l,:] @ state_in^T ─────────────────────────────
417 //
418 // blue[b,n,h,l,p] = exp(da[b,n,h,l]) · Σ_r c[b,n,h,l,r] · state[b,n,h,p,r]
419 //
420 // [b,n,h,l,r] @ [b,n,h,r,p] → [b,n,h,l,p]
421 let exp_da_cumsum_bnhlp = da_cumsum_bnhl
422 .clone()
423 // - 11: exp: (da_cumsum_bnhl) -> (exp_da_cumsum_bnhl)
424 .exp()
425 // - 12: unsqueeze: (exp_da_cumsum_bnhl) -> (exp_da_cumsum_bnhl1)
426 .unsqueeze_dim::<5>(4) // exp_da_cumsum_bnhl1
427 // - 13: expand: (exp_da_cumsum_bnhl1) -> (exp_da_cumsum_bnhlp)
428 .expand([batch, nchunks, nheads, chunk_len, per_head_dim]);
429 san(&exp_da_cumsum_bnhlp);
430 // - 14: permute: (chunk_input_state_bnhpr [!]) -> (chunk_input_state_bnhrp)
431 let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
432 // - 15: matmul: (c_bnhlr, chunk_input_state_bnhrp) -> (blue_bnhlp)
433 let blue_scaled_bnhlp = c_bnhlr
434 .matmul(chunk_input_state_bnhrp) // blue_bnhlp
435 // - 16: mul: (blue_bnhlp, exp_da_cumsum_bnhlp) -> (blue_scaled_bnhlp)
436 * exp_da_cumsum_bnhlp;
437 san(&blue_scaled_bnhlp);
438
439 // ── ORANGE: causal CB_weighted @ X ──────────────────────────────────────
440 //
441 // orange[b,n,h,l,p] = Σ_{s≤l} CB[l,s] · exp(da[l]-da[s]) · dt[s] · x[s,p]
442 //
443 // Precompute the full lower-triangular weight matrix, then do a single matmul.
444 //
445 let da_cumsum_target_bnhll = da_cumsum_bnhl
446 .clone()
447 // - 17: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnhl1)
448 .unsqueeze_dim::<5>(4) // da_cumsum_bnhl1
449 // - 18: expand: (da_cumsum_bnhl1) -> (da_cumsum_target_bnhll)
450 .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
451 // println!("{}", da_cumsum_target_bnhll);
452 san(&da_cumsum_target_bnhll);
453 let da_cumsum_source_bnhll = da_cumsum_bnhl
454 // - 19: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnh1l)
455 .unsqueeze_dim::<5>(3) // da_cumsum_bnh1l
456 // - 20: expand: (da_cumsum_bnh1l) -> (da_cumsum_source_bnhll)
457 .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
458 // println!("{}", da_cumsum_source_bnhll);
459 san(&da_cumsum_source_bnhll);
460 // - 21: sub: (da_cumsum_target_bnhll, da_cumsum_source_bnhll) -> (da_cumsum_diff_bnhll)
461 let da_cumsum_diff_bnhll = da_cumsum_target_bnhll - da_cumsum_source_bnhll;
462 san(&da_cumsum_diff_bnhll);
463
464 // note: overflow instability at step 22, a `minimal::segsum`-like upper triangle protection is necessary.
465 // - 21.1: tril-mask: (0) -> (causal_mask_ll), expanded as a view to causal_mask_bnhll.
466 // true above the main diagonal, false at diagonal and below.
467 // Built at [L,L] and broadcast — the mask values do not depend on (b,n,h).
468 let causal_mask_bnhll: Tensor<B, 5, burn::prelude::Bool> =
469 Tensor::<B, 2, burn::prelude::Bool>::tril_mask([chunk_len, chunk_len], 0, &device)
470 .reshape([1, 1, 1, chunk_len, chunk_len])
471 .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
472 // - 21.2: mask-fill: (da_cumsum_diff_bnhll, causal_mask_bnhll) -> (da_cumsum_diff_masked_bnhll)
473 // Causal mask and exp stabilizer: above upper diagonal set to -inf.
474 let da_cumsum_diff_masked_bnhll =
475 da_cumsum_diff_bnhll.mask_fill(causal_mask_bnhll, f32::NEG_INFINITY);
476
477 // - 22: exp: (da_cumsum_diff_masked_bnhll) -> (da_cumsum_diff_exp_bnhll)
478 let da_cumsum_diff_exp_bnhll = da_cumsum_diff_masked_bnhll.exp();
479 san(&da_cumsum_diff_exp_bnhll);
480 let dt_source_bnhll = dt_bnhl
481 // - 23: unsqueeze: (dt_bnhl) -> (dt_bnh1l)
482 .unsqueeze_dim::<5>(3) // dt_bnh1l
483 // - 24: expand: (dt_bnh1l) -> (dt_source_bnhll)
484 .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
485 san(&dt_source_bnhll);
486
487 // note: steps 25, 26 and 29 are no longer necessary.
488 // // Causal mask (0 above the main diagonal, 1 elsewhere).
489 // let causal_mask_bnhll =
490 // // - 25: ones: (1) -> (ones_bnhll)
491 // Tensor::ones([batch, nchunks, nheads, chunk_len, chunk_len], &device)
492 // // - 26: tril: (ones_bnhll, 0) -> (causal_mask_bnhll)
493 // .tril(0);
494
495 // [b,n,h,l,l] @ [b,n,h,l,p] → [b,n,h,l,p]
496 // - 27: mul: (cb_bnhll, da_cumsum_diff_exp_bnhll) -> (orange_lhs_partial1_bnhll)
497 let orange_lhs_partial1_bnhll = cb_bnhll * da_cumsum_diff_exp_bnhll;
498 san(&orange_lhs_partial1_bnhll);
499 // - 28: mul: (orange_lhs_partial1_bnhll, dt_source_bnhll) -> (orange_lhs_partial2_bnhll)
500 let orange_lhs_partial2_bnhll = orange_lhs_partial1_bnhll * dt_source_bnhll;
501 san(&orange_lhs_partial2_bnhll);
502 // // - 29: mul: (orange_lhs_partial2_bnhll, causal_mask_bnhll) -> (orange_lhs_partial3_bnhll)
503 // let orange_lhs_partial3_bnhll = orange_lhs_partial2_bnhll * causal_mask_bnhll;
504 // san(&orange_lhs_partial3_bnhll);
505 // - 30: matmul: (orange_lhs_partial3_bnhll, x_bnhlp) -> (orange_bnhlp)
506 // - 30: matmul: (orange_lhs_partial2_bnhll, x_bnhlp) -> (orange_bnhlp)
507 let orange_bnhlp = orange_lhs_partial2_bnhll.matmul(x_bnhlp);
508 san(&orange_bnhlp);
509
510 // ── SKIP: D[h] · x[l,p] ─────────────────────────────────────────────────
511 //
512 // D_HAS_HDIM = False: D is a scalar per head, shape [nheads].
513 // Triton: `acc += x_residual * D`
514 let skip_bnlhp = d_h
515 // - 31: unsqueeze-dims: (d_h [*]) -> (d_111h1)
516 .unsqueeze_dims::<5>(&[0, 1, 2, 4]) // d_111h1
517 // - 32: expand: (d_111h1) -> (d_bnlhp)
518 .expand([
519 batch,
520 nchunks,
521 chunk_len,
522 nheads,
523 per_head_dim,
524 ]) // d_bnlhp
525 // - 33: mul: (d_bnlhp, x_bnlhp[*]) -> (skip_bnlhp)
526 * x_bnlhp;
527 san(&skip_bnlhp);
528
529 // Permute BLUE + ORANGE from [b,n,h,l,p] back to [b,n,l,h,p], then add SKIP.
530 // - 34: add: (blue_scaled_bnhlp, orange_bnhlp) -> (y_partial_bnhlp)
531 let y_partial_bnhlp = blue_scaled_bnhlp + orange_bnhlp;
532 san(&y_partial_bnhlp);
533 // - 35: permute: (y_partial_bnhlp) -> (y_partial_bnlhp)
534 let y_partial_bnlhp = y_partial_bnhlp.permute([0, 1, 3, 2, 4]);
535 san(&y_partial_bnlhp);
536 // - 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
537 let y_bnlhp: Tensor<B, 5> = y_partial_bnlhp + skip_bnlhp;
538 san(&y_bnlhp);
539
540 assert_eq!(
541 [batch, nchunks, chunk_len, nheads, per_head_dim],
542 y_bnlhp.dims()
543 );
544 y_bnlhp
545}