pub trait Mamba3BackendExt: Backend {
// Provided method
fn ssd_serial_recalculated(
v_bnlrhp: FloatTensor<Self>,
_da_bnlh: FloatTensor<Self>,
b_bnlrhn: FloatTensor<Self>,
c_bnlrhn: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
da_cumsum_bhnl: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>) { ... }
}Expand description
Extends the backend for the memory-efficient serial recalculated SSD.
The default implementation runs K2-K5 using standard tensor operations. Backends that support a custom memory-efficient backward can override this.
Provided Methods§
Sourcefn ssd_serial_recalculated(
v_bnlrhp: FloatTensor<Self>,
_da_bnlh: FloatTensor<Self>,
b_bnlrhn: FloatTensor<Self>,
c_bnlrhn: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
da_cumsum_bhnl: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>)
fn ssd_serial_recalculated( v_bnlrhp: FloatTensor<Self>, _da_bnlh: FloatTensor<Self>, b_bnlrhn: FloatTensor<Self>, c_bnlrhn: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, da_cumsum_bhnl: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
Memory-efficient MIMO serial SSD.
§Arguments
v_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]da_bnlh:[batch, nchunks, chunk_len, nheads]— pre-combined Δ·Ab_bnlrhn:[batch, nchunks, chunk_len, R, nheads, state_rank]c_bnlrhn:[batch, nchunks, chunk_len, R, nheads, state_rank]initial_state_bhpr:[batch, nheads, per_head_dim, state_rank]da_cumsum_bhnl:[batch, nheads, nchunks, chunk_len]— pre-computed by K1
§Returns
y_bnlrhp:[batch, nchunks, chunk_len, R, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.
Implementations on Foreign Types§
impl<B: Backend + Mamba3BackendExt> Mamba3BackendExt for Autodiff<B>
Available on crate feature
autodiff only.Autodiff-wrapped backends inherit the inner backend’s Mamba3BackendExt impl.