Skip to main content

burn_mamba/mamba2/ssd/serial_recalculated/
backward.rs

1#![allow(non_snake_case)]
2
3use crate::mamba2::ssd::serial_recalculated::{
4    Mamba2BackendExt,
5    combined_backward::{self, CombinedGrads},
6};
7use burn::backend::autodiff::{
8    Autodiff,
9    checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
10    grads::Gradients,
11    ops::{Backward, Ops, OpsKind},
12};
13use burn::prelude::*;
14use burn::tensor::{TensorPrimitive, ops::FloatTensor};
15
16impl<B: Backend + Mamba2BackendExt, C: CheckpointStrategy> Mamba2BackendExt for Autodiff<B, C> {
17    /// Memory-efficient combined forward+backward.
18    ///
19    /// The two output tensors are concatenated into a single 1-D tracked tensor
20    /// so that one `Backward<B, 7>` node covers both outputs.  The caller
21    /// receives split+reshaped slices of that combined tensor; burn's autodiff
22    /// accumulates their upstream gradients back into a single gradient vector
23    /// before firing this backward.
24    fn ssd_serial_recalculated(
25        // AI init interface:
26        //
27        // da_cumsum_bhnl: FloatTensor<Self>,
28        // dt_discretized_bhnl: FloatTensor<Self>,
29        // x_bnlhp: FloatTensor<Self>,
30        // b_bnlgr: FloatTensor<Self>,
31        // c_bnlgr: FloatTensor<Self>,
32        // d_h: FloatTensor<Self>,
33        // initial_state_bhpr: FloatTensor<Self>,
34        x_bnlhp: FloatTensor<Self>,
35        dt_discretized_bhnl: FloatTensor<Self>,
36        b_bnlgr: FloatTensor<Self>,
37        c_bnlgr: FloatTensor<Self>,
38        d_h: FloatTensor<Self>,
39        initial_state_bhpr: FloatTensor<Self>,
40        a_decay_h: FloatTensor<Self>,
41    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
42        // ── Backward struct ──────────────────────────────────────────────────
43        #[derive(Debug)]
44        struct K2K3K4K5CombinedBackward;
45
46        #[derive(Clone, Debug)]
47        struct State<B: Backend> {
48            x_bnlhp: FloatTensor<B>,
49            dt_discretized_bhnl: FloatTensor<B>,
50            b_bnlgr: FloatTensor<B>,
51            c_bnlgr: FloatTensor<B>,
52            d_h: FloatTensor<B>,
53            initial_state_bhpr: FloatTensor<B>,
54            a_decay_h: FloatTensor<B>,
55            // flat byte-sizes for splitting the combined gradient vector
56            flat_len_y_BNLHP: usize,
57            flat_len_final_state_BHPR: usize,
58            // shapes needed to reconstruct tensors in the right ranks
59            shape_x_bnlhp: [usize; 5],
60            shape_dt_discretized_bhnl: [usize; 4],
61            shape_b_bnlgr: [usize; 5],
62            shape_c_bnlgr: [usize; 5],
63            shape_d_h: [usize; 1],
64            shape_initial_state_bhpr: [usize; 4],
65            shape_a_decay_h: [usize; 1],
66            shape_y_bnlhp: [usize; 5],          // (output 1)
67            shape_final_state_bhpr: [usize; 4], // (output 2)
68        }
69
70        /// State carried across the forward→backward boundary.
71        ///
72        /// Only the 7 original inputs are saved; all intermediates (cb, intra
73        /// state, chunk_input_state) are recomputed during `backward`.
74        #[allow(clippy::type_complexity)]
75        impl<B: Backend + Mamba2BackendExt> Backward<B, 7> for K2K3K4K5CombinedBackward {
76            type State = State<B>;
77
78            fn backward(
79                self,
80                ops: Ops<Self::State, 7>,
81                grads: &mut Gradients,
82                _checkpointer: &mut Checkpointer,
83            ) {
84                let [
85                    node_x_bnlhp,
86                    node_dt_discretized_bhnl,
87                    node_b_bnlgr,
88                    node_c_bnlgr,
89                    node_d_h,
90                    node_initial_state_bhpr,
91                    node_a_decay_h,
92                ] = ops.parents;
93
94                // Retrieve the gradient of the combined 1-D output.
95                let d_combined: Tensor<B, 1> =
96                    Tensor::from_primitive(TensorPrimitive::Float(grads.consume::<B>(&ops.node)));
97
98                let State {
99                    x_bnlhp,
100                    dt_discretized_bhnl,
101                    b_bnlgr,
102                    c_bnlgr,
103                    d_h,
104                    initial_state_bhpr,
105                    a_decay_h,
106                    //
107                    flat_len_y_BNLHP,
108                    flat_len_final_state_BHPR,
109                    //
110                    shape_x_bnlhp,
111                    shape_dt_discretized_bhnl,
112                    shape_b_bnlgr,
113                    shape_c_bnlgr,
114                    shape_d_h,
115                    shape_initial_state_bhpr,
116                    shape_a_decay_h,
117                    //
118                    shape_y_bnlhp,
119                    shape_final_state_bhpr,
120                } = ops.state;
121
122                // ── Reconstruct saved tensors ──────────────────────────────
123                use super::serial_recalculated::mk;
124
125                let x_bnlhp = mk::<_, 5>(x_bnlhp).reshape(shape_x_bnlhp);
126                let dt_discretized_bhnl =
127                    mk::<_, 4>(dt_discretized_bhnl).reshape(shape_dt_discretized_bhnl);
128                let b_bnlgr = mk::<_, 5>(b_bnlgr).reshape(shape_b_bnlgr);
129                let c_bnlgr = mk::<_, 5>(c_bnlgr).reshape(shape_c_bnlgr);
130                let d_h = mk::<_, 1>(d_h).reshape(shape_d_h);
131                let initial_state_bhpr =
132                    mk::<_, 4>(initial_state_bhpr).reshape(shape_initial_state_bhpr);
133                let a_decay_h = mk::<_, 1>(a_decay_h).reshape(shape_a_decay_h);
134
135                // ── Split incoming combined gradient ───────────────────────
136                // d_combined : [y_flat_len_BNLHP + fs_flat_len_BHPR]
137                let flat_d_y_BNLHP = d_combined.clone().narrow(0, 0, flat_len_y_BNLHP);
138                let flat_d_final_state_BHPR =
139                    d_combined.narrow(0, flat_len_y_BNLHP, flat_len_final_state_BHPR);
140
141                let d_y_bnlhp: Tensor<B, 5> = flat_d_y_BNLHP.reshape(shape_y_bnlhp);
142                let d_final_state_bhpr: Tensor<B, 4> =
143                    flat_d_final_state_BHPR.reshape(shape_final_state_bhpr);
144
145                // ── Core gradient computation ──────────────────────────────
146                let CombinedGrads {
147                    d_x_bnlhp,
148                    d_dt_discretized_bhnl,
149                    d_b_bnlgr,
150                    d_c_bnlgr,
151                    d_d_h,
152                    d_initial_state_bhpr,
153                    d_a_decay_h,
154                } = combined_backward::combined_backward(
155                    d_y_bnlhp,
156                    d_final_state_bhpr,
157                    //
158                    x_bnlhp,
159                    dt_discretized_bhnl,
160                    b_bnlgr,
161                    c_bnlgr,
162                    d_h,
163                    initial_state_bhpr,
164                    a_decay_h,
165                );
166
167                // ── Register gradients ─────────────────────────────────────
168                // TODO: request Node to be re-exported.
169                //
170                // use burn::cubecl::stub::Arc;
171                // use burn::backend::autodiff::Node;
172                // let reg = |node: Option<Arc<_>>, grad: Tensor<B, _>| {
173                //     if let Some(n) = node {
174                //         grads.register::<B>(n.id, grad.into_primitive().tensor());
175                //     }
176                // };
177                // let () = reg(node_x_bnlhp, d_x_bnlhp);
178                // let () = reg(node_dt_discretized_bhnl, d_dt_discretized_bhnl);
179                // let () = reg(node_b_bnlgr, d_b_bnlgr);
180                // let () = reg(node_c_bnlgr, d_c_bnlgr);
181                // let () = reg(node_d_h, d_d_h);
182                // let () = reg(node_initial_state_bhpr, d_initial_state_bhpr);
183                // let () = reg(node_da_cumsum_bhnl, d_da_cumsum_bhnl);
184
185                if let Some(n) = node_x_bnlhp {
186                    grads.register::<B>(n.id, d_x_bnlhp.into_primitive().tensor());
187                }
188                if let Some(n) = node_dt_discretized_bhnl {
189                    grads.register::<B>(n.id, d_dt_discretized_bhnl.into_primitive().tensor());
190                }
191                if let Some(n) = node_b_bnlgr {
192                    grads.register::<B>(n.id, d_b_bnlgr.into_primitive().tensor());
193                }
194                if let Some(n) = node_c_bnlgr {
195                    grads.register::<B>(n.id, d_c_bnlgr.into_primitive().tensor());
196                }
197                if let Some(n) = node_d_h {
198                    grads.register::<B>(n.id, d_d_h.into_primitive().tensor());
199                }
200                if let Some(n) = node_initial_state_bhpr {
201                    grads.register::<B>(n.id, d_initial_state_bhpr.into_primitive().tensor());
202                }
203                if let Some(n) = node_a_decay_h {
204                    grads.register::<B>(n.id, d_a_decay_h.into_primitive().tensor());
205                }
206            }
207        } // end impl Backward
208
209        // ── Shape extraction helpers ───────────────────────────────────────
210        // Accessed via the AutodiffTensor wrappers (which own both .node
211        // and .primitive).
212        use burn::tensor::TensorMetadata;
213        let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.primitive.shape().dims();
214        let [_, _, _, ngroups, state_rank] = b_bnlgr.primitive.shape().dims();
215
216        let flat_len_y_BNLHP = batch * nchunks * chunk_len * nheads * per_head_dim;
217        let flat_len_final_state_BHPR = batch * nheads * per_head_dim * state_rank;
218
219        let shape_x_bnlhp: [usize; 5] = [batch, nchunks, chunk_len, nheads, per_head_dim];
220        let shape_dt_discretized_bhnl: [usize; 4] = [batch, nheads, nchunks, chunk_len];
221        let shape_b_bnlgr: [usize; 5] = [batch, nchunks, chunk_len, ngroups, state_rank];
222        let shape_c_bnlgr: [usize; 5] = [batch, nchunks, chunk_len, ngroups, state_rank];
223        let shape_d_h: [usize; 1] = [nheads];
224        let shape_initial_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
225        let shape_a_decay_h: [usize; 1] = [nheads];
226        let shape_y_bnlhp: [usize; 5] = [batch, nchunks, chunk_len, nheads, per_head_dim];
227        let shape_final_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
228
229        // ── Register backward / run forward ───────────────────────────────
230        match K2K3K4K5CombinedBackward
231            .prepare::<C>([
232                x_bnlhp.node.clone(),
233                dt_discretized_bhnl.node.clone(),
234                b_bnlgr.node.clone(),
235                c_bnlgr.node.clone(),
236                d_h.node.clone(),
237                initial_state_bhpr.node.clone(),
238                a_decay_h.node.clone(),
239            ])
240            .compute_bound()
241            .stateful() // requires compute_bound
242        {
243            OpsKind::Tracked(prep) => {
244                // Run the inner (non-autodiff) forward pass.
245                let (prim_y_bnlhp, prim_final_state_bhpr) = B::ssd_serial_recalculated(
246                    x_bnlhp.primitive.clone(),
247                    dt_discretized_bhnl.primitive.clone(),
248                    b_bnlgr.primitive.clone(),
249                    c_bnlgr.primitive.clone(),
250                    d_h.primitive.clone(),
251                    initial_state_bhpr.primitive.clone(),
252                    a_decay_h.primitive.clone(),
253                );
254
255                // Note: prep.finish accepts only a single tensor.
256                // Flatten both outputs and cat into one 1-D tensor so that
257                // one Backward node covers both.
258                let flat_y_BNLHP: Tensor<B, 1> =
259                    Tensor::<B, 5>::from_primitive(TensorPrimitive::Float(prim_y_bnlhp))
260                        .reshape([flat_len_y_BNLHP]);
261                let flat_final_state_BHPR: Tensor<B, 1> =
262                    Tensor::<B, 4>::from_primitive(TensorPrimitive::Float(prim_final_state_bhpr))
263                        .reshape([flat_len_final_state_BHPR]);
264                let combined: Tensor<B, 1> =
265                    Tensor::cat(vec![flat_y_BNLHP, flat_final_state_BHPR], 0);
266
267                let state = State {
268                    x_bnlhp: x_bnlhp.primitive.clone(),
269                    dt_discretized_bhnl: dt_discretized_bhnl.primitive.clone(),
270                    b_bnlgr: b_bnlgr.primitive.clone(),
271                    c_bnlgr: c_bnlgr.primitive.clone(),
272                    d_h: d_h.primitive.clone(),
273                    initial_state_bhpr: initial_state_bhpr.primitive.clone(),
274                    a_decay_h: a_decay_h.primitive.clone(),
275                    //
276                    flat_len_y_BNLHP,
277                    flat_len_final_state_BHPR,
278                    //
279                    shape_x_bnlhp, shape_dt_discretized_bhnl, shape_b_bnlgr, shape_c_bnlgr, shape_d_h, shape_initial_state_bhpr, shape_a_decay_h,
280                    shape_y_bnlhp, shape_final_state_bhpr,
281                };
282                let tracked_combined: FloatTensor<Autodiff<B, C>> =
283                    prep.finish(state, combined.into_primitive().tensor());
284
285                // Split the tracked 1-D tensor back into the two outputs.
286                // The narrow / reshape ops create thin pass-through autodiff
287                // nodes; their backward accumulates gradients into the
288                // combined gradient vector consumed above.
289                let tracked_combined: Tensor<Autodiff<B, C>, 1> =
290                    Tensor::from_primitive(TensorPrimitive::Float(tracked_combined));
291
292                let tracked_y_bnlhp: Tensor<Autodiff<B, C>, 5> = tracked_combined
293                    .clone()
294                    .narrow(0, 0, flat_len_y_BNLHP)
295                    .reshape(shape_y_bnlhp);
296                let tracked_final_state_bhpr: Tensor<Autodiff<B, C>, 4> = tracked_combined
297                    .narrow(0, flat_len_y_BNLHP, flat_len_final_state_BHPR)
298                    .reshape(shape_final_state_bhpr);
299
300                (
301                    tracked_y_bnlhp.into_primitive().tensor(),
302                    tracked_final_state_bhpr.into_primitive().tensor(),
303                )
304            }
305
306            OpsKind::UnTracked(prep) => {
307                // No gradient tracking needed — run bare forward and return.
308                let (prim_y_bnlhp, prim_final_state_bhpr) = B::ssd_serial_recalculated(
309                    x_bnlhp.primitive,
310                    dt_discretized_bhnl.primitive,
311                    b_bnlgr.primitive,
312                    c_bnlgr.primitive,
313                    d_h.primitive,
314                    initial_state_bhpr.primitive,
315                    a_decay_h.primitive,
316                );
317
318                // Note: prep.finish accepts only a single tensor.
319                let flat_y_BNLHP: Tensor<B, 1> =
320                    Tensor::<B, 5>::from_primitive(TensorPrimitive::Float(prim_y_bnlhp))
321                        .reshape([flat_len_y_BNLHP]);
322                let flat_final_state_BHPR: Tensor<B, 1> =
323                    Tensor::<B, 4>::from_primitive(TensorPrimitive::Float(prim_final_state_bhpr))
324                        .reshape([flat_len_final_state_BHPR]);
325                let combined: Tensor<B, 1> = Tensor::cat(vec![flat_y_BNLHP, flat_final_state_BHPR], 0);
326
327                let tracked_combined: FloatTensor<Autodiff<B, C>> =
328                    prep.finish(combined.into_primitive().tensor());
329
330                let tracked_combined: Tensor<Autodiff<B, C>, 1> =
331                    Tensor::from_primitive(TensorPrimitive::Float(tracked_combined));
332                let tracked_y_bnlhp: Tensor<Autodiff<B, C>, 5> = tracked_combined
333                    .clone()
334                    .narrow(0, 0, flat_len_y_BNLHP)
335                    .reshape(shape_y_bnlhp);
336                let tracked_final_state_bhpr: Tensor<Autodiff<B, C>, 4> = tracked_combined
337                    .narrow(0, flat_len_y_BNLHP, flat_len_final_state_BHPR)
338                    .reshape(shape_final_state_bhpr);
339
340                (
341                    tracked_y_bnlhp.into_primitive().tensor(),
342                    tracked_final_state_bhpr.into_primitive().tensor(),
343                )
344            }
345        } // end match
346    } // end fn ssd_serial_recalculated on Autodiff<B, C>
347} // end impl Mamba2BackendExt for Autodiff<B, C>