Skip to main content

burn_mamba/mamba2/ssd/
serial.rs

1#![allow(unused_variables)]
2
3use crate::mamba2::prelude::*;
4use crate::utils::sanity::sanity as san;
5use burn::prelude::*;
6
7impl<B: Backend> Mamba2<B> {
8    /// Forward pass for the Mamba-2 SSD module.
9    ///
10    /// Returns:
11    /// - `y_bnlhp`.
12    /// - `final_state_bhpr`.
13    #[allow(non_snake_case)]
14    pub fn ssd_serial(input: super::Mamba2SsdInput<B>) -> (Tensor<B, 5>, Tensor<B, 4>) {
15        let [batch, nchunks, chunk_len, nheads, per_head_dim] = input.x_bnlhp.dims();
16        let [.., ngroups, state_rank] = input.b_bnlgr.dims();
17        let device = input.x_bnlhp.device();
18        assert_ne!(ngroups, 0);
19        assert_eq!(nheads % ngroups, 0);
20        assert!(nchunks > 0, "sequence length must be at least 1");
21        // `heads_per_group` is called `nheads_ngroups_ratio` in every Triton kernel.
22        // It is the compile-time constant used by GQA (Grouped Query Attention) to map
23        // a head index to its B/C group: `group_idx = head_idx / heads_per_group`.
24        let heads_per_group = nheads / ngroups;
25
26        san(&input.x_bnlhp);
27        san(&input.dt_bnlh);
28        san(&input.a_decay_h);
29        san(&input.b_bnlgr);
30        san(&input.c_bnlgr);
31        san(&input.d_h);
32        san(&input.initial_state_bhpr);
33
34        assert!(
35            input.init_state_hpr.is_none(),
36            "init_state_hpr not yet implemented"
37        );
38
39        // ── Permutes ──────────────────────────────────────────────────────────────────
40        // Note: dt_bnlh calculation (originally in Kernel 1) moved to Step 4 (before padding).
41        let dt_discretized_bhnl = input.dt_bnlh.permute([0, 3, 1, 2]);
42        assert_eq!(
43            [batch, nheads, nchunks, chunk_len],
44            dt_discretized_bhnl.dims()
45        );
46        san(&dt_discretized_bhnl);
47
48        // ── Kernel 1 ──────────────────────────────────────────────────────────────────
49        // IO: (..) -> (da_cumsum_bhnl [used in K3+K5][*], da_chunk_end_bhn [used in K4][omitted][*])
50        let (da_cumsum_bhnl, da_chunk_end_bhn): (Tensor<B, 4>, Tensor<B, 3>) =
51            k1_ssd_chunk_cumsum(dt_discretized_bhnl.clone(), input.a_decay_h.clone());
52        assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
53        assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
54        san(&da_cumsum_bhnl);
55        san(&da_chunk_end_bhn);
56
57        // ── Kernel 2 ──────────────────────────────────────────────────────────────────
58        // IO: (..) -> (cb_bngll [used in K5][!])
59        let cb_bngll: Tensor<B, 5> = k2_ssd_bmm(input.c_bnlgr.clone(), input.b_bnlgr.clone());
60        assert_eq!(
61            [batch, nchunks, ngroups, chunk_len, chunk_len],
62            cb_bngll.dims()
63        );
64        // Note: cb_bngll is then only used by Kernel 5.
65        san(&cb_bngll);
66
67        // ── Kernel 3 ──────────────────────────────────────────────────────────────────
68        // IO: (..) -> (intra_chunk_state_bnhpr [used in K4][!])
69        let intra_chunk_state_bnhpr: Tensor<B, 5> = k3_ssd_chunk_state(
70            input.x_bnlhp.clone(),
71            input.b_bnlgr.clone(),
72            da_cumsum_bhnl.clone(),
73            dt_discretized_bhnl.clone(),
74        );
75        assert_eq!(
76            [batch, nchunks, nheads, per_head_dim, state_rank],
77            intra_chunk_state_bnhpr.dims()
78        );
79        san(&intra_chunk_state_bnhpr);
80
81        // ── Kernel 4 ──────────────────────────────────────────────────────────────────
82        // IO: (..) -> (chunk_input_state_bnhpr [used in K5][!], final_state_bhpr [final output])
83        let flat_state_dim = per_head_dim * state_rank;
84        let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<B, 5>, Tensor<B, 4>) =
85            k4_ssd_state_passing(
86                intra_chunk_state_bnhpr.clone(),
87                da_chunk_end_bhn.clone(),
88                input.initial_state_bhpr,
89            );
90        assert_eq!(
91            [batch, nchunks, nheads, per_head_dim, state_rank],
92            chunk_input_state_bnhpr.dims()
93        );
94        assert_eq!(
95            [batch, nheads, per_head_dim, state_rank],
96            final_state_bhpr.dims()
97        );
98        san(&chunk_input_state_bnhpr);
99        san(&final_state_bhpr);
100
101        // ── Kernel 5 ──────────────────────────────────────────────────────────────────
102        let y_bnlhp: Tensor<B, 5> = k5_ssd_chunk_scan(
103            da_cumsum_bhnl,
104            dt_discretized_bhnl,
105            input.x_bnlhp,
106            input.c_bnlgr,
107            cb_bngll,
108            chunk_input_state_bnhpr,
109            input.d_h,
110        );
111        assert_eq!(
112            [batch, nchunks, chunk_len, nheads, per_head_dim],
113            y_bnlhp.dims()
114        );
115        san(&y_bnlhp);
116
117        (y_bnlhp, final_state_bhpr)
118    }
119}
120
121/// Based on the Kernel 1 Triton reference `_chunk_cumsum_fwd_kernel` (`ssd_chunk_state.py`).
122///
123/// Returns:
124/// - da_cumsum_bhnl [used in K3+K5][*] - intra-chunk cumsum.
125/// - da_chunk_end_bhn [used in K4][omitted][*] - last da_cumsum per chunk.
126pub fn k1_ssd_chunk_cumsum<B: Backend>(
127    dt_discretized_bhnl: Tensor<B, 4>,
128    a_decay_h: Tensor<B, 1>,
129) -> (Tensor<B, 4>, Tensor<B, 3>) {
130    let [batch, nheads, nchunks, chunk_len] = dt_discretized_bhnl.dims();
131    let da_cumsum_bhnl: Tensor<B, 4> = {
132        let a_decay_bhnl = a_decay_h
133            // - 1/6: unsqueeze-dims: (a_decay_h [*]) -> (a_decay_1h11)
134            .unsqueeze_dims::<4>(&[0, 2, 3]) // a_decay_1h11
135            // - 2: expand: (a_decay_1h11) -> (a_decay_bhnl)
136            .expand([batch, nheads, nchunks, chunk_len]);
137        // - 3: mul: (dt_discretized_bhnl [*], a_decay_bhnl) -> (da_bhnl)
138        // - 4: cumsum: (da_bhnl) -> (da_cumsum_bhnl [out][*])
139        (dt_discretized_bhnl * a_decay_bhnl).cumsum(3)
140    };
141    assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
142
143    let da_chunk_end_bhn = da_cumsum_bhnl
144        .clone()
145        // - 5: slice: (da_cumsum_bhnl [*]) -> (da_cumsum_bhn1)
146        .slice(s![.., .., .., -1]) // da_cumsum_bhn1
147        // - 6/6: squeeze: (da_cumsum_bhn1) -> (da_chunk_end_bhn [out])
148        .squeeze_dim::<3>(3);
149    assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
150
151    (da_cumsum_bhnl, da_chunk_end_bhn)
152}
153
154/// Based on the Kernel 2 Triton reference `_bmm_chunk_fwd_kernel` (`ssd_bmm.py`).
155///
156/// Returns:
157/// - cb_bngll [used in K5][!].
158pub fn k2_ssd_bmm<B: Backend>(c_bnlgr: Tensor<B, 5>, b_bnlgr: Tensor<B, 5>) -> Tensor<B, 5> {
159    let [batch, nchunks, chunk_len, ngroups, state_rank] = c_bnlgr.dims();
160
161    // - 1/3: permute: (c_bnlgr [in][*]) -> (c_bnglr)
162    let c_bnglr = c_bnlgr.clone().permute([0, 1, 3, 2, 4]);
163    // - 2: permute: (b_bnlgr [in][*]) -> (b_bngrl)
164    let b_bngrl = b_bnlgr.clone().permute([0, 1, 3, 4, 2]);
165    // - 3/3: matmul: (c_bnglr, b_bngrl) -> (cb_bngll [out][!])
166    let cb_bngll: Tensor<B, 5> = c_bnglr.matmul(b_bngrl);
167    assert_eq!(
168        [batch, nchunks, ngroups, chunk_len, chunk_len],
169        cb_bngll.dims()
170    );
171    // Note: cb_bngll is then only used by Kernel 5.
172    cb_bngll
173}
174
175/// Based on the Kernel 3 Triton reference `_chunk_state_fwd_kernel` (`ssd_chunk_state.py`).
176///
177/// Returns:
178/// - cb_bngll [used in K5][!] - state assuming zero initial state at each chunk boundary.
179/// - b_bar_scale_bhnl [*] - intermediary
180pub fn k3_ssd_chunk_state<B: Backend>(
181    x_bnlhp: Tensor<B, 5>,
182    b_bnlgr: Tensor<B, 5>,
183    da_cumsum_bhnl: Tensor<B, 4>,
184    dt_discretized_bhnl: Tensor<B, 4>,
185) -> Tensor<B, 5> {
186    use burn::tensor::s;
187
188    let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
189    let [.., ngroups, state_rank] = b_bnlgr.dims();
190
191    // permute b and x to prepare them for the mamtul
192    // - 1/15: permute: (x_bnlhp [in][*]) -> (x_bnhpl)
193    let x_bnhpl = x_bnlhp.clone().permute([0, 1, 3, 4, 2]);
194    assert_eq!(
195        [batch, nchunks, nheads, per_head_dim, chunk_len],
196        x_bnhpl.dims()
197    );
198    // - 2: permute: (b_bnlgr [in][*]) -> (b_bnglr)
199    let b_bnglr = b_bnlgr.permute([0, 1, 3, 2, 4]); // note: still in groups instead of heads
200    assert_eq!(
201        [batch, nchunks, ngroups, chunk_len, state_rank],
202        b_bnglr.dims()
203    );
204
205    // Expand B from ngroups to nheads by repeating each group's
206    // projection across all heads_per_group heads in that group.
207    let heads_per_group = nheads / ngroups;
208    let b_bnhlr = b_bnglr
209        // - 3: unsqueeze: (b_bnglr) -> (b_bng1lr)
210        .unsqueeze_dim::<6>(3) // b_bng1lr
211        // - 4: expand: (b_bng1lr) -> (b_bngHlr)
212        .expand([
213            batch,
214            nchunks,
215            ngroups,
216            heads_per_group,
217            chunk_len,
218            state_rank,
219        ]) // b_bngHlr
220        // - 5: reshape: (b_bngHlr) -> (b_bnhlr)
221        .reshape([batch, nchunks, nheads, chunk_len, state_rank]);
222
223    // scale b
224    let b_scaled_bnhlr = {
225        let b_bar_scale_bhnl = {
226            let da_cumsum_last_in_chunk_bhn1 =
227                // - 6: slice: (da_cumsum_bhnl [in][*]) -> (da_cumsum_last_in_chunk_bhn1)
228                da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
229            assert_eq!(
230                [batch, nheads, nchunks, 1],
231                da_cumsum_last_in_chunk_bhn1.dims()
232            );
233
234            // - 7: expand: (da_cumsum_last_in_chunk_bhn1) -> (da_cumsum_last_bhnl)
235            let da_cumsum_last_bhnl =
236                da_cumsum_last_in_chunk_bhn1.expand([batch, nheads, nchunks, chunk_len]);
237            // - 8: sub: (da_cumsum_last_bhnl, da_cumsum_bhnl [from K1][*]) -> (da_delta_bhnl)
238            let da_delta_bhnl = da_cumsum_last_bhnl - da_cumsum_bhnl.clone();
239            // - 9: exp: (da_delta_bhnl) -> (forward_decay_to_chunk_end_bhnl [+])
240            let forward_decay_to_chunk_end_bhnl = da_delta_bhnl.exp();
241            assert_eq!(
242                [batch, nheads, nchunks, chunk_len],
243                forward_decay_to_chunk_end_bhnl.dims()
244            );
245
246            // - 10: mul: (forward_decay_to_chunk_end_bhnl [+], dt_discretized_bhnl [in][*]) -> (b_bar_scale_bhnl [+])
247            forward_decay_to_chunk_end_bhnl * dt_discretized_bhnl.clone()
248        };
249        assert_eq!([batch, nheads, nchunks, chunk_len], b_bar_scale_bhnl.dims());
250
251        // - 11: permute: (b_bar_scale_bhnl [+]) -> (b_bar_scale_bnhl)
252        let b_bar_scale_bnhl = b_bar_scale_bhnl.permute([0, 2, 1, 3]);
253        assert_eq!([batch, nchunks, nheads, chunk_len], b_bar_scale_bnhl.dims());
254        let b_bar_scale_bnhlr = b_bar_scale_bnhl
255            // - 12: unsqueeze: (b_bar_scale_bnhl) -> (b_bar_scale_bnhl1)
256            .unsqueeze_dim::<5>(4) // b_bar_scale_bnhl1
257            // - 13: expand: (b_bar_scale_bnhl1) -> (b_bar_scale_bnhlr)
258            .expand([batch, nchunks, nheads, chunk_len, state_rank]);
259        // - 14: mul: (b_bnhlr, b_bar_scale_bnhlr) -> (b_scaled_bnhlr [+])
260        b_bnhlr * b_bar_scale_bnhlr
261    };
262    assert_eq!(
263        [batch, nchunks, nheads, chunk_len, state_rank],
264        b_scaled_bnhlr.dims()
265    );
266
267    // - 15/15: matmul: (x_bnhpl, b_scaled_bnhlr [+]) -> (intra_chunk_state_bnhpr [out][!])
268    let intra_chunk_state_bnhpr: Tensor<B, 5> = x_bnhpl.matmul(b_scaled_bnhlr);
269    assert_eq!(
270        [batch, nchunks, nheads, per_head_dim, state_rank],
271        intra_chunk_state_bnhpr.dims()
272    );
273    intra_chunk_state_bnhpr
274}
275
276/// Based on the Kernel 4 Triton reference `_state_passing_fwd_kernel` (`ssd_state_passing.py`).
277///
278/// Returns:
279/// - chunk_input_state_bnhpr [used in K5][!].
280/// - final_state_bhpr [final output].
281pub fn k4_ssd_state_passing<B: Backend>(
282    intra_chunk_state_bnhpr: Tensor<B, 5>,
283    da_chunk_end_bhn: Tensor<B, 3>,
284    initial_state_bhpr: Tensor<B, 4>,
285) -> (Tensor<B, 5>, Tensor<B, 4>) {
286    let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
287    let flat_state_dim = per_head_dim * state_rank;
288
289    // - 1/5: init-mut: (initial_state_bhpr [in][*]) -> (running_state_bhpr)
290    let mut running_state_bhpr = initial_state_bhpr;
291    assert_eq!(
292        [batch, nheads, per_head_dim, state_rank],
293        running_state_bhpr.dims()
294    );
295
296    let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
297    // - 2: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
298    chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
299
300    // - 3: serial-loop: (0..nchunks)
301    for i_chunk in 0..nchunks {
302        let intra_state_bhpr = intra_chunk_state_bnhpr
303            .clone()
304            //   - 3.1/3.9: slice: (intra_chunk_state_bnhpr [in][!]) -> (intra_chunk_state_b1hpr)
305            .slice(s![.., i_chunk, .., .., ..]) // intra_chunk_state_b1hpr
306            //   - 3.2: squeeze: (intra_chunk_state_b1hpr) -> (intra_state_bhpr)
307            .squeeze_dim::<4>(1);
308        assert_eq!(
309            [batch, nheads, per_head_dim, state_rank],
310            intra_state_bhpr.dims()
311        );
312
313        let decay_bhpr = da_chunk_end_bhn
314            .clone()
315            //   - 3.3: slice: (da_chunk_end_bhn [in][*]) -> (da_chunk_end_bh1)
316            .slice(s![.., .., i_chunk]) // da_chunk_end_bh1
317            //   - 3.4: exp: (da_chunk_end_bh1) -> (exp_da_chunk_end_bh1)
318            .exp() // exp_da_chunk_end_bh1
319            //   - 3.5: unsqueeze: (exp_da_chunk_end_bh1) -> (exp_da_chunk_end_bh11)
320            .unsqueeze_dim::<4>(3) // exp_da_chunk_end_bh11
321            //   - 3.6: expand: (exp_da_chunk_end_bh11) -> (decay_bhpr)
322            .expand([batch, nheads, per_head_dim, state_rank]);
323
324        // SSM recurrence: running_state = decay * running_state + intra_state
325        running_state_bhpr =
326        //   - 3.7: mul: (decay_bhpr, running_state_bhpr) -> (running_state_bhpr)
327            (decay_bhpr * running_state_bhpr) // running_state_bhpr
328        //   - 3.8: add: (running_state_bhpr, intra_state_bhpr) -> (running_state_bhpr)
329            + intra_state_bhpr;
330        //   - 3.9/3.9: vec-push: (running_state_bhpr [elem]) -> (chunk_input_state_vec_bhpr [vec][!])
331        chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
332    }
333
334    // - 4: vec-pop: (chunk_input_state_vec_bhpr [vec][!]) -> (final_state_bhpr [elem][out][!])
335    let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
336    assert_eq!(
337        [batch, nheads, per_head_dim, state_rank],
338        final_state_bhpr.dims()
339    );
340
341    // - 5/5: stack: (chunk_input_state_vec_bhpr [!]) -> (chunk_input_state_bnhpr [out][!])
342    let chunk_input_state_bnhpr = Tensor::stack(chunk_input_state_vec_bhpr, 1);
343    assert_eq!(
344        [batch, nchunks, nheads, per_head_dim, state_rank],
345        chunk_input_state_bnhpr.dims()
346    );
347
348    (chunk_input_state_bnhpr, final_state_bhpr)
349}
350
351/// Based on the Kernel 5 Triton reference `_chunk_scan_fwd_kernel` (`ssd_chunk_scan.py`).
352///
353/// Returns:
354/// - y_bnlhp [final output]
355pub fn k5_ssd_chunk_scan<B: Backend>(
356    da_cumsum_bhnl: Tensor<B, 4>,
357    dt_discretized_bhnl: Tensor<B, 4>,
358    x_bnlhp: Tensor<B, 5>,
359    c_bnlgr: Tensor<B, 5>,
360    cb_bngll: Tensor<B, 5>,
361    chunk_input_state_bnhpr: Tensor<B, 5>,
362    d_h: Tensor<B, 1>,
363) -> Tensor<B, 5> {
364    let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
365    let [.., ngroups, state_rank] = c_bnlgr.dims();
366    let heads_per_group = nheads / ngroups;
367    let device = x_bnlhp.device();
368
369    // Rearrange inputs to the common [batch, nchunks, nheads, ...] ordering used below.
370    // - 1/36: permute: (da_cumsum_bhnl [*]) -> (da_cumsum_bnhl)
371    let da_cumsum_bnhl = da_cumsum_bhnl.permute([0, 2, 1, 3]);
372    san(&da_cumsum_bnhl);
373    // - 2: permute: (dt_discretized_bhnl [*]) -> (dt_bnhl)
374    let dt_bnhl = dt_discretized_bhnl.permute([0, 2, 1, 3]);
375    san(&dt_bnhl);
376    // - 3: permute: (x_bnlhp [*]) -> (x_bnhlp)
377    let x_bnhlp = x_bnlhp.clone().permute([0, 1, 3, 2, 4]);
378    san(&x_bnhlp);
379
380    // GQA: expand C  [b,n,l,g,r] → [b,n,h,l,r].
381    let c_bnhlr = c_bnlgr
382        // - 4: unsqueeze: (c_bnlgr [*]) -> (c_bnlg1r)
383        .unsqueeze_dim::<6>(4) // c_bnlg1r
384        // - 5: expand: (c_bnlg1r) -> (c_bnlgHr)
385        .expand([
386            batch,
387            nchunks,
388            chunk_len,
389            ngroups,
390            heads_per_group,
391            state_rank,
392        ]) // c_bnlgHr
393        // - 6: reshape: (c_bnlgHr) -> (c_bnlhr)
394        .reshape([batch, nchunks, chunk_len, nheads, state_rank]) // c_bnlhr
395        // - 7: permute: (c_bnlhr) -> (c_bnhlr)
396        .permute([0, 1, 3, 2, 4]);
397    san(&c_bnhlr);
398
399    // GQA: expand CB [b,n,g,l,l] → [b,n,h,l,l].
400    let cb_bnhll = cb_bngll
401        // - 8: unsqueeze: (cb_bngll [!]) -> (cb_bng1ll)
402        .unsqueeze_dim::<6>(3) // cb_bng1ll
403        // - 9: expand: (cb_bng1ll) -> (cb_bngHll)
404        .expand([
405            batch,
406            nchunks,
407            ngroups,
408            heads_per_group,
409            chunk_len,
410            chunk_len,
411        ]) // cb_bngHll
412        // - 10: reshape: (cb_bngHll) -> (cb_bnhll)
413        .reshape([batch, nchunks, nheads, chunk_len, chunk_len]);
414    san(&cb_bnhll);
415
416    // ── BLUE: exp(dA[l]) · C[l,:] @ state_in^T ─────────────────────────────
417    //
418    //   blue[b,n,h,l,p] = exp(da[b,n,h,l]) · Σ_r  c[b,n,h,l,r] · state[b,n,h,p,r]
419    //
420    //   [b,n,h,l,r] @ [b,n,h,r,p]  →  [b,n,h,l,p]
421    let exp_da_cumsum_bnhlp = da_cumsum_bnhl
422        .clone()
423        // - 11: exp: (da_cumsum_bnhl) -> (exp_da_cumsum_bnhl)
424        .exp()
425        // - 12: unsqueeze: (exp_da_cumsum_bnhl) -> (exp_da_cumsum_bnhl1)
426        .unsqueeze_dim::<5>(4) // exp_da_cumsum_bnhl1
427        // - 13: expand: (exp_da_cumsum_bnhl1) -> (exp_da_cumsum_bnhlp)
428        .expand([batch, nchunks, nheads, chunk_len, per_head_dim]);
429    san(&exp_da_cumsum_bnhlp);
430    // - 14: permute: (chunk_input_state_bnhpr [!]) -> (chunk_input_state_bnhrp)
431    let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
432    // - 15: matmul: (c_bnhlr, chunk_input_state_bnhrp) -> (blue_bnhlp)
433    let blue_scaled_bnhlp = c_bnhlr
434        .matmul(chunk_input_state_bnhrp)  // blue_bnhlp
435        // - 16: mul: (blue_bnhlp, exp_da_cumsum_bnhlp) -> (blue_scaled_bnhlp)
436        * exp_da_cumsum_bnhlp;
437    san(&blue_scaled_bnhlp);
438
439    // ── ORANGE: causal CB_weighted @ X ──────────────────────────────────────
440    //
441    //   orange[b,n,h,l,p] = Σ_{s≤l} CB[l,s] · exp(da[l]-da[s]) · dt[s] · x[s,p]
442    //
443    // Precompute the full lower-triangular weight matrix, then do a single matmul.
444    //
445    let da_cumsum_target_bnhll = da_cumsum_bnhl
446        .clone()
447        // - 17: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnhl1)
448        .unsqueeze_dim::<5>(4) // da_cumsum_bnhl1
449        // - 18: expand: (da_cumsum_bnhl1) -> (da_cumsum_target_bnhll)
450        .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
451    // println!("{}", da_cumsum_target_bnhll);
452    san(&da_cumsum_target_bnhll);
453    let da_cumsum_source_bnhll = da_cumsum_bnhl
454        // - 19: unsqueeze: (da_cumsum_bnhl) -> (da_cumsum_bnh1l)
455        .unsqueeze_dim::<5>(3) // da_cumsum_bnh1l
456        // - 20: expand: (da_cumsum_bnh1l) -> (da_cumsum_source_bnhll)
457        .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
458    // println!("{}", da_cumsum_source_bnhll);
459    san(&da_cumsum_source_bnhll);
460    // - 21: sub: (da_cumsum_target_bnhll, da_cumsum_source_bnhll) -> (da_cumsum_diff_bnhll)
461    let da_cumsum_diff_bnhll = da_cumsum_target_bnhll - da_cumsum_source_bnhll;
462    san(&da_cumsum_diff_bnhll);
463
464    // note: overflow instability at step 22, a `minimal::segsum`-like upper triangle protection is necessary.
465    // - 21.1: tril-mask: (0) -> (causal_mask_ll), expanded as a view to causal_mask_bnhll.
466    // true above the main diagonal, false at diagonal and below.
467    // Built at [L,L] and broadcast — the mask values do not depend on (b,n,h).
468    let causal_mask_bnhll: Tensor<B, 5, burn::prelude::Bool> =
469        Tensor::<B, 2, burn::prelude::Bool>::tril_mask([chunk_len, chunk_len], 0, &device)
470            .reshape([1, 1, 1, chunk_len, chunk_len])
471            .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
472    // - 21.2: mask-fill: (da_cumsum_diff_bnhll, causal_mask_bnhll) -> (da_cumsum_diff_masked_bnhll)
473    // Causal mask and exp stabilizer: above upper diagonal set to -inf.
474    let da_cumsum_diff_masked_bnhll =
475        da_cumsum_diff_bnhll.mask_fill(causal_mask_bnhll, f32::NEG_INFINITY);
476
477    // - 22: exp: (da_cumsum_diff_masked_bnhll) -> (da_cumsum_diff_exp_bnhll)
478    let da_cumsum_diff_exp_bnhll = da_cumsum_diff_masked_bnhll.exp();
479    san(&da_cumsum_diff_exp_bnhll);
480    let dt_source_bnhll = dt_bnhl
481        // - 23: unsqueeze: (dt_bnhl) -> (dt_bnh1l)
482        .unsqueeze_dim::<5>(3) // dt_bnh1l
483        // - 24: expand: (dt_bnh1l) -> (dt_source_bnhll)
484        .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
485    san(&dt_source_bnhll);
486
487    // note: steps 25, 26 and 29 are no longer necessary.
488    // // Causal mask (0 above the main diagonal, 1 elsewhere).
489    // let causal_mask_bnhll =
490    //     // - 25: ones: (1) -> (ones_bnhll)
491    //     Tensor::ones([batch, nchunks, nheads, chunk_len, chunk_len], &device)
492    //     // - 26: tril: (ones_bnhll, 0) -> (causal_mask_bnhll)
493    //     .tril(0);
494
495    //   [b,n,h,l,l] @ [b,n,h,l,p]  →  [b,n,h,l,p]
496    // - 27: mul: (cb_bnhll, da_cumsum_diff_exp_bnhll) -> (orange_lhs_partial1_bnhll)
497    let orange_lhs_partial1_bnhll = cb_bnhll * da_cumsum_diff_exp_bnhll;
498    san(&orange_lhs_partial1_bnhll);
499    // - 28: mul: (orange_lhs_partial1_bnhll, dt_source_bnhll) -> (orange_lhs_partial2_bnhll)
500    let orange_lhs_partial2_bnhll = orange_lhs_partial1_bnhll * dt_source_bnhll;
501    san(&orange_lhs_partial2_bnhll);
502    // // - 29: mul: (orange_lhs_partial2_bnhll, causal_mask_bnhll) -> (orange_lhs_partial3_bnhll)
503    // let orange_lhs_partial3_bnhll = orange_lhs_partial2_bnhll * causal_mask_bnhll;
504    // san(&orange_lhs_partial3_bnhll);
505    // - 30: matmul: (orange_lhs_partial3_bnhll, x_bnhlp) -> (orange_bnhlp)
506    // - 30: matmul: (orange_lhs_partial2_bnhll, x_bnhlp) -> (orange_bnhlp)
507    let orange_bnhlp = orange_lhs_partial2_bnhll.matmul(x_bnhlp);
508    san(&orange_bnhlp);
509
510    // ── SKIP: D[h] · x[l,p] ─────────────────────────────────────────────────
511    //
512    //   D_HAS_HDIM = False: D is a scalar per head, shape [nheads].
513    //   Triton: `acc += x_residual * D`
514    let skip_bnlhp = d_h
515        // - 31: unsqueeze-dims: (d_h [*]) -> (d_111h1)
516        .unsqueeze_dims::<5>(&[0, 1, 2, 4]) // d_111h1
517        // - 32: expand: (d_111h1) -> (d_bnlhp)
518        .expand([
519            batch,
520            nchunks,
521            chunk_len,
522            nheads,
523            per_head_dim,
524        ]) // d_bnlhp
525    // - 33: mul: (d_bnlhp, x_bnlhp[*]) -> (skip_bnlhp)
526    * x_bnlhp;
527    san(&skip_bnlhp);
528
529    // Permute BLUE + ORANGE from [b,n,h,l,p] back to [b,n,l,h,p], then add SKIP.
530    // - 34: add: (blue_scaled_bnhlp, orange_bnhlp) -> (y_partial_bnhlp)
531    let y_partial_bnhlp = blue_scaled_bnhlp + orange_bnhlp;
532    san(&y_partial_bnhlp);
533    // - 35: permute: (y_partial_bnhlp) -> (y_partial_bnlhp)
534    let y_partial_bnlhp = y_partial_bnhlp.permute([0, 1, 3, 2, 4]);
535    san(&y_partial_bnlhp);
536    // - 36/36: add: (y_partial_bnlhp, skip_bnlhp) -> (y_bnlhp [out])
537    let y_bnlhp: Tensor<B, 5> = y_partial_bnlhp + skip_bnlhp;
538    san(&y_bnlhp);
539
540    assert_eq!(
541        [batch, nchunks, chunk_len, nheads, per_head_dim],
542        y_bnlhp.dims()
543    );
544    y_bnlhp
545}