Skip to main content

Module combined_grad

Module combined_grad 

Source
Expand description

Flatten/unflatten (y, final_state) into one tracked tensor for the custom backward. Helpers for the “two-output, one autodiff node” pattern used by both crate::mamba2::ssd::Mamba2BackendExt::ssd_serial_recalculated and crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdBackendExt::double_ssd_serial_recalculated/crate::mamba3::single_ssd::ssd::Mamba3SingleSsdBackendExt::single_ssd_serial_recalculated.

Burn’s prep.finish accepts only a single tracked tensor, so the two outputs (y and final_state) are flattened and concatenated into a single 1-D tracked tensor; the caller then narrows it back into two reshaped views. Burn’s autodiff accumulates the upstream gradients of those views into the combined gradient vector which the custom backward consumes.

Functions§

autodiff_unflatten_pair
Inverse of flatten_pair: split a 1-D combined tensor back into the two outputs at their original ranks/shapes.
flatten_pair
Flatten the two outputs (y and final_state) and concatenate them along a fresh axis-0 into a single 1-D tensor. Returns the combined tensor and the per-output flat lengths needed to split it later.
unflatten_pair
Inverse of flatten_pair: split a 1-D combined tensor back into the two outputs at their original ranks/shapes.