Skip to main content

burn_mamba/utils/
combined_grad.rs

1//! Helpers for the "two-output, one autodiff node" pattern used by both
2//! [`crate::mamba2::ssd::Mamba2BackendExt::ssd_serial_recalculated`]
3//! and
4//! [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdBackendExt::double_ssd_serial_recalculated`]/[`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdBackendExt::single_ssd_serial_recalculated`].
5//!
6//! Burn's `prep.finish` accepts only a single tracked tensor, so the two
7//! outputs (`y` and `final_state`) are flattened and concatenated into a
8//! single 1-D tracked tensor; the caller then `narrow`s it back into two
9//! reshaped views. Burn's autodiff accumulates the upstream gradients of those
10//! views into the combined gradient vector which the custom backward consumes.
11
12use burn::backend::Autodiff;
13use burn::backend::Backend;
14use burn::backend::BackendTypes;
15use burn::backend::TensorMetadata;
16use burn::backend::autodiff::checkpoint::strategy::CheckpointStrategy;
17use burn::backend::ops::FloatTensorOps;
18use burn::backend::tensor::FloatTensor;
19use burn::prelude::*;
20
21/// Flatten the two outputs (`y` and `final_state`) and concatenate them along a
22/// fresh axis-0 into a single 1-D tensor. Returns the combined tensor and the
23/// per-output flat lengths needed to split it later.
24pub fn flatten_pair<B: Backend>(
25    y: <B as BackendTypes>::FloatTensorPrimitive,
26    final_state: <B as BackendTypes>::FloatTensorPrimitive,
27) -> (<B as BackendTypes>::FloatTensorPrimitive, usize, usize) {
28    let flat_y_len = y.shape().num_elements();
29    let flat_s_len = final_state.shape().num_elements();
30    let flat_y = B::float_reshape(y, Shape::new([flat_y_len]));
31    let flat_s = B::float_reshape(final_state, Shape::new([flat_s_len]));
32    let combined = B::float_cat(vec![flat_y, flat_s], 0);
33    (combined, flat_y_len, flat_s_len)
34}
35
36/// Inverse of [`flatten_pair`]: split a 1-D combined tensor back into the two
37/// outputs at their original ranks/shapes.
38pub fn unflatten_pair<B: Backend, const DA: usize, const DB: usize>(
39    combined: <B as BackendTypes>::FloatTensorPrimitive,
40    flat_y_len: usize,
41    flat_s_len: usize,
42    shape_y: [usize; DA],
43    shape_s: [usize; DB],
44) -> (
45    <B as BackendTypes>::FloatTensorPrimitive,
46    <B as BackendTypes>::FloatTensorPrimitive,
47) {
48    let flat_y = B::float_slice(combined.clone(), &[s![0..flat_y_len]]);
49    let y = B::float_reshape(flat_y, Shape::new(shape_y));
50    let flat_s = B::float_slice(combined, &[s![flat_y_len..flat_y_len + flat_s_len]]);
51    let s = B::float_reshape(flat_s, Shape::new(shape_s));
52    (y, s)
53}
54
55/// Inverse of [`flatten_pair`]: split a 1-D combined tensor back into the two
56/// outputs at their original ranks/shapes.
57pub fn autodiff_unflatten_pair<
58    B: Backend,
59    C: CheckpointStrategy,
60    const DA: usize,
61    const DB: usize,
62>(
63    combined: FloatTensor<Autodiff<B, C>>,
64    flat_y_len: usize,
65    flat_s_len: usize,
66    shape_y: [usize; DA],
67    shape_s: [usize; DB],
68) -> (FloatTensor<Autodiff<B, C>>, FloatTensor<Autodiff<B, C>>) {
69    let flat_y = Autodiff::<B, C>::float_slice(combined.clone(), &[s![0..flat_y_len]]);
70    let y = Autodiff::<B, C>::float_reshape(flat_y, Shape::new(shape_y));
71    let flat_s =
72        Autodiff::<B, C>::float_slice(combined, &[s![flat_y_len..flat_y_len + flat_s_len]]);
73    let s = Autodiff::<B, C>::float_reshape(flat_s, Shape::new(shape_s));
74    (y, s)
75}