pub struct Mamba2Layers<B: Backend> {
pub n_real_layers: usize,
pub n_virtual_layers: Option<(usize, Schedule)>,
pub real_layers: Vec<Mamba2Layer<B>>,
pub ignore_first_residual: bool,
pub ignore_last_residual: bool,
}Expand description
A stack of Mamba-2 layers with optional virtual-layer scheduling.
The stack maintains n_real_layers distinct weight sets but can execute
n_virtual_layers logical forward passes, cycling through weights
according to the provided Schedule.
Fields§
§n_real_layers: usizeNumber of real (weight-bearing) layers.
n_virtual_layers: Option<(usize, Schedule)>Optional (n_virtual_layers, schedule) for weight-sharing.
When None, the virtual layer count falls back to n_real_layers (no
sharing). Marked module(skip) so Burn does not treat it as a
trainable parameter.
real_layers: Vec<Mamba2Layer<B>>The actual weight-bearing layer instances.
Length: n_real_layers.
ignore_first_residual: boolWhen true, the residual connection of the first virtual layer is
scaled to zero (i.e. the first block acts as a pure projection, not a
residual update).
ignore_last_residual: boolWhen true, the residual connection of the last virtual layer is
scaled to zero.
Implementations§
Source§impl<B: Backend + Mamba2BackendExt> Mamba2Layers<B>
impl<B: Backend + Mamba2BackendExt> Mamba2Layers<B>
Sourcepub fn forward(
&self,
x: Tensor<B, 3>,
caches: Option<Mamba2Caches<B>>,
ssd_path: Mamba2SsdPath,
) -> (Tensor<B, 3>, Mamba2Caches<B>)
pub fn forward( &self, x: Tensor<B, 3>, caches: Option<Mamba2Caches<B>>, ssd_path: Mamba2SsdPath, ) -> (Tensor<B, 3>, Mamba2Caches<B>)
Process a full sequence through every (virtual) layer.
Internally each layer calls Mamba2::forward, which runs the
chunkwise SSD algorithm. This is efficient for training because the
intra-chunk products can exploit GEMM / tensor cores.
If caches is None, zero-initialised caches are created automatically.
§Arguments
x— input tensor, shape[batch, sequence, d_model]caches— optional pre-filled layer caches (useful for prefill followed by decode)ssd_path— SSD algorithm and chunk length selection.
§Returns
(output, updated_caches) where output has shape
[batch, sequence, d_model].
Sourcepub fn step(
&self,
x: Tensor<B, 2>,
caches: Option<Mamba2Caches<B>>,
) -> (Tensor<B, 2>, Mamba2Caches<B>)
pub fn step( &self, x: Tensor<B, 2>, caches: Option<Mamba2Caches<B>>, ) -> (Tensor<B, 2>, Mamba2Caches<B>)
Process a single token through every (virtual) layer.
Each layer calls Mamba2::step, which runs one tick of the recurrent
SSM: hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ, yₜ = Cₜᵀ hₜ + D xₜ.
This is O(H·P·N) per step — independent of sequence length — and
requires no KV-cache.
§Arguments
x— current token embedding, shape[batch, d_model]caches— layer caches from the previous step (orNonefor the first token, in which case zero caches are created)
§Returns
(output, updated_caches) where output has shape [batch, d_model].