Expand description
§Mamba-3 Language Model Network
This module assembles a complete autoregressive language model from the Mamba-3 components:
tokens [B, T]
│
▼
Embedding (vocab_size → d_model)
│
▼ (×n_layers)
Mamba3Layer [Pre-LN residual block]
│
▼
RMSNorm (final normalisation)
│
▼
LM head (d_model → vocab_size)
│
▼
logits [B, T, vocab_size]§Vocabulary padding
The embedding and LM head dimensions are rounded up to the nearest
multiple of pad_vocab_size_multiple. This improves memory alignment on
GPU without exposing the extra token slots to the model (they are never
sampled from in practice).
§Tied / untied LM head
When missing_lm_head = true, the logit projection reuses the transposed
embedding weight matrix (lm_head = None, applied as a linear layer on
the fly). This halves the parameter count for the output projection and is
standard in many LLM implementations. When missing_lm_head = false, a
separate [Linear] layer is allocated.
§Two execution modes
| Method | Input shape | Use case |
|---|---|---|
Mamba3Network::forward | [B, T] | Training, prefill |
Mamba3Network::step | [B] | Autoregressive decoding |
Structs§
- Mamba3
Network - A complete Mamba-3 language model.
- Mamba3
Network Config - Configuration / factory for
Mamba3Network. - Mamba3
Network Record - The record type for the module.
- Mamba3
Network Record Item - The record item type for the module.