Skip to main content

Mamba3SingleSsdBackendExt

Trait Mamba3SingleSsdBackendExt 

Source
pub trait Mamba3SingleSsdBackendExt: Backend {
    // Provided method
    fn single_ssd_serial_recalculated(
        v_bnlmhp: FloatTensor<Self>,
        da_bnlh: FloatTensor<Self>,
        b_bnlmhr: FloatTensor<Self>,
        c_bnlmhr: FloatTensor<Self>,
        gamma_bnlh: FloatTensor<Self>,
        scale_bnlh: FloatTensor<Self>,
        initial_state_bhpr: FloatTensor<Self>,
    ) -> (FloatTensor<Self>, FloatTensor<Self>) { ... }
}
Expand description

Extends the backend for the memory-efficient single-ssd form serial SSD.

The default implementation runs K1–K5 using primitive tensor operations, reusing the mode-agnostic K1/K2/K3/K4 from the double-SSD forward and the single-ssd form K5 below. Backends that support a custom memory-efficient backward (the Autodiff wrapper) override this to recompute forward intermediates during backward instead of saving them.

Provided Methods§

Source

fn single_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, gamma_bnlh: FloatTensor<Self>, scale_bnlh: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Memory-efficient MIMO single-ssd form serial SSD.

§Arguments
  • v_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • da_bnlh: [batch, nchunks, chunk_len, nheads] — pre-combined Δ·A
  • b_bnlmhr: [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
  • c_bnlmhr: [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
  • gamma_bnlh: [batch, nchunks, chunk_len, nheads]γₜ = λₜ Δₜ
  • scale_bnlh: [batch, nchunks, chunk_len, nheads]scaleₜ = γₜ + (1−λₜ₊₁)Δₜ₊₁
  • initial_state_bhpr: [batch, nheads, per_head_dim, state_rank]
§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementations on Foreign Types§

Source§

impl Mamba3SingleSsdBackendExt for Dispatch

Source§

fn single_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, gamma_bnlh: FloatTensor<Self>, scale_bnlh: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Source§

impl Mamba3SingleSsdBackendExt for Flex

Available on crate feature backend-flex only.
Source§

impl<B: Backend + Mamba3SingleSsdBackendExt, C: CheckpointStrategy> Mamba3SingleSsdBackendExt for Autodiff<B, C>

Source§

fn single_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, gamma_bnlh: FloatTensor<Self>, scale_bnlh: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Memory-efficient combined forward+backward for the Mamba-3 MIMO Single-SSD.

The two outputs (y_bnlmhp, final_state_bhpr) are flattened and concatenated into a single 1-D tracked tensor so one Backward<B, 7> node covers both. The seven differentiable inputs are v, da, b, c, gamma, scale, initial_state.

Implementors§