pub trait Mamba3SingleSsdBackendExt: Backend {
// Provided method
fn single_ssd_serial_recalculated(
v_bnlmhp: FloatTensor<Self>,
da_bnlh: FloatTensor<Self>,
b_bnlmhr: FloatTensor<Self>,
c_bnlmhr: FloatTensor<Self>,
gamma_bnlh: FloatTensor<Self>,
scale_bnlh: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>) { ... }
}Expand description
Extends the backend for the memory-efficient single-ssd form serial SSD.
The default implementation runs K1–K5 using primitive tensor operations, reusing the mode-agnostic K1/K2/K3/K4 from the double-SSD forward and the single-ssd form K5 below. Backends that support a custom memory-efficient backward (the Autodiff wrapper) override this to recompute forward intermediates during backward instead of saving them.
Provided Methods§
Sourcefn single_ssd_serial_recalculated(
v_bnlmhp: FloatTensor<Self>,
da_bnlh: FloatTensor<Self>,
b_bnlmhr: FloatTensor<Self>,
c_bnlmhr: FloatTensor<Self>,
gamma_bnlh: FloatTensor<Self>,
scale_bnlh: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>)
fn single_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, gamma_bnlh: FloatTensor<Self>, scale_bnlh: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
Memory-efficient MIMO single-ssd form serial SSD.
§Arguments
v_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]da_bnlh:[batch, nchunks, chunk_len, nheads]— pre-combined Δ·Ab_bnlmhr:[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]c_bnlmhr:[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]gamma_bnlh:[batch, nchunks, chunk_len, nheads]—γₜ = λₜ Δₜscale_bnlh:[batch, nchunks, chunk_len, nheads]—scaleₜ = γₜ + (1−λₜ₊₁)Δₜ₊₁initial_state_bhpr:[batch, nheads, per_head_dim, state_rank]
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".
Implementations on Foreign Types§
Source§impl Mamba3SingleSsdBackendExt for Dispatch
impl Mamba3SingleSsdBackendExt for Dispatch
fn single_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, gamma_bnlh: FloatTensor<Self>, scale_bnlh: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
impl Mamba3SingleSsdBackendExt for Flex
backend-flex only.Source§impl<B: Backend + Mamba3SingleSsdBackendExt, C: CheckpointStrategy> Mamba3SingleSsdBackendExt for Autodiff<B, C>
impl<B: Backend + Mamba3SingleSsdBackendExt, C: CheckpointStrategy> Mamba3SingleSsdBackendExt for Autodiff<B, C>
Source§fn single_ssd_serial_recalculated(
v_bnlmhp: FloatTensor<Self>,
da_bnlh: FloatTensor<Self>,
b_bnlmhr: FloatTensor<Self>,
c_bnlmhr: FloatTensor<Self>,
gamma_bnlh: FloatTensor<Self>,
scale_bnlh: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>)
fn single_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, gamma_bnlh: FloatTensor<Self>, scale_bnlh: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
Memory-efficient combined forward+backward for the Mamba-3 MIMO Single-SSD.
The two outputs (y_bnlmhp, final_state_bhpr) are flattened and
concatenated into a single 1-D tracked tensor so one Backward<B, 7>
node covers both. The seven differentiable inputs are v, da, b, c, gamma, scale, initial_state.