Skip to main content

burn_mamba/mamba3/single_ssd/
mod.rs

1//! # Single-SSD pathway (official-kernel form)
2//!
3//! Realises the Mamba-3 trapezoidal recurrence as a **single SSD call** (the
4//! official Triton-SISO / Tilelang-MIMO form): a key scale
5//! `scaleₜ = γₜ + (1−λₜ₊₁)·Δₜ₊₁`, a strict lower-triangular intra-chunk mask, a
6//! same-step γ correction, and a boundary-β seed folded into the initial state.
7//!
8//! Uses ≈ half the training memory of the
9//! [`double_ssd`](crate::mamba3::double_ssd) pathway.  Its cache's SSM
10//! accumulator `h'` has **different mid-sequence semantics** than the double-SSD
11//! state (hence a distinct cache type), but the two coincide at sequence
12//! boundaries and inter-convert via field-identity `From` impls.
13
14/// The single-SSD cache (same fields as double-SSD, different `ssm` semantics).
15pub mod cache;
16/// `forward_single_ssd` (scale + boundary-β seed) and `step_single_ssd`.
17pub mod single_ssd;
18/// The standard SSD kernels specialised to the single-pass scale/mask.
19pub mod ssd;
20
21/// Public re-exports for the single-SSD pathway.
22pub mod prelude {
23    use super::*;
24    pub use cache::{
25        Mamba3SingleSsdCache, Mamba3SingleSsdCacheConfig, Mamba3SingleSsdCaches,
26        Mamba3SingleSsdCachesConfig,
27    };
28    #[cfg(feature = "autodiff")]
29    pub use ssd::Mamba3SingleSsdAutodiffBackendExt;
30    pub use ssd::Mamba3SingleSsdBackendExt;
31    pub use ssd::Mamba3SingleSsdInput;
32}