pub struct Mamba2CacheConfig {
pub batch: usize,
pub state_rank: usize,
pub conv_kernel: usize,
pub conv_dim: usize,
pub per_head_dim: usize,
pub nheads: usize,
}Expand description
Configuration / factory for a single Mamba2Cache.
Fields§
§batch: usizeBatch size.
state_rank: usizeState rank N — the number of latent dimensions in the SSM hidden
state. Corresponds to state_rank in Mamba2Config.
conv_kernel: usizeCausal convolution window length. Corresponds to conv_kernel in
Mamba2Config.
conv_dim: usizeNumber of channels entering (and leaving) the depthwise convolution.
Equal to d_inner + 2 · ngroups · state_rank.
per_head_dim: usizeHead dimension P. Corresponds to per_head_dim in Mamba2Config.
nheads: usizeNumber of SSM heads H.
Implementations§
Source§impl Mamba2CacheConfig
impl Mamba2CacheConfig
Sourcepub fn new(batch: usize, conv_dim: usize, nheads: usize) -> Self
pub fn new(batch: usize, conv_dim: usize, nheads: usize) -> Self
Create a new instance of the config.
§Arguments
§Required Arguments
§batch
Batch size.
§conv_dim
Number of channels entering (and leaving) the depthwise convolution.
Equal to d_inner + 2 · ngroups · state_rank.
§nheads
Number of SSM heads H.
§Default Arguments
§state_rank
State rank N — the number of latent dimensions in the SSM hidden
state. Corresponds to state_rank in Mamba2Config.
- Defaults to
128
§conv_kernel
Causal convolution window length. Corresponds to conv_kernel in
Mamba2Config.
- Defaults to
4
§per_head_dim
Head dimension P. Corresponds to per_head_dim in Mamba2Config.
- Defaults to
64
Source§impl Mamba2CacheConfig
impl Mamba2CacheConfig
Sourcepub fn with_state_rank(self, state_rank: usize) -> Self
pub fn with_state_rank(self, state_rank: usize) -> Self
Sets the value for the field state_rank.
State rank N — the number of latent dimensions in the SSM hidden
state. Corresponds to state_rank in Mamba2Config.
- Defaults to
128
Sourcepub fn with_conv_kernel(self, conv_kernel: usize) -> Self
pub fn with_conv_kernel(self, conv_kernel: usize) -> Self
Sets the value for the field conv_kernel.
Causal convolution window length. Corresponds to conv_kernel in
Mamba2Config.
- Defaults to
4
Sourcepub fn with_per_head_dim(self, per_head_dim: usize) -> Self
pub fn with_per_head_dim(self, per_head_dim: usize) -> Self
Sets the value for the field per_head_dim.
Head dimension P. Corresponds to per_head_dim in Mamba2Config.
- Defaults to
64
Source§impl Mamba2CacheConfig
impl Mamba2CacheConfig
Sourcepub fn new_from_block_config(batch: usize, block_config: Mamba2Config) -> Self
pub fn new_from_block_config(batch: usize, block_config: Mamba2Config) -> Self
Derive cache shapes from a Mamba-2 block configuration plus a batch size.
Sourcepub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Cache<B>
pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Cache<B>
Allocate zero-initialised cache tensors on device.
Zero initialisation is correct because:
- The convolution cache represents “no previous tokens” (identity padding).
- The SSM state represents
h₀ = 0(zero initial condition), which is the standard default. Learnable initial state (if configured) are added on top of this inside [crate::mamba2::Mamba2::forward] / [crate::mamba2::Mamba2::step].