1#![allow(non_snake_case)]
2
3use crate::mamba2::ssd::serial_recalculated::{
4 Mamba2BackendExt,
5 combined_backward::{self, CombinedGrads},
6};
7use burn::backend::autodiff::{
8 Autodiff,
9 checkpoint::{base::Checkpointer, strategy::CheckpointStrategy},
10 grads::Gradients,
11 ops::{Backward, Ops, OpsKind},
12};
13use burn::prelude::*;
14use burn::tensor::{TensorPrimitive, ops::FloatTensor};
15
16impl<B: Backend + Mamba2BackendExt, C: CheckpointStrategy> Mamba2BackendExt for Autodiff<B, C> {
17 fn ssd_serial_recalculated(
25 x_bnlhp: FloatTensor<Self>,
35 dt_discretized_bhnl: FloatTensor<Self>,
36 b_bnlgr: FloatTensor<Self>,
37 c_bnlgr: FloatTensor<Self>,
38 d_h: FloatTensor<Self>,
39 initial_state_bhpr: FloatTensor<Self>,
40 a_decay_h: FloatTensor<Self>,
41 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
42 #[derive(Debug)]
44 struct K2K3K4K5CombinedBackward;
45
46 #[derive(Clone, Debug)]
47 struct State<B: Backend> {
48 x_bnlhp: FloatTensor<B>,
49 dt_discretized_bhnl: FloatTensor<B>,
50 b_bnlgr: FloatTensor<B>,
51 c_bnlgr: FloatTensor<B>,
52 d_h: FloatTensor<B>,
53 initial_state_bhpr: FloatTensor<B>,
54 a_decay_h: FloatTensor<B>,
55 flat_len_y_BNLHP: usize,
57 flat_len_final_state_BHPR: usize,
58 shape_x_bnlhp: [usize; 5],
60 shape_dt_discretized_bhnl: [usize; 4],
61 shape_b_bnlgr: [usize; 5],
62 shape_c_bnlgr: [usize; 5],
63 shape_d_h: [usize; 1],
64 shape_initial_state_bhpr: [usize; 4],
65 shape_a_decay_h: [usize; 1],
66 shape_y_bnlhp: [usize; 5], shape_final_state_bhpr: [usize; 4], }
69
70 #[allow(clippy::type_complexity)]
75 impl<B: Backend + Mamba2BackendExt> Backward<B, 7> for K2K3K4K5CombinedBackward {
76 type State = State<B>;
77
78 fn backward(
79 self,
80 ops: Ops<Self::State, 7>,
81 grads: &mut Gradients,
82 _checkpointer: &mut Checkpointer,
83 ) {
84 let [
85 node_x_bnlhp,
86 node_dt_discretized_bhnl,
87 node_b_bnlgr,
88 node_c_bnlgr,
89 node_d_h,
90 node_initial_state_bhpr,
91 node_a_decay_h,
92 ] = ops.parents;
93
94 let d_combined: Tensor<B, 1> =
96 Tensor::from_primitive(TensorPrimitive::Float(grads.consume::<B>(&ops.node)));
97
98 let State {
99 x_bnlhp,
100 dt_discretized_bhnl,
101 b_bnlgr,
102 c_bnlgr,
103 d_h,
104 initial_state_bhpr,
105 a_decay_h,
106 flat_len_y_BNLHP,
108 flat_len_final_state_BHPR,
109 shape_x_bnlhp,
111 shape_dt_discretized_bhnl,
112 shape_b_bnlgr,
113 shape_c_bnlgr,
114 shape_d_h,
115 shape_initial_state_bhpr,
116 shape_a_decay_h,
117 shape_y_bnlhp,
119 shape_final_state_bhpr,
120 } = ops.state;
121
122 use super::serial_recalculated::mk;
124
125 let x_bnlhp = mk::<_, 5>(x_bnlhp).reshape(shape_x_bnlhp);
126 let dt_discretized_bhnl =
127 mk::<_, 4>(dt_discretized_bhnl).reshape(shape_dt_discretized_bhnl);
128 let b_bnlgr = mk::<_, 5>(b_bnlgr).reshape(shape_b_bnlgr);
129 let c_bnlgr = mk::<_, 5>(c_bnlgr).reshape(shape_c_bnlgr);
130 let d_h = mk::<_, 1>(d_h).reshape(shape_d_h);
131 let initial_state_bhpr =
132 mk::<_, 4>(initial_state_bhpr).reshape(shape_initial_state_bhpr);
133 let a_decay_h = mk::<_, 1>(a_decay_h).reshape(shape_a_decay_h);
134
135 let flat_d_y_BNLHP = d_combined.clone().narrow(0, 0, flat_len_y_BNLHP);
138 let flat_d_final_state_BHPR =
139 d_combined.narrow(0, flat_len_y_BNLHP, flat_len_final_state_BHPR);
140
141 let d_y_bnlhp: Tensor<B, 5> = flat_d_y_BNLHP.reshape(shape_y_bnlhp);
142 let d_final_state_bhpr: Tensor<B, 4> =
143 flat_d_final_state_BHPR.reshape(shape_final_state_bhpr);
144
145 let CombinedGrads {
147 d_x_bnlhp,
148 d_dt_discretized_bhnl,
149 d_b_bnlgr,
150 d_c_bnlgr,
151 d_d_h,
152 d_initial_state_bhpr,
153 d_a_decay_h,
154 } = combined_backward::combined_backward(
155 d_y_bnlhp,
156 d_final_state_bhpr,
157 x_bnlhp,
159 dt_discretized_bhnl,
160 b_bnlgr,
161 c_bnlgr,
162 d_h,
163 initial_state_bhpr,
164 a_decay_h,
165 );
166
167 if let Some(n) = node_x_bnlhp {
186 grads.register::<B>(n.id, d_x_bnlhp.into_primitive().tensor());
187 }
188 if let Some(n) = node_dt_discretized_bhnl {
189 grads.register::<B>(n.id, d_dt_discretized_bhnl.into_primitive().tensor());
190 }
191 if let Some(n) = node_b_bnlgr {
192 grads.register::<B>(n.id, d_b_bnlgr.into_primitive().tensor());
193 }
194 if let Some(n) = node_c_bnlgr {
195 grads.register::<B>(n.id, d_c_bnlgr.into_primitive().tensor());
196 }
197 if let Some(n) = node_d_h {
198 grads.register::<B>(n.id, d_d_h.into_primitive().tensor());
199 }
200 if let Some(n) = node_initial_state_bhpr {
201 grads.register::<B>(n.id, d_initial_state_bhpr.into_primitive().tensor());
202 }
203 if let Some(n) = node_a_decay_h {
204 grads.register::<B>(n.id, d_a_decay_h.into_primitive().tensor());
205 }
206 }
207 } use burn::tensor::TensorMetadata;
213 let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.primitive.shape().dims();
214 let [_, _, _, ngroups, state_rank] = b_bnlgr.primitive.shape().dims();
215
216 let flat_len_y_BNLHP = batch * nchunks * chunk_len * nheads * per_head_dim;
217 let flat_len_final_state_BHPR = batch * nheads * per_head_dim * state_rank;
218
219 let shape_x_bnlhp: [usize; 5] = [batch, nchunks, chunk_len, nheads, per_head_dim];
220 let shape_dt_discretized_bhnl: [usize; 4] = [batch, nheads, nchunks, chunk_len];
221 let shape_b_bnlgr: [usize; 5] = [batch, nchunks, chunk_len, ngroups, state_rank];
222 let shape_c_bnlgr: [usize; 5] = [batch, nchunks, chunk_len, ngroups, state_rank];
223 let shape_d_h: [usize; 1] = [nheads];
224 let shape_initial_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
225 let shape_a_decay_h: [usize; 1] = [nheads];
226 let shape_y_bnlhp: [usize; 5] = [batch, nchunks, chunk_len, nheads, per_head_dim];
227 let shape_final_state_bhpr: [usize; 4] = [batch, nheads, per_head_dim, state_rank];
228
229 match K2K3K4K5CombinedBackward
231 .prepare::<C>([
232 x_bnlhp.node.clone(),
233 dt_discretized_bhnl.node.clone(),
234 b_bnlgr.node.clone(),
235 c_bnlgr.node.clone(),
236 d_h.node.clone(),
237 initial_state_bhpr.node.clone(),
238 a_decay_h.node.clone(),
239 ])
240 .compute_bound()
241 .stateful() {
243 OpsKind::Tracked(prep) => {
244 let (prim_y_bnlhp, prim_final_state_bhpr) = B::ssd_serial_recalculated(
246 x_bnlhp.primitive.clone(),
247 dt_discretized_bhnl.primitive.clone(),
248 b_bnlgr.primitive.clone(),
249 c_bnlgr.primitive.clone(),
250 d_h.primitive.clone(),
251 initial_state_bhpr.primitive.clone(),
252 a_decay_h.primitive.clone(),
253 );
254
255 let flat_y_BNLHP: Tensor<B, 1> =
259 Tensor::<B, 5>::from_primitive(TensorPrimitive::Float(prim_y_bnlhp))
260 .reshape([flat_len_y_BNLHP]);
261 let flat_final_state_BHPR: Tensor<B, 1> =
262 Tensor::<B, 4>::from_primitive(TensorPrimitive::Float(prim_final_state_bhpr))
263 .reshape([flat_len_final_state_BHPR]);
264 let combined: Tensor<B, 1> =
265 Tensor::cat(vec![flat_y_BNLHP, flat_final_state_BHPR], 0);
266
267 let state = State {
268 x_bnlhp: x_bnlhp.primitive.clone(),
269 dt_discretized_bhnl: dt_discretized_bhnl.primitive.clone(),
270 b_bnlgr: b_bnlgr.primitive.clone(),
271 c_bnlgr: c_bnlgr.primitive.clone(),
272 d_h: d_h.primitive.clone(),
273 initial_state_bhpr: initial_state_bhpr.primitive.clone(),
274 a_decay_h: a_decay_h.primitive.clone(),
275 flat_len_y_BNLHP,
277 flat_len_final_state_BHPR,
278 shape_x_bnlhp, shape_dt_discretized_bhnl, shape_b_bnlgr, shape_c_bnlgr, shape_d_h, shape_initial_state_bhpr, shape_a_decay_h,
280 shape_y_bnlhp, shape_final_state_bhpr,
281 };
282 let tracked_combined: FloatTensor<Autodiff<B, C>> =
283 prep.finish(state, combined.into_primitive().tensor());
284
285 let tracked_combined: Tensor<Autodiff<B, C>, 1> =
290 Tensor::from_primitive(TensorPrimitive::Float(tracked_combined));
291
292 let tracked_y_bnlhp: Tensor<Autodiff<B, C>, 5> = tracked_combined
293 .clone()
294 .narrow(0, 0, flat_len_y_BNLHP)
295 .reshape(shape_y_bnlhp);
296 let tracked_final_state_bhpr: Tensor<Autodiff<B, C>, 4> = tracked_combined
297 .narrow(0, flat_len_y_BNLHP, flat_len_final_state_BHPR)
298 .reshape(shape_final_state_bhpr);
299
300 (
301 tracked_y_bnlhp.into_primitive().tensor(),
302 tracked_final_state_bhpr.into_primitive().tensor(),
303 )
304 }
305
306 OpsKind::UnTracked(prep) => {
307 let (prim_y_bnlhp, prim_final_state_bhpr) = B::ssd_serial_recalculated(
309 x_bnlhp.primitive,
310 dt_discretized_bhnl.primitive,
311 b_bnlgr.primitive,
312 c_bnlgr.primitive,
313 d_h.primitive,
314 initial_state_bhpr.primitive,
315 a_decay_h.primitive,
316 );
317
318 let flat_y_BNLHP: Tensor<B, 1> =
320 Tensor::<B, 5>::from_primitive(TensorPrimitive::Float(prim_y_bnlhp))
321 .reshape([flat_len_y_BNLHP]);
322 let flat_final_state_BHPR: Tensor<B, 1> =
323 Tensor::<B, 4>::from_primitive(TensorPrimitive::Float(prim_final_state_bhpr))
324 .reshape([flat_len_final_state_BHPR]);
325 let combined: Tensor<B, 1> = Tensor::cat(vec![flat_y_BNLHP, flat_final_state_BHPR], 0);
326
327 let tracked_combined: FloatTensor<Autodiff<B, C>> =
328 prep.finish(combined.into_primitive().tensor());
329
330 let tracked_combined: Tensor<Autodiff<B, C>, 1> =
331 Tensor::from_primitive(TensorPrimitive::Float(tracked_combined));
332 let tracked_y_bnlhp: Tensor<Autodiff<B, C>, 5> = tracked_combined
333 .clone()
334 .narrow(0, 0, flat_len_y_BNLHP)
335 .reshape(shape_y_bnlhp);
336 let tracked_final_state_bhpr: Tensor<Autodiff<B, C>, 4> = tracked_combined
337 .narrow(0, flat_len_y_BNLHP, flat_len_final_state_BHPR)
338 .reshape(shape_final_state_bhpr);
339
340 (
341 tracked_y_bnlhp.into_primitive().tensor(),
342 tracked_final_state_bhpr.into_primitive().tensor(),
343 )
344 }
345 } } }