Skip to main content

Mamba3DoubleSsdBackendExt

Trait Mamba3DoubleSsdBackendExt 

Source
pub trait Mamba3DoubleSsdBackendExt: Backend {
    // Provided method
    fn double_ssd_serial_recalculated(
        v_bnlmhp: FloatTensor<Self>,
        da_bnlh: FloatTensor<Self>,
        b_bnlmhr: FloatTensor<Self>,
        c_bnlmhr: FloatTensor<Self>,
        initial_state_bhpr: FloatTensor<Self>,
    ) -> (FloatTensor<Self>, FloatTensor<Self>) { ... }
}
Expand description

Extends the backend for the memory-efficient serial recalculated SSD.

The default implementation runs K1-K5 using primitive tensor operations. Backends that support a custom memory-efficient backward (specifically the Autodiff wrapper) override this to recompute forward intermediates during the backward pass instead of saving them.

Provided Methods§

Source

fn double_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Memory-efficient MIMO serial SSD.

§Arguments
  • v_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • da_bnlh: [batch, nchunks, chunk_len, nheads] — pre-combined Δ·A
  • b_bnlmhr: [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
  • c_bnlmhr: [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]
  • initial_state_bhpr: [batch, nheads, per_head_dim, state_rank]
§Returns
  • y_bnlmhp: [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]
  • final_state_bhpr: [batch, nheads, per_head_dim, state_rank]

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety".

Implementations on Foreign Types§

Source§

impl Mamba3DoubleSsdBackendExt for Dispatch

Source§

fn double_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Source§

impl Mamba3DoubleSsdBackendExt for Flex

Available on crate feature backend-flex only.
Source§

impl<B: Backend + Mamba3DoubleSsdBackendExt, C: CheckpointStrategy> Mamba3DoubleSsdBackendExt for Autodiff<B, C>

Source§

fn double_ssd_serial_recalculated( v_bnlmhp: FloatTensor<Self>, da_bnlh: FloatTensor<Self>, b_bnlmhr: FloatTensor<Self>, c_bnlmhr: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Memory-efficient combined forward+backward for the Mamba-3 MIMO SSD.

The two output tensors (y_bnlmhp, final_state_bhpr) are flattened and concatenated into a single 1-dimensional tracked tensor so a single Backward<B, 5> node covers both. The caller receives split+reshaped slices of that combined tensor; burn’s autodiff accumulates their upstream gradients back into one gradient vector before invoking backward.

Implementors§