burn_mamba/mamba2/network.rs
1//! # Mamba-2 Language Model Network
2//!
3//! This module assembles a complete autoregressive language model from the
4//! Mamba-2 components:
5//!
6//! ```text
7//! tokens [B, T]
8//! │
9//! ▼
10//! Embedding (vocab_size → d_model)
11//! │
12//! ▼ (×n_layers)
13//! Mamba2Layer [Pre-LN residual block]
14//! │
15//! ▼
16//! RMSNorm (final normalisation)
17//! │
18//! ▼
19//! LM head (d_model → vocab_size)
20//! │
21//! ▼
22//! logits [B, T, vocab_size]
23//! ```
24//!
25//! ## Vocabulary padding
26//!
27//! The embedding and LM head dimensions are rounded up to the nearest
28//! multiple of `pad_vocab_size_multiple`. This improves memory alignment on
29//! GPU without exposing the extra token slots to the model (they are never
30//! sampled from in practice).
31//!
32//! ## Tied / untied LM head
33//!
34//! When `missing_lm_head = true`, the logit projection reuses the *transposed*
35//! embedding weight matrix (`lm_head = None`, applied as a linear layer on
36//! the fly). This halves the parameter count for the output projection and is
37//! standard in many LLM implementations. When `missing_lm_head = false`, a
38//! separate [`Linear`] layer is allocated.
39//!
40//! ## Two execution modes
41//!
42//! | Method | Input shape | Use case |
43//! |--------|-------------|----------|
44//! | [`Mamba2Network::forward`] | `[B, T]` | Training, prefill |
45//! | [`Mamba2Network::step`] | `[B]` | Autoregressive decoding |
46
47use crate::mamba2::prelude::*;
48use crate::schedule::Schedule;
49use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
50use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig};
51use burn::prelude::*;
52
53// ---------------------------------------------------------------------------
54// Mamba2Network
55// ---------------------------------------------------------------------------
56
57/// A complete Mamba-2 language model.
58///
59/// See the [module-level documentation](self) for an overview of the
60/// architecture and the two execution modes.
61#[derive(Module, Debug)]
62pub struct Mamba2Network<B: Backend> {
63 /// Token embedding table.
64 ///
65 /// Shape of weight matrix: `[padded_vocab_size, d_model]`.
66 /// Maps integer token IDs to `d_model`-dimensional vectors.
67 pub embedding: Embedding<B>,
68
69 /// The stack of Mamba-2 residual blocks.
70 pub layers: Mamba2Layers<B>,
71
72 /// Final layer normalisation applied after all Mamba-2 blocks and before
73 /// the LM head. This is the `norm_f` in the original implementation.
74 pub norm_f: RmsNorm<B>,
75
76 /// Optional separate LM head projection.
77 ///
78 /// - `Some(linear)` — dedicated weight matrix of shape
79 /// `[d_model, padded_vocab_size]`.
80 /// - `None` — the embedding weights are reused (transposed). This is the
81 /// "weight-tied" variant and is selected when `missing_lm_head = true`.
82 pub lm_head: Option<Linear<B>>,
83}
84
85// ---------------------------------------------------------------------------
86// Mamba2NetworkConfig
87// ---------------------------------------------------------------------------
88
89/// Configuration / factory for [`Mamba2Network`].
90#[derive(Config, Debug)]
91pub struct Mamba2NetworkConfig {
92 /// Number of real (weight-bearing) Mamba-2 layers.
93 pub n_real_layers: usize,
94
95 /// Optional virtual-layer scheduling. See [`Mamba2Layers`] for details.
96 #[config(default = "None")]
97 pub n_virtual_layers: Option<(usize, Schedule)>,
98
99 /// The *unpadded* vocabulary size as specified by the tokenizer.
100 ///
101 /// At initialisation this value is rounded up to the nearest multiple of
102 /// `pad_vocab_size_multiple` to obtain the actual embedding / logit
103 /// dimension `padded_vocab_size`.
104 pub vocab_size: usize,
105
106 /// Vocabulary size will be rounded up to a multiple of this value.
107 ///
108 /// Set to `1` to disable rounding. Common values: 8, 16, 64.
109 pub pad_vocab_size_multiple: usize,
110
111 /// Configuration shared by all Mamba-2 blocks.
112 pub mamba_block: Mamba2Config,
113
114 /// When `true`, the LM head weight is not allocated separately; instead
115 /// the transposed embedding matrix is used directly (weight tying).
116 pub missing_lm_head: bool,
117}
118
119impl Mamba2NetworkConfig {
120 /// Allocate and initialise the full network on `device`.
121 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Network<B> {
122 let padded_vocab_size = Self::padded_vocab(self.vocab_size, self.pad_vocab_size_multiple);
123
124 let layers = Mamba2LayersConfig {
125 n_real_layers: self.n_real_layers,
126 n_virtual_layers: self.n_virtual_layers.clone(),
127 mamba_block: self.mamba_block.clone(),
128 ignore_first_residual: false,
129 ignore_last_residual: false,
130 }
131 .init(device);
132
133 let lm_head = if self.missing_lm_head {
134 None
135 } else {
136 Some(
137 LinearConfig::new(self.mamba_block.d_model, padded_vocab_size)
138 .with_bias(false)
139 .init(device),
140 )
141 };
142
143 Mamba2Network {
144 embedding: EmbeddingConfig::new(padded_vocab_size, self.mamba_block.d_model)
145 .init(device),
146 layers,
147 norm_f: RmsNormConfig::new(self.mamba_block.d_model).init(device),
148 lm_head,
149 }
150 }
151
152 /// Round `vocab_size` up to the next multiple of `multiple`.
153 fn padded_vocab(vocab_size: usize, multiple: usize) -> usize {
154 if vocab_size.is_multiple_of(multiple) {
155 vocab_size
156 } else {
157 ((vocab_size / multiple) + 1) * multiple
158 }
159 }
160}
161
162// ---------------------------------------------------------------------------
163// Inference implementations
164// ---------------------------------------------------------------------------
165
166impl<B: Backend + Mamba2BackendExt> Mamba2Network<B> {
167 // -----------------------------------------------------------------------
168 // forward (full sequence — training / prefill)
169 // -----------------------------------------------------------------------
170
171 /// Process a full token sequence and return next-token logits.
172 ///
173 /// Internally this calls [`Mamba2Layers::forward`], which runs the
174 /// chunkwise SSD algorithm over every layer. This is the mode to use
175 /// during training (backpropagation through the entire sequence) and
176 /// during the prefill phase of inference.
177 ///
178 /// # Arguments
179 /// - `x` — integer token IDs, shape `[batch, sequence]`
180 /// - `caches` — optional pre-filled layer caches. Pass `None` to
181 /// start from a zero state (training) or to create fresh
182 /// caches that can be returned and reused for a subsequent
183 /// decoding step.
184 /// - `ssd_path` — SSD algorithm and chunk length selection.
185 ///
186 /// # Returns
187 /// `(logits, caches)` where:
188 /// - `logits` has shape `[batch, sequence, padded_vocab_size]`
189 /// - `caches` contains the SSM and convolution state at the end of the
190 /// sequence, ready to be passed to the first [`Self::step`] call.
191 pub fn forward(
192 &self,
193 x: Tensor<B, 2, Int>,
194 caches: Option<Mamba2Caches<B>>,
195 ssd_path: Mamba2SsdPath,
196 ) -> (Tensor<B, 3>, Mamba2Caches<B>) {
197 let [batch, sequence] = x.dims();
198 let [padded_vocab, d_model] = self.embedding.weight.dims();
199
200 // Embed token IDs → dense vectors.
201 let x_bsm = self.embedding.forward(x);
202 assert_eq!([batch, sequence, d_model], x_bsm.dims());
203
204 // Run the Mamba-2 layer stack (chunkwise SSD).
205 let (mut x_bsm, caches) = self.layers.forward(x_bsm, caches, ssd_path);
206 assert_eq!([batch, sequence, d_model], x_bsm.dims());
207
208 // Final normalisation before projection.
209 x_bsm = self.norm_f.forward(x_bsm);
210 assert_eq!([batch, sequence, d_model], x_bsm.dims());
211
212 // Project to vocabulary logits.
213 x_bsm = self.apply_lm_head(x_bsm, d_model, padded_vocab);
214 assert_eq!([batch, sequence, padded_vocab], x_bsm.dims());
215
216 (x_bsm, caches)
217 }
218
219 // -----------------------------------------------------------------------
220 // step (single token — autoregressive decoding)
221 // -----------------------------------------------------------------------
222
223 /// Process a **single** token and return next-token logits.
224 ///
225 /// Internally this calls [`Mamba2Layers::step`], which advances each
226 /// layer's recurrent state by one step:
227 ///
228 /// ```text
229 /// hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ
230 /// yₜ = Cₜᵀ hₜ + D xₜ
231 /// ```
232 ///
233 /// This is O(H·P·N) per token — independent of sequence length — and is
234 /// the correct mode for token-by-token generation after prefill.
235 ///
236 /// # Arguments
237 /// - `x` — current token IDs, shape `[batch]`
238 /// - `caches` — layer caches from the previous step (or `None` for the
239 /// very first token, which starts from a zero hidden state)
240 ///
241 /// # Returns
242 /// `(logits, caches)` where:
243 /// - `logits` has shape `[batch, padded_vocab_size]`
244 /// - `caches` contains the updated state for the **next** step.
245 pub fn step(
246 &self,
247 x: Tensor<B, 1, Int>,
248 caches: Option<Mamba2Caches<B>>,
249 ) -> (Tensor<B, 2>, Mamba2Caches<B>) {
250 let [batch] = x.dims();
251 let [padded_vocab, d_model] = self.embedding.weight.dims();
252
253 // Embed the single token. We temporarily add a sequence dimension so
254 // that the embedding module (which expects `[B, T]`) is satisfied, then
255 // immediately squeeze it out.
256 let x_b1 = x.unsqueeze_dim::<2>(1);
257 assert_eq!([batch, 1], x_b1.dims());
258
259 let x_b1m = self.embedding.forward(x_b1);
260 assert_eq!([batch, 1, d_model], x_b1m.dims());
261
262 let x_bm = x_b1m.squeeze_dim(1);
263 assert_eq!([batch, d_model], x_bm.dims());
264
265 // Advance each layer's recurrent state by one step.
266 let (mut x_bm, caches) = self.layers.step(x_bm, caches);
267 assert_eq!([batch, d_model], x_bm.dims());
268
269 // Final normalisation.
270 x_bm = self.norm_f.forward(x_bm);
271 assert_eq!([batch, d_model], x_bm.dims());
272
273 // Project to vocabulary logits.
274 // Re-use the `apply_lm_head` helper by temporarily unsqueezing the
275 // sequence dimension then squeezing it back out.
276 let x_b1m = x_bm.unsqueeze_dim(1);
277 let logits_b1v = self.apply_lm_head(x_b1m, d_model, padded_vocab);
278 assert_eq!([batch, 1, padded_vocab], logits_b1v.dims());
279
280 let logits_bv = logits_b1v.squeeze_dim(1);
281 assert_eq!([batch, padded_vocab], logits_bv.dims());
282
283 (logits_bv, caches)
284 }
285
286 // -----------------------------------------------------------------------
287 // Private helpers
288 // -----------------------------------------------------------------------
289
290 /// Apply the LM head projection to a `[batch, sequence, d_model]` tensor,
291 /// returning `[batch, sequence, padded_vocab]`.
292 ///
293 /// Uses the dedicated `lm_head` linear layer when available, or the
294 /// transposed embedding weight matrix otherwise (weight tying).
295 fn apply_lm_head(
296 &self,
297 x_bsm: Tensor<B, 3>,
298 d_model: usize,
299 padded_vocab: usize,
300 ) -> Tensor<B, 3> {
301 if let Some(lm_head) = &self.lm_head {
302 lm_head.forward(x_bsm)
303 } else {
304 // Weight-tied variant: reuse embedding.weight^T as the projection.
305 // embedding.weight has shape [padded_vocab, d_model], so we need
306 // to transpose it to [d_model, padded_vocab].
307 let weight_mv = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
308 assert_eq!([d_model, padded_vocab], weight_mv.dims());
309 let tied_linear = Linear {
310 weight: weight_mv,
311 bias: None,
312 };
313 tied_linear.forward(x_bsm)
314 }
315 }
316}