Skip to main content

Mamba2BackendExt

Trait Mamba2BackendExt 

Source
pub trait Mamba2BackendExt: Backend {
    // Provided method
    fn ssd_serial_recalculated(
        x_bnlhp: FloatTensor<Self>,
        dt_discretized_bhnl: FloatTensor<Self>,
        b_bnlgr: FloatTensor<Self>,
        c_bnlgr: FloatTensor<Self>,
        d_h: FloatTensor<Self>,
        initial_state_bhpr: FloatTensor<Self>,
        a_decay_h: FloatTensor<Self>,
    ) -> (FloatTensor<Self>, FloatTensor<Self>) { ... }
}
Expand description

Extends the backend and wraps it for burn.

Provided Methods§

Source

fn ssd_serial_recalculated( x_bnlhp: FloatTensor<Self>, dt_discretized_bhnl: FloatTensor<Self>, b_bnlgr: FloatTensor<Self>, c_bnlgr: FloatTensor<Self>, d_h: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, a_decay_h: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Returns:

  • y_bnlhp.
  • final_state_bhpr.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl<B: Backend + Mamba2BackendExt, C: CheckpointStrategy> Mamba2BackendExt for Autodiff<B, C>

Source§

fn ssd_serial_recalculated( x_bnlhp: FloatTensor<Self>, dt_discretized_bhnl: FloatTensor<Self>, b_bnlgr: FloatTensor<Self>, c_bnlgr: FloatTensor<Self>, d_h: FloatTensor<Self>, initial_state_bhpr: FloatTensor<Self>, a_decay_h: FloatTensor<Self>, ) -> (FloatTensor<Self>, FloatTensor<Self>)

Memory-efficient combined forward+backward.

The two output tensors are concatenated into a single 1-D tracked tensor so that one Backward<B, 7> node covers both outputs. The caller receives split+reshaped slices of that combined tensor; burn’s autodiff accumulates their upstream gradients back into a single gradient vector before firing this backward.

Implementors§