burn_mamba/utils/
combined_grad.rs1use burn::backend::Autodiff;
13use burn::backend::Backend;
14use burn::backend::BackendTypes;
15use burn::backend::TensorMetadata;
16use burn::backend::autodiff::checkpoint::strategy::CheckpointStrategy;
17use burn::backend::ops::FloatTensorOps;
18use burn::backend::tensor::FloatTensor;
19use burn::prelude::*;
20
21pub fn flatten_pair<B: Backend>(
25 y: <B as BackendTypes>::FloatTensorPrimitive,
26 final_state: <B as BackendTypes>::FloatTensorPrimitive,
27) -> (<B as BackendTypes>::FloatTensorPrimitive, usize, usize) {
28 let flat_y_len = y.shape().num_elements();
29 let flat_s_len = final_state.shape().num_elements();
30 let flat_y = B::float_reshape(y, Shape::new([flat_y_len]));
31 let flat_s = B::float_reshape(final_state, Shape::new([flat_s_len]));
32 let combined = B::float_cat(vec![flat_y, flat_s], 0);
33 (combined, flat_y_len, flat_s_len)
34}
35
36pub fn unflatten_pair<B: Backend, const DA: usize, const DB: usize>(
39 combined: <B as BackendTypes>::FloatTensorPrimitive,
40 flat_y_len: usize,
41 flat_s_len: usize,
42 shape_y: [usize; DA],
43 shape_s: [usize; DB],
44) -> (
45 <B as BackendTypes>::FloatTensorPrimitive,
46 <B as BackendTypes>::FloatTensorPrimitive,
47) {
48 let flat_y = B::float_slice(combined.clone(), &[s![0..flat_y_len]]);
49 let y = B::float_reshape(flat_y, Shape::new(shape_y));
50 let flat_s = B::float_slice(combined, &[s![flat_y_len..flat_y_len + flat_s_len]]);
51 let s = B::float_reshape(flat_s, Shape::new(shape_s));
52 (y, s)
53}
54
55pub fn autodiff_unflatten_pair<
58 B: Backend,
59 C: CheckpointStrategy,
60 const DA: usize,
61 const DB: usize,
62>(
63 combined: FloatTensor<Autodiff<B, C>>,
64 flat_y_len: usize,
65 flat_s_len: usize,
66 shape_y: [usize; DA],
67 shape_s: [usize; DB],
68) -> (FloatTensor<Autodiff<B, C>>, FloatTensor<Autodiff<B, C>>) {
69 let flat_y = Autodiff::<B, C>::float_slice(combined.clone(), &[s![0..flat_y_len]]);
70 let y = Autodiff::<B, C>::float_reshape(flat_y, Shape::new(shape_y));
71 let flat_s =
72 Autodiff::<B, C>::float_slice(combined, &[s![flat_y_len..flat_y_len + flat_s_len]]);
73 let s = Autodiff::<B, C>::float_reshape(flat_s, Shape::new(shape_s));
74 (y, s)
75}