Skip to main content

burn_mamba/mamba3/double_ssd/ssd/serial_recalculated/
mod.rs

1//! Double-SSD serial scan with a custom, memory-efficient backward.
2//!
3//! The forward + [`Mamba3DoubleSsdBackendExt`] impl live in
4//! `serial_recalculated`; the registered autodiff [`backward`] node and the
5//! recompute-based gradient math in [`combined_backward`] save training memory
6//! by recomputing intermediates instead of storing them.
7
8/// The registered custom `Backward` node (autodiff op).
9#[cfg(feature = "autodiff")]
10pub mod backward;
11/// Recompute-based gradient math (the memory-efficient backward).
12pub mod combined_backward;
13mod serial_recalculated;
14
15pub use serial_recalculated::Mamba3DoubleSsdBackendExt;
16
17#[cfg(feature = "autodiff")]
18pub use serial_recalculated::Mamba3DoubleSsdAutodiffBackendExt;
19
20// Primitive forward kernels reused by the recompute backward and by the
21// single-SSD pathway's backward (which shares the standard SSD kernels).
22pub(crate) use serial_recalculated::{
23    k1_ssd_chunk_cumsum, k2_ssd_bmm, k3_ssd_chunk_state, k4_ssd_state_passing,
24};