burn_mamba/mamba3/single_ssd/ssd/serial_recalculated/mod.rs
1//! Single-SSD serial scan with a custom, memory-efficient backward.
2//!
3//! The forward + [`Mamba3SingleSsdBackendExt`] impl live in
4//! `serial_recalculated`; the registered autodiff [`backward`] node and the
5//! recompute-based gradient math in [`combined_backward`] save training memory
6//! by recomputing intermediates instead of storing them.
7
8/// The registered custom `Backward` node (autodiff op).
9#[cfg(feature = "autodiff")]
10pub mod backward;
11/// Recompute-based gradient math (the memory-efficient backward).
12pub mod combined_backward;
13mod serial_recalculated;
14
15pub use serial_recalculated::Mamba3SingleSsdBackendExt;
16
17#[cfg(feature = "autodiff")]
18pub use serial_recalculated::Mamba3SingleSsdAutodiffBackendExt;