1#![allow(non_snake_case)]
11
12use crate::mamba3::double_ssd::ssd;
13use crate::utils::fprim::F;
14use burn::backend::autodiff::{
15 Autodiff,
16 checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
17 grads::Gradients,
18 ops::{Backward, Ops, OpsKind},
19};
20use burn::backend::tensor::FloatTensor;
21use burn::backend::{Backend, BackendTypes};
22use ssd::serial_recalculated::{
23 Mamba3DoubleSsdBackendExt,
24 combined_backward::{self, CombinedGrads},
25};
26
27impl<B: Backend + Mamba3DoubleSsdBackendExt, C: CheckpointStrategy> Mamba3DoubleSsdBackendExt
28 for Autodiff<B, C>
29{
30 fn double_ssd_serial_recalculated(
39 v_bnlmhp: FloatTensor<Self>,
40 da_bnlh: FloatTensor<Self>,
41 b_bnlmhr: FloatTensor<Self>,
42 c_bnlmhr: FloatTensor<Self>,
43 initial_state_bhpr: FloatTensor<Self>,
44 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
45 #[derive(Debug)]
47 struct CombinedKernelsBackward;
48
49 #[derive(Clone, Debug)]
50 struct State<B: Backend> {
51 v_bnlmhp: <B as BackendTypes>::FloatTensorPrimitive,
53 da_bnlh: <B as BackendTypes>::FloatTensorPrimitive,
54 b_bnlmhr: <B as BackendTypes>::FloatTensorPrimitive,
55 c_bnlmhr: <B as BackendTypes>::FloatTensorPrimitive,
56 initial_state_bhpr: <B as BackendTypes>::FloatTensorPrimitive,
57 flat_len_y_BNLMHP: usize,
59 flat_len_final_state_BHPR: usize,
60 shape_v_bnlmhp: [usize; 6],
62 shape_da_bnlh: [usize; 4],
63 shape_b_bnlmhr: [usize; 6],
64 shape_c_bnlmhr: [usize; 6],
65 shape_initial_state_bhpr: [usize; 4],
66 shape_y_bnlmhp: [usize; 6],
67 shape_final_state_bhpr: [usize; 4],
68 }
69
70 impl<B: Backend + Mamba3DoubleSsdBackendExt> Backward<B, 5> for CombinedKernelsBackward {
71 type State = State<B>;
72
73 fn backward(
74 self,
75 ops: Ops<Self::State, 5>,
76 grads: &mut Gradients,
77 _checkpointer: &mut Checkpointer,
78 ) {
79 let [
80 node_v_bnlmhp,
81 node_da_bnlh,
82 node_b_bnlmhr,
83 node_c_bnlmhr,
84 node_initial_state_bhpr,
85 ] = ops.parents;
86
87 let d_combined: <B as BackendTypes>::FloatTensorPrimitive =
89 grads.consume::<B>(&ops.node);
90
91 let State {
92 v_bnlmhp,
93 da_bnlh,
94 b_bnlmhr,
95 c_bnlmhr,
96 initial_state_bhpr,
97 flat_len_y_BNLMHP,
98 flat_len_final_state_BHPR,
99 shape_v_bnlmhp,
100 shape_da_bnlh,
101 shape_b_bnlmhr,
102 shape_c_bnlmhr,
103 shape_initial_state_bhpr,
104 shape_y_bnlmhp,
105 shape_final_state_bhpr,
106 } = ops.state;
107
108 let v_bnlmhp = F::<B, 6>::new(v_bnlmhp).reshape(shape_v_bnlmhp);
110 let da_bnlh = F::<B, 4>::new(da_bnlh).reshape(shape_da_bnlh);
111 let b_bnlmhr = F::<B, 6>::new(b_bnlmhr).reshape(shape_b_bnlmhr);
112 let c_bnlmhr = F::<B, 6>::new(c_bnlmhr).reshape(shape_c_bnlmhr);
113 let initial_state_bhpr =
114 F::<B, 4>::new(initial_state_bhpr).reshape(shape_initial_state_bhpr);
115
116 let (d_y_bnlmhp, d_final_state_bhpr) =
118 crate::utils::combined_grad::unflatten_pair::<B, 6, 4>(
119 d_combined,
120 flat_len_y_BNLMHP,
121 flat_len_final_state_BHPR,
122 shape_y_bnlmhp,
123 shape_final_state_bhpr,
124 );
125
126 let CombinedGrads {
128 d_v_bnlmhp,
129 d_da_bnlh,
130 d_b_bnlmhr,
131 d_c_bnlmhr,
132 d_initial_state_bhpr,
133 ..
134 } = combined_backward::combined_backward(
135 F::<B, 6>::new(d_y_bnlmhp),
136 F::<B, 4>::new(d_final_state_bhpr),
137 v_bnlmhp,
138 da_bnlh,
139 b_bnlmhr,
140 c_bnlmhr,
141 initial_state_bhpr,
142 );
143
144 if let Some(n) = node_v_bnlmhp {
146 grads.register::<B>(n.id, d_v_bnlmhp.inner());
147 }
148 if let Some(n) = node_da_bnlh {
149 grads.register::<B>(n.id, d_da_bnlh.inner());
150 }
151 if let Some(n) = node_b_bnlmhr {
152 grads.register::<B>(n.id, d_b_bnlmhr.inner());
153 }
154 if let Some(n) = node_c_bnlmhr {
155 grads.register::<B>(n.id, d_c_bnlmhr.inner());
156 }
157 if let Some(n) = node_initial_state_bhpr {
158 grads.register::<B>(n.id, d_initial_state_bhpr.inner());
159 }
160 }
161 }
162
163 use burn::backend::TensorMetadata;
165 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] =
166 v_bnlmhp.primitive.shape().dims();
167 let [.., state_rank] = b_bnlmhr.primitive.shape().dims::<6>();
168
169 let flat_len_y_BNLMHP = batch * nchunks * chunk_len * mimo_rank * nheads * per_head_dim;
170 let flat_len_final_state_BHPR = batch * nheads * per_head_dim * state_rank;
171
172 let shape_v_bnlmhp: [usize; 6] =
173 [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim];
174 let shape_da_bnlh: [usize; 4] = [batch, nchunks, chunk_len, nheads];
175 let shape_b_bnlmhr: [usize; 6] = [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank];
176 let shape_c_bnlmhr: [usize; 6] = [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank];
177 let shape_initial_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
178 let shape_y_bnlmhp: [usize; 6] =
179 [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim];
180 let shape_final_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
181
182 match CombinedKernelsBackward
184 .prepare::<C>([
185 v_bnlmhp.node.clone(),
186 da_bnlh.node.clone(),
187 b_bnlmhr.node.clone(),
188 c_bnlmhr.node.clone(),
189 initial_state_bhpr.node.clone(),
190 ])
191 .compute_bound()
192 .stateful()
193 {
194 OpsKind::Tracked(prep) => {
195 let (prim_y_bnlmhp, prim_final_state_bhpr) = B::double_ssd_serial_recalculated(
196 v_bnlmhp.primitive.clone(),
197 da_bnlh.primitive.clone(),
198 b_bnlmhr.primitive.clone(),
199 c_bnlmhr.primitive.clone(),
200 initial_state_bhpr.primitive.clone(),
201 );
202
203 let (prim_combined, _, _) = crate::utils::combined_grad::flatten_pair::<B>(
206 prim_y_bnlmhp,
207 prim_final_state_bhpr,
208 );
209
210 let state = State {
211 v_bnlmhp: v_bnlmhp.primitive.clone(),
212 da_bnlh: da_bnlh.primitive.clone(),
213 b_bnlmhr: b_bnlmhr.primitive.clone(),
214 c_bnlmhr: c_bnlmhr.primitive.clone(),
215 initial_state_bhpr: initial_state_bhpr.primitive.clone(),
216 flat_len_y_BNLMHP,
217 flat_len_final_state_BHPR,
218 shape_v_bnlmhp,
219 shape_da_bnlh,
220 shape_b_bnlmhr,
221 shape_c_bnlmhr,
222 shape_initial_state_bhpr,
223 shape_y_bnlmhp,
224 shape_final_state_bhpr,
225 };
226 let tracked_combined: FloatTensor<Autodiff<B, C>> =
227 prep.finish(state, prim_combined);
228
229 let (tracked_y_bnlmhp, tracked_final_state_bhpr) =
234 crate::utils::combined_grad::autodiff_unflatten_pair::<B, C, 6, 4>(
235 tracked_combined,
236 flat_len_y_BNLMHP,
237 flat_len_final_state_BHPR,
238 shape_y_bnlmhp,
239 shape_final_state_bhpr,
240 );
241
242 (tracked_y_bnlmhp, tracked_final_state_bhpr)
243 }
244
245 OpsKind::UnTracked(prep) => {
246 let (prim_y_bnlmhp, prim_final_state_bhpr) = B::double_ssd_serial_recalculated(
248 v_bnlmhp.primitive,
249 da_bnlh.primitive,
250 b_bnlmhr.primitive,
251 c_bnlmhr.primitive,
252 initial_state_bhpr.primitive,
253 );
254
255 let (prim_combined, _, _) = crate::utils::combined_grad::flatten_pair::<B>(
256 prim_y_bnlmhp,
257 prim_final_state_bhpr,
258 );
259
260 let tracked_combined: FloatTensor<Autodiff<B, C>> = prep.finish(prim_combined);
261
262 let (tracked_y_bnlmhp, tracked_final_state_bhpr) =
263 crate::utils::combined_grad::autodiff_unflatten_pair::<B, C, 6, 4>(
264 tracked_combined,
265 flat_len_y_BNLMHP,
266 flat_len_final_state_BHPR,
267 shape_y_bnlmhp,
268 shape_final_state_bhpr,
269 );
270
271 (tracked_y_bnlmhp, tracked_final_state_bhpr)
272 }
273 }
274 }
275}