Skip to main content

Module network

Module network 

Source
Expand description

§Mamba-3 Language Model Network

This module assembles a complete autoregressive language model from the Mamba-3 components:

  tokens [B, T]
      │
      ▼
  Embedding  (vocab_size → d_model)
      │
      ▼  (×n_layers)
  Mamba3Layer  [Pre-LN residual block]
      │
      ▼
  RMSNorm  (final normalisation)
      │
      ▼
  LM head  (d_model → vocab_size)
      │
      ▼
  logits [B, T, vocab_size]

§Vocabulary padding

The embedding and LM head dimensions are rounded up to the nearest multiple of pad_vocab_size_multiple. This improves memory alignment on GPU without exposing the extra token slots to the model (they are never sampled from in practice).

§Tied / untied LM head

When missing_lm_head = true, the logit projection reuses the transposed embedding weight matrix (lm_head = None, applied as a linear layer on the fly). This halves the parameter count for the output projection and is standard in many LLM implementations. When missing_lm_head = false, a separate [Linear] layer is allocated.

§Two execution modes

MethodInput shapeUse case
Mamba3Network::forward[B, T]Training, prefill
Mamba3Network::step[B]Autoregressive decoding

Structs§

Mamba3Network
A complete Mamba-3 language model.
Mamba3NetworkConfig
Configuration / factory for Mamba3Network.
Mamba3NetworkRecord
The record type for the module.
Mamba3NetworkRecordItem
The record item type for the module.