pub trait Mamba3DoubleSsdBackendExt: Backend {
// Provided method
fn double_ssd_serial_recalculated(
v_bnlmhp: FloatTensor<Self>,
da_bnlh: FloatTensor<Self>,
b_bnlmhr: FloatTensor<Self>,
c_bnlmhr: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>) { ... }
}Expand description
Extends the backend for the memory-efficient serial recalculated SSD.
The default implementation runs K1-K5 using primitive tensor operations. Backends that support a custom memory-efficient backward (specifically the Autodiff wrapper) override this to recompute forward intermediates during the backward pass instead of saving them.
Provided Methods§
Sourcefn double_ssd_serial_recalculated(
v_bnlmhp: FloatTensor<Self>,
da_bnlh: FloatTensor<Self>,
b_bnlmhr: FloatTensor<Self>,
c_bnlmhr: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>)
fn double_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
Memory-efficient MIMO serial SSD.
§Arguments
v_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]da_bnlh:[batch, nchunks, chunk_len, nheads]— pre-combined Δ·Ab_bnlmhr:[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]c_bnlmhr:[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]initial_state_bhpr:[batch, nheads, per_head_dim, state_rank]
§Returns
y_bnlmhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety".
Implementations on Foreign Types§
Source§impl Mamba3DoubleSsdBackendExt for Dispatch
impl Mamba3DoubleSsdBackendExt for Dispatch
fn double_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
impl Mamba3DoubleSsdBackendExt for Flex
backend-flex only.Source§impl<B: Backend + Mamba3DoubleSsdBackendExt, C: CheckpointStrategy> Mamba3DoubleSsdBackendExt for Autodiff<B, C>
impl<B: Backend + Mamba3DoubleSsdBackendExt, C: CheckpointStrategy> Mamba3DoubleSsdBackendExt for Autodiff<B, C>
Source§fn double_ssd_serial_recalculated(
v_bnlmhp: FloatTensor<Self>,
da_bnlh: FloatTensor<Self>,
b_bnlmhr: FloatTensor<Self>,
c_bnlmhr: FloatTensor<Self>,
initial_state_bhpr: FloatTensor<Self>,
) -> (FloatTensor<Self>, FloatTensor<Self>)
fn double_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)
Memory-efficient combined forward+backward for the Mamba-3 MIMO SSD.
The two output tensors (y_bnlmhp, final_state_bhpr) are flattened
and concatenated into a single 1-dimensional tracked tensor so a single
Backward<B, 5> node covers both. The caller receives split+reshaped
slices of that combined tensor; burn’s autodiff accumulates their
upstream gradients back into one gradient vector before invoking
backward.