pub struct Mamba3Network<B: Backend> {
pub embedding: Embedding<B>,
pub layers: Mamba3Layers<B>,
pub norm_f: RmsNorm<B>,
pub lm_head: Option<Linear<B>>,
}Expand description
A complete Mamba-3 language model.
See the module-level documentation for an overview of the architecture and the two execution modes.
Fields§
§embedding: Embedding<B>Token embedding table.
Shape of weight matrix: [padded_vocab_size, d_model].
Maps integer token IDs to d_model-dimensional vectors.
layers: Mamba3Layers<B>The stack of Mamba-3 residual blocks.
norm_f: RmsNorm<B>Final layer normalisation applied after all Mamba-3 blocks and before
the LM head. This is the norm_f in the original implementation.
lm_head: Option<Linear<B>>Optional separate LM head projection.
Some(linear)— dedicated weight matrix of shape[d_model, padded_vocab_size].None— the embedding weights are reused (transposed). This is the “weight-tied” variant and is selected whenmissing_lm_head = true.
Implementations§
Source§impl<B: Backend + Mamba3BackendExt> Mamba3Network<B>
impl<B: Backend + Mamba3BackendExt> Mamba3Network<B>
Sourcepub fn forward(
&self,
x: Tensor<B, 2, Int>,
caches: Option<Mamba3Caches<B>>,
ssd_path: Mamba3SsdPath,
) -> (Tensor<B, 3>, Mamba3Caches<B>)
pub fn forward( &self, x: Tensor<B, 2, Int>, caches: Option<Mamba3Caches<B>>, ssd_path: Mamba3SsdPath, ) -> (Tensor<B, 3>, Mamba3Caches<B>)
Process a full token sequence and return next-token logits.
Internally this calls Mamba3Layers::forward, which runs the
chunkwise SSD algorithm over every layer. This is the mode to use
during training (backpropagation through the entire sequence) and
during the prefill phase of inference.
§Arguments
x— integer token IDs, shape[batch, sequence]caches— optional pre-filled layer caches. PassNoneto start from a zero state (training) or to create fresh caches that can be returned and reused for a subsequent decoding step.ssd_path— SSD algorithm and chunk length selection.
§Returns
(logits, caches) where:
logitshas shape[batch, sequence, padded_vocab_size]cachescontains the SSM and convolution state at the end of the sequence, ready to be passed to the firstSelf::stepcall.
Sourcepub fn step(
&self,
x: Tensor<B, 1, Int>,
caches: Option<Mamba3Caches<B>>,
) -> (Tensor<B, 2>, Mamba3Caches<B>)
pub fn step( &self, x: Tensor<B, 1, Int>, caches: Option<Mamba3Caches<B>>, ) -> (Tensor<B, 2>, Mamba3Caches<B>)
Process a single token and return next-token logits.
Internally this calls Mamba3Layers::step, which advances each
layer’s recurrent state by one step:
hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ
yₜ = Cₜᵀ hₜ + D xₜThis is O(H·P·N) per token — independent of sequence length — and is the correct mode for token-by-token generation after prefill.
§Arguments
x— current token IDs, shape[batch]caches— layer caches from the previous step (orNonefor the very first token, which starts from a zero hidden state)
§Returns
(logits, caches) where:
logitshas shape[batch, padded_vocab_size]cachescontains the updated state for the next step.