pub enum Mamba3SsdPath {
Minimal(Option<usize>),
Serial(Option<usize>),
SerialRecalculated(Option<usize>),
}Expand description
Ssd algorithm selection.
Each variant carries the chunk length Q for the SSD algorithm.
Larger values increase the intra-chunk GEMM work and reduce the
inter-chunk scan length.
Optimal value is approximately √(state_rank · per_head_dim).
Variants§
Minimal(Option<usize>)
Minimal SSD.
This algorithm mostly uses batched matmuls. For the backward operation, this relies on autodiff.
See [chunked_selective_scan] for more info.
For training, you may prefer using SerialRecalculated instead.
Based on /mamba_ssm/modules/ssd_minimal.py from the state-spaces/mamba github reference.
Serial(Option<usize>)
(Hybrid) Serial SSD.
This algorithm uses a serial loop over the nchunks, besides batched matmuls. For the backward operation, this relies on autodiff. For a custom backwards that saves memory, see SerialRecalculated.
Based on 5 kernels on /mamba_ssm/ops/triton/ from the state-spaces/mamba github reference:
ssd_chunk_state.py(K1, K3).ssd_bmm.py(K2).ssd_state_passing.py(K4).ssd_chunk_scan.py(K5).
SerialRecalculated(Option<usize>)
(Hybrid) Serial SSD that triggers recalculations for the backward pass.
This algorithm uses a serial loop over the nchunks, besides batched matmuls. Contains a custom backward operation that saves memory. For an autodiff backwards, see Serial.
Based on the combined kernel /mamba_ssm/ops/triton/ssd_combined.py from the state-spaces/mamba
github reference.
Implementations§
Source§impl Mamba3SsdPath
impl Mamba3SsdPath
Sourcepub fn optimal_default(state_rank: usize, per_head_dim: usize) -> usize
pub fn optimal_default(state_rank: usize, per_head_dim: usize) -> usize
Optimal chunk length is approximately √(state_rank · per_head_dim).
Sourcepub fn core_optimal(state_rank: usize, per_head_dim: usize) -> Self
pub fn core_optimal(state_rank: usize, per_head_dim: usize) -> Self
Optimal Minimal variant.
See optimal_default for more info.
Sourcepub fn core_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self
pub fn core_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self
Optimal Minimal variant.
See optimal_default for more info.
Sourcepub fn chunked_optimal(state_rank: usize, per_head_dim: usize) -> Self
pub fn chunked_optimal(state_rank: usize, per_head_dim: usize) -> Self
Optimal Serial variant.
See optimal_default for more info.
Sourcepub fn chunked_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self
pub fn chunked_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self
Optimal Serial variant.
See optimal_default for more info.
Sourcepub fn chunked_recalculated_optimal(
state_rank: usize,
per_head_dim: usize,
) -> Self
pub fn chunked_recalculated_optimal( state_rank: usize, per_head_dim: usize, ) -> Self
Optimal Serial variant.
See optimal_default for more info.
Sourcepub fn chunked_recalculated_optimal_from_block<B: Backend>(
block: &Mamba3<B>,
) -> Self
pub fn chunked_recalculated_optimal_from_block<B: Backend>( block: &Mamba3<B>, ) -> Self
Optimal Serial Recalculated variant.
See optimal_default for more info.
pub fn chunk_len(&self) -> Option<usize>
pub fn chunk_len_or_optimal( &self, state_rank: usize, per_head_dim: usize, ) -> usize
Sourcepub fn run<B: Backend + Mamba3BackendExt>(
&self,
input: Mamba3SsdInput<B>,
) -> (Tensor<B, 6>, Tensor<B, 4>)
pub fn run<B: Backend + Mamba3BackendExt>( &self, input: Mamba3SsdInput<B>, ) -> (Tensor<B, 6>, Tensor<B, 4>)
Run the SSD algorithm on the given MIMO-first input.
Dispatches to ssd_minimal, ssd_serial, or ssd_serial_recalculated based on the variant.
§Returns
y_bnlrhp:[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]final_state_bhpr:[batch, nheads, per_head_dim, state_rank]
Trait Implementations§
Source§impl Clone for Mamba3SsdPath
impl Clone for Mamba3SsdPath
Source§fn clone(&self) -> Mamba3SsdPath
fn clone(&self) -> Mamba3SsdPath
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read more