Skip to main content

burn_mamba/mamba2/
network.rs

1//! # Mamba-2 Language Model Network
2//!
3//! This module assembles a complete autoregressive language model from the
4//! Mamba-2 components:
5//!
6//! ```text
7//!   tokens [B, T]
8//!       │
9//!       ▼
10//!   Embedding  (vocab_size → d_model)
11//!       │
12//!       ▼  (×n_layers)
13//!   Mamba2Layer  [Pre-LN residual block]
14//!       │
15//!       ▼
16//!   RMSNorm  (final normalisation)
17//!       │
18//!       ▼
19//!   LM head  (d_model → vocab_size)
20//!       │
21//!       ▼
22//!   logits [B, T, vocab_size]
23//! ```
24//!
25//! ## Vocabulary padding
26//!
27//! The embedding and LM head dimensions are rounded up to the nearest
28//! multiple of `pad_vocab_size_multiple`.  This improves memory alignment on
29//! GPU without exposing the extra token slots to the model (they are never
30//! sampled from in practice).
31//!
32//! ## Tied / untied LM head
33//!
34//! When `missing_lm_head = true`, the logit projection reuses the *transposed*
35//! embedding weight matrix (`lm_head = None`, applied as a linear layer on
36//! the fly).  This halves the parameter count for the output projection and is
37//! standard in many LLM implementations.  When `missing_lm_head = false`, a
38//! separate [`Linear`] layer is allocated.
39//!
40//! ## Two execution modes
41//!
42//! | Method | Input shape | Use case |
43//! |--------|-------------|----------|
44//! | [`Mamba2Network::forward`] | `[B, T]` | Training, prefill |
45//! | [`Mamba2Network::step`]    | `[B]`    | Autoregressive decoding |
46
47use crate::mamba2::prelude::*;
48use crate::schedule::Schedule;
49use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
50use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig};
51use burn::prelude::*;
52
53// ---------------------------------------------------------------------------
54// Mamba2Network
55// ---------------------------------------------------------------------------
56
57/// A complete Mamba-2 language model.
58///
59/// See the [module-level documentation](self) for an overview of the
60/// architecture and the two execution modes.
61#[derive(Module, Debug)]
62pub struct Mamba2Network<B: Backend> {
63    /// Token embedding table.
64    ///
65    /// Shape of weight matrix: `[padded_vocab_size, d_model]`.
66    /// Maps integer token IDs to `d_model`-dimensional vectors.
67    pub embedding: Embedding<B>,
68
69    /// The stack of Mamba-2 residual blocks.
70    pub layers: Mamba2Layers<B>,
71
72    /// Final layer normalisation applied after all Mamba-2 blocks and before
73    /// the LM head.  This is the `norm_f` in the original implementation.
74    pub norm_f: RmsNorm<B>,
75
76    /// Optional separate LM head projection.
77    ///
78    /// - `Some(linear)` — dedicated weight matrix of shape
79    ///   `[d_model, padded_vocab_size]`.
80    /// - `None` — the embedding weights are reused (transposed).  This is the
81    ///   "weight-tied" variant and is selected when `missing_lm_head = true`.
82    pub lm_head: Option<Linear<B>>,
83}
84
85// ---------------------------------------------------------------------------
86// Mamba2NetworkConfig
87// ---------------------------------------------------------------------------
88
89/// Configuration / factory for [`Mamba2Network`].
90#[derive(Config, Debug)]
91pub struct Mamba2NetworkConfig {
92    /// Number of real (weight-bearing) Mamba-2 layers.
93    pub n_real_layers: usize,
94
95    /// Optional virtual-layer scheduling.  See [`Mamba2Layers`] for details.
96    #[config(default = "None")]
97    pub n_virtual_layers: Option<(usize, Schedule)>,
98
99    /// The *unpadded* vocabulary size as specified by the tokenizer.
100    ///
101    /// At initialisation this value is rounded up to the nearest multiple of
102    /// `pad_vocab_size_multiple` to obtain the actual embedding / logit
103    /// dimension `padded_vocab_size`.
104    pub vocab_size: usize,
105
106    /// Vocabulary size will be rounded up to a multiple of this value.
107    ///
108    /// Set to `1` to disable rounding.  Common values: 8, 16, 64.
109    pub pad_vocab_size_multiple: usize,
110
111    /// Configuration shared by all Mamba-2 blocks.
112    pub mamba_block: Mamba2Config,
113
114    /// When `true`, the LM head weight is not allocated separately; instead
115    /// the transposed embedding matrix is used directly (weight tying).
116    pub missing_lm_head: bool,
117}
118
119impl Mamba2NetworkConfig {
120    /// Allocate and initialise the full network on `device`.
121    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Network<B> {
122        let padded_vocab_size = Self::padded_vocab(self.vocab_size, self.pad_vocab_size_multiple);
123
124        let layers = Mamba2LayersConfig {
125            n_real_layers: self.n_real_layers,
126            n_virtual_layers: self.n_virtual_layers.clone(),
127            mamba_block: self.mamba_block.clone(),
128            ignore_first_residual: false,
129            ignore_last_residual: false,
130        }
131        .init(device);
132
133        let lm_head = if self.missing_lm_head {
134            None
135        } else {
136            Some(
137                LinearConfig::new(self.mamba_block.d_model, padded_vocab_size)
138                    .with_bias(false)
139                    .init(device),
140            )
141        };
142
143        Mamba2Network {
144            embedding: EmbeddingConfig::new(padded_vocab_size, self.mamba_block.d_model)
145                .init(device),
146            layers,
147            norm_f: RmsNormConfig::new(self.mamba_block.d_model).init(device),
148            lm_head,
149        }
150    }
151
152    /// Round `vocab_size` up to the next multiple of `multiple`.
153    fn padded_vocab(vocab_size: usize, multiple: usize) -> usize {
154        if vocab_size.is_multiple_of(multiple) {
155            vocab_size
156        } else {
157            ((vocab_size / multiple) + 1) * multiple
158        }
159    }
160}
161
162// ---------------------------------------------------------------------------
163// Inference implementations
164// ---------------------------------------------------------------------------
165
166impl<B: Backend + Mamba2BackendExt> Mamba2Network<B> {
167    // -----------------------------------------------------------------------
168    // forward  (full sequence — training / prefill)
169    // -----------------------------------------------------------------------
170
171    /// Process a full token sequence and return next-token logits.
172    ///
173    /// Internally this calls [`Mamba2Layers::forward`], which runs the
174    /// chunkwise SSD algorithm over every layer. This is the mode to use
175    /// during training (backpropagation through the entire sequence) and
176    /// during the prefill phase of inference.
177    ///
178    /// # Arguments
179    /// - `x` — integer token IDs, shape `[batch, sequence]`
180    /// - `caches` — optional pre-filled layer caches.  Pass `None` to
181    ///   start from a zero state (training) or to create fresh
182    ///   caches that can be returned and reused for a subsequent
183    ///   decoding step.
184    /// - `ssd_path` — SSD algorithm and chunk length selection.
185    ///
186    /// # Returns
187    /// `(logits, caches)` where:
188    /// - `logits` has shape `[batch, sequence, padded_vocab_size]`
189    /// - `caches` contains the SSM and convolution state at the end of the
190    ///   sequence, ready to be passed to the first [`Self::step`] call.
191    pub fn forward(
192        &self,
193        x: Tensor<B, 2, Int>,
194        caches: Option<Mamba2Caches<B>>,
195        ssd_path: Mamba2SsdPath,
196    ) -> (Tensor<B, 3>, Mamba2Caches<B>) {
197        let [batch, sequence] = x.dims();
198        let [padded_vocab, d_model] = self.embedding.weight.dims();
199
200        // Embed token IDs → dense vectors.
201        let x_bsm = self.embedding.forward(x);
202        assert_eq!([batch, sequence, d_model], x_bsm.dims());
203
204        // Run the Mamba-2 layer stack (chunkwise SSD).
205        let (mut x_bsm, caches) = self.layers.forward(x_bsm, caches, ssd_path);
206        assert_eq!([batch, sequence, d_model], x_bsm.dims());
207
208        // Final normalisation before projection.
209        x_bsm = self.norm_f.forward(x_bsm);
210        assert_eq!([batch, sequence, d_model], x_bsm.dims());
211
212        // Project to vocabulary logits.
213        x_bsm = self.apply_lm_head(x_bsm, d_model, padded_vocab);
214        assert_eq!([batch, sequence, padded_vocab], x_bsm.dims());
215
216        (x_bsm, caches)
217    }
218
219    // -----------------------------------------------------------------------
220    // step  (single token — autoregressive decoding)
221    // -----------------------------------------------------------------------
222
223    /// Process a **single** token and return next-token logits.
224    ///
225    /// Internally this calls [`Mamba2Layers::step`], which advances each
226    /// layer's recurrent state by one step:
227    ///
228    /// ```text
229    ///   hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ
230    ///   yₜ = Cₜᵀ hₜ + D xₜ
231    /// ```
232    ///
233    /// This is O(H·P·N) per token — independent of sequence length — and is
234    /// the correct mode for token-by-token generation after prefill.
235    ///
236    /// # Arguments
237    /// - `x` — current token IDs, shape `[batch]`
238    /// - `caches` — layer caches from the previous step (or `None` for the
239    ///   very first token, which starts from a zero hidden state)
240    ///
241    /// # Returns
242    /// `(logits, caches)` where:
243    /// - `logits` has shape `[batch, padded_vocab_size]`
244    /// - `caches` contains the updated state for the **next** step.
245    pub fn step(
246        &self,
247        x: Tensor<B, 1, Int>,
248        caches: Option<Mamba2Caches<B>>,
249    ) -> (Tensor<B, 2>, Mamba2Caches<B>) {
250        let [batch] = x.dims();
251        let [padded_vocab, d_model] = self.embedding.weight.dims();
252
253        // Embed the single token.  We temporarily add a sequence dimension so
254        // that the embedding module (which expects `[B, T]`) is satisfied, then
255        // immediately squeeze it out.
256        let x_b1 = x.unsqueeze_dim::<2>(1);
257        assert_eq!([batch, 1], x_b1.dims());
258
259        let x_b1m = self.embedding.forward(x_b1);
260        assert_eq!([batch, 1, d_model], x_b1m.dims());
261
262        let x_bm = x_b1m.squeeze_dim(1);
263        assert_eq!([batch, d_model], x_bm.dims());
264
265        // Advance each layer's recurrent state by one step.
266        let (mut x_bm, caches) = self.layers.step(x_bm, caches);
267        assert_eq!([batch, d_model], x_bm.dims());
268
269        // Final normalisation.
270        x_bm = self.norm_f.forward(x_bm);
271        assert_eq!([batch, d_model], x_bm.dims());
272
273        // Project to vocabulary logits.
274        // Re-use the `apply_lm_head` helper by temporarily unsqueezing the
275        // sequence dimension then squeezing it back out.
276        let x_b1m = x_bm.unsqueeze_dim(1);
277        let logits_b1v = self.apply_lm_head(x_b1m, d_model, padded_vocab);
278        assert_eq!([batch, 1, padded_vocab], logits_b1v.dims());
279
280        let logits_bv = logits_b1v.squeeze_dim(1);
281        assert_eq!([batch, padded_vocab], logits_bv.dims());
282
283        (logits_bv, caches)
284    }
285
286    // -----------------------------------------------------------------------
287    // Private helpers
288    // -----------------------------------------------------------------------
289
290    /// Apply the LM head projection to a `[batch, sequence, d_model]` tensor,
291    /// returning `[batch, sequence, padded_vocab]`.
292    ///
293    /// Uses the dedicated `lm_head` linear layer when available, or the
294    /// transposed embedding weight matrix otherwise (weight tying).
295    fn apply_lm_head(
296        &self,
297        x_bsm: Tensor<B, 3>,
298        d_model: usize,
299        padded_vocab: usize,
300    ) -> Tensor<B, 3> {
301        if let Some(lm_head) = &self.lm_head {
302            lm_head.forward(x_bsm)
303        } else {
304            // Weight-tied variant: reuse embedding.weight^T as the projection.
305            // embedding.weight has shape [padded_vocab, d_model], so we need
306            // to transpose it to [d_model, padded_vocab].
307            let weight_mv = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
308            assert_eq!([d_model, padded_vocab], weight_mv.dims());
309            let tied_linear = Linear {
310                weight: weight_mv,
311                bias: None,
312            };
313            tied_linear.forward(x_bsm)
314        }
315    }
316}