Skip to main content

burn_mamba/mamba3/double_ssd/ssd/serial_recalculated/
backward.rs

1//! # Custom autodiff node for the Mamba-3 double-SSD recompute backward
2//!
3//! Implements [`Mamba3DoubleSsdBackendExt`](crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdBackendExt)
4//! for `Autodiff<B>` via a single Burn [`Backward`] node.  The forward keeps only
5//! its leaf inputs; backprop replays the serial kernels and the gradient math in
6//! [`super::combined_backward`], so the large intermediates are never retained.
7//! The two outputs (`y`, `final_state`) are flattened into one tracked tensor
8//! (via [`crate::utils::combined_grad`]) so one node covers both.
9
10#![allow(non_snake_case)]
11
12use crate::mamba3::double_ssd::ssd;
13use crate::utils::fprim::F;
14use burn::backend::autodiff::{
15    Autodiff,
16    checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
17    grads::Gradients,
18    ops::{Backward, Ops, OpsKind},
19};
20use burn::backend::tensor::FloatTensor;
21use burn::backend::{Backend, BackendTypes};
22use ssd::serial_recalculated::{
23    Mamba3DoubleSsdBackendExt,
24    combined_backward::{self, CombinedGrads},
25};
26
27impl<B: Backend + Mamba3DoubleSsdBackendExt, C: CheckpointStrategy> Mamba3DoubleSsdBackendExt
28    for Autodiff<B, C>
29{
30    /// Memory-efficient combined forward+backward for the Mamba-3 MIMO SSD.
31    ///
32    /// The two output tensors (`y_bnlmhp`, `final_state_bhpr`) are flattened
33    /// and concatenated into a single 1-dimensional tracked tensor so a single
34    /// `Backward<B, 5>` node covers both. The caller receives split+reshaped
35    /// slices of that combined tensor; burn's autodiff accumulates their
36    /// upstream gradients back into one gradient vector before invoking
37    /// `backward`.
38    fn double_ssd_serial_recalculated(
39        v_bnlmhp: FloatTensor<Self>,
40        da_bnlh: FloatTensor<Self>,
41        b_bnlmhr: FloatTensor<Self>,
42        c_bnlmhr: FloatTensor<Self>,
43        initial_state_bhpr: FloatTensor<Self>,
44    ) -> (FloatTensor<Self>, FloatTensor<Self>) {
45        // ── Backward node definition ─────────────────────────────────────────
46        #[derive(Debug)]
47        struct CombinedKernelsBackward;
48
49        #[derive(Clone, Debug)]
50        struct State<B: Backend> {
51            // Saved forward inputs
52            v_bnlmhp: <B as BackendTypes>::FloatTensorPrimitive,
53            da_bnlh: <B as BackendTypes>::FloatTensorPrimitive,
54            b_bnlmhr: <B as BackendTypes>::FloatTensorPrimitive,
55            c_bnlmhr: <B as BackendTypes>::FloatTensorPrimitive,
56            initial_state_bhpr: <B as BackendTypes>::FloatTensorPrimitive,
57            // Flat lengths for splitting the combined upstream gradient
58            flat_len_y_BNLMHP: usize,
59            flat_len_final_state_BHPR: usize,
60            // Shapes needed to reconstruct tensors at the right ranks
61            shape_v_bnlmhp: [usize; 6],
62            shape_da_bnlh: [usize; 4],
63            shape_b_bnlmhr: [usize; 6],
64            shape_c_bnlmhr: [usize; 6],
65            shape_initial_state_bhpr: [usize; 4],
66            shape_y_bnlmhp: [usize; 6],
67            shape_final_state_bhpr: [usize; 4],
68        }
69
70        impl<B: Backend + Mamba3DoubleSsdBackendExt> Backward<B, 5> for CombinedKernelsBackward {
71            type State = State<B>;
72
73            fn backward(
74                self,
75                ops: Ops<Self::State, 5>,
76                grads: &mut Gradients,
77                _checkpointer: &mut Checkpointer,
78            ) {
79                let [
80                    node_v_bnlmhp,
81                    node_da_bnlh,
82                    node_b_bnlmhr,
83                    node_c_bnlmhr,
84                    node_initial_state_bhpr,
85                ] = ops.parents;
86
87                // Upstream gradient of the combined 1-dimensional output.
88                let d_combined: <B as BackendTypes>::FloatTensorPrimitive =
89                    grads.consume::<B>(&ops.node);
90
91                let State {
92                    v_bnlmhp,
93                    da_bnlh,
94                    b_bnlmhr,
95                    c_bnlmhr,
96                    initial_state_bhpr,
97                    flat_len_y_BNLMHP,
98                    flat_len_final_state_BHPR,
99                    shape_v_bnlmhp,
100                    shape_da_bnlh,
101                    shape_b_bnlmhr,
102                    shape_c_bnlmhr,
103                    shape_initial_state_bhpr,
104                    shape_y_bnlmhp,
105                    shape_final_state_bhpr,
106                } = ops.state;
107
108                // ── Reconstruct saved tensors as rank-tagged primitives ──
109                let v_bnlmhp = F::<B, 6>::new(v_bnlmhp).reshape(shape_v_bnlmhp);
110                let da_bnlh = F::<B, 4>::new(da_bnlh).reshape(shape_da_bnlh);
111                let b_bnlmhr = F::<B, 6>::new(b_bnlmhr).reshape(shape_b_bnlmhr);
112                let c_bnlmhr = F::<B, 6>::new(c_bnlmhr).reshape(shape_c_bnlmhr);
113                let initial_state_bhpr =
114                    F::<B, 4>::new(initial_state_bhpr).reshape(shape_initial_state_bhpr);
115
116                // ── Split combined gradient vector ──────────────────────
117                let (d_y_bnlmhp, d_final_state_bhpr) =
118                    crate::utils::combined_grad::unflatten_pair::<B, 6, 4>(
119                        d_combined,
120                        flat_len_y_BNLMHP,
121                        flat_len_final_state_BHPR,
122                        shape_y_bnlmhp,
123                        shape_final_state_bhpr,
124                    );
125
126                // ── Core gradient computation ───────────────────────────
127                let CombinedGrads {
128                    d_v_bnlmhp,
129                    d_da_bnlh,
130                    d_b_bnlmhr,
131                    d_c_bnlmhr,
132                    d_initial_state_bhpr,
133                    ..
134                } = combined_backward::combined_backward(
135                    F::<B, 6>::new(d_y_bnlmhp),
136                    F::<B, 4>::new(d_final_state_bhpr),
137                    v_bnlmhp,
138                    da_bnlh,
139                    b_bnlmhr,
140                    c_bnlmhr,
141                    initial_state_bhpr,
142                );
143
144                // ── Register gradients with autodiff ────────────────────
145                if let Some(n) = node_v_bnlmhp {
146                    grads.register::<B>(n.id, d_v_bnlmhp.inner());
147                }
148                if let Some(n) = node_da_bnlh {
149                    grads.register::<B>(n.id, d_da_bnlh.inner());
150                }
151                if let Some(n) = node_b_bnlmhr {
152                    grads.register::<B>(n.id, d_b_bnlmhr.inner());
153                }
154                if let Some(n) = node_c_bnlmhr {
155                    grads.register::<B>(n.id, d_c_bnlmhr.inner());
156                }
157                if let Some(n) = node_initial_state_bhpr {
158                    grads.register::<B>(n.id, d_initial_state_bhpr.inner());
159                }
160            }
161        }
162
163        // ── Shape extraction (via the AutodiffTensor wrappers) ─────────────
164        use burn::backend::TensorMetadata;
165        let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] =
166            v_bnlmhp.primitive.shape().dims();
167        let [.., state_rank] = b_bnlmhr.primitive.shape().dims::<6>();
168
169        let flat_len_y_BNLMHP = batch * nchunks * chunk_len * mimo_rank * nheads * per_head_dim;
170        let flat_len_final_state_BHPR = batch * nheads * per_head_dim * state_rank;
171
172        let shape_v_bnlmhp: [usize; 6] =
173            [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim];
174        let shape_da_bnlh: [usize; 4] = [batch, nchunks, chunk_len, nheads];
175        let shape_b_bnlmhr: [usize; 6] = [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank];
176        let shape_c_bnlmhr: [usize; 6] = [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank];
177        let shape_initial_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
178        let shape_y_bnlmhp: [usize; 6] =
179            [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim];
180        let shape_final_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
181
182        // ── Register backward / run forward ───────────────────────────────
183        match CombinedKernelsBackward
184            .prepare::<C>([
185                v_bnlmhp.node.clone(),
186                da_bnlh.node.clone(),
187                b_bnlmhr.node.clone(),
188                c_bnlmhr.node.clone(),
189                initial_state_bhpr.node.clone(),
190            ])
191            .compute_bound()
192            .stateful()
193        {
194            OpsKind::Tracked(prep) => {
195                let (prim_y_bnlmhp, prim_final_state_bhpr) = B::double_ssd_serial_recalculated(
196                    v_bnlmhp.primitive.clone(),
197                    da_bnlh.primitive.clone(),
198                    b_bnlmhr.primitive.clone(),
199                    c_bnlmhr.primitive.clone(),
200                    initial_state_bhpr.primitive.clone(),
201                );
202
203                // prep.finish takes a single tensor, so pack both outputs into a
204                // single 1-D tensor; one Backward node then covers both.
205                let (prim_combined, _, _) = crate::utils::combined_grad::flatten_pair::<B>(
206                    prim_y_bnlmhp,
207                    prim_final_state_bhpr,
208                );
209
210                let state = State {
211                    v_bnlmhp: v_bnlmhp.primitive.clone(),
212                    da_bnlh: da_bnlh.primitive.clone(),
213                    b_bnlmhr: b_bnlmhr.primitive.clone(),
214                    c_bnlmhr: c_bnlmhr.primitive.clone(),
215                    initial_state_bhpr: initial_state_bhpr.primitive.clone(),
216                    flat_len_y_BNLMHP,
217                    flat_len_final_state_BHPR,
218                    shape_v_bnlmhp,
219                    shape_da_bnlh,
220                    shape_b_bnlmhr,
221                    shape_c_bnlmhr,
222                    shape_initial_state_bhpr,
223                    shape_y_bnlmhp,
224                    shape_final_state_bhpr,
225                };
226                let tracked_combined: FloatTensor<Autodiff<B, C>> =
227                    prep.finish(state, prim_combined);
228
229                // Split the tracked combined tensor back into the two outputs.
230                // The narrow/reshape ops are thin autodiff pass-throughs whose
231                // backwards accumulate into the combined gradient vector that
232                // `backward` above consumes.
233                let (tracked_y_bnlmhp, tracked_final_state_bhpr) =
234                    crate::utils::combined_grad::autodiff_unflatten_pair::<B, C, 6, 4>(
235                        tracked_combined,
236                        flat_len_y_BNLMHP,
237                        flat_len_final_state_BHPR,
238                        shape_y_bnlmhp,
239                        shape_final_state_bhpr,
240                    );
241
242                (tracked_y_bnlmhp, tracked_final_state_bhpr)
243            }
244
245            OpsKind::UnTracked(prep) => {
246                // No gradient tracking — just run the bare forward.
247                let (prim_y_bnlmhp, prim_final_state_bhpr) = B::double_ssd_serial_recalculated(
248                    v_bnlmhp.primitive,
249                    da_bnlh.primitive,
250                    b_bnlmhr.primitive,
251                    c_bnlmhr.primitive,
252                    initial_state_bhpr.primitive,
253                );
254
255                let (prim_combined, _, _) = crate::utils::combined_grad::flatten_pair::<B>(
256                    prim_y_bnlmhp,
257                    prim_final_state_bhpr,
258                );
259
260                let tracked_combined: FloatTensor<Autodiff<B, C>> = prep.finish(prim_combined);
261
262                let (tracked_y_bnlmhp, tracked_final_state_bhpr) =
263                    crate::utils::combined_grad::autodiff_unflatten_pair::<B, C, 6, 4>(
264                        tracked_combined,
265                        flat_len_y_BNLMHP,
266                        flat_len_final_state_BHPR,
267                        shape_y_bnlmhp,
268                        shape_final_state_bhpr,
269                    );
270
271                (tracked_y_bnlmhp, tracked_final_state_bhpr)
272            }
273        }
274    }
275}