Expand description
Flatten/unflatten (y, final_state) into one tracked tensor for the custom
backward.
Helpers for the “two-output, one autodiff node” pattern used by both
crate::mamba2::ssd::Mamba2BackendExt::ssd_serial_recalculated
and
crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdBackendExt::double_ssd_serial_recalculated/crate::mamba3::single_ssd::ssd::Mamba3SingleSsdBackendExt::single_ssd_serial_recalculated.
Burn’s prep.finish accepts only a single tracked tensor, so the two
outputs (y and final_state) are flattened and concatenated into a
single 1-D tracked tensor; the caller then narrows it back into two
reshaped views. Burn’s autodiff accumulates the upstream gradients of those
views into the combined gradient vector which the custom backward consumes.
Functions§
- autodiff_
unflatten_ pair - Inverse of
flatten_pair: split a 1-D combined tensor back into the two outputs at their original ranks/shapes. - flatten_
pair - Flatten the two outputs (
yandfinal_state) and concatenate them along a fresh axis-0 into a single 1-D tensor. Returns the combined tensor and the per-output flat lengths needed to split it later. - unflatten_
pair - Inverse of
flatten_pair: split a 1-D combined tensor back into the two outputs at their original ranks/shapes.