burn_mamba/mamba3/layer.rs
1//! # Mamba-3 Layer and Layer Stack
2//!
3//! A **Mamba-3 layer** is the standard Pre-LN residual block used throughout
4//! the network. It wraps a single [`Mamba3`] SSM block with an RMSNorm
5//! (applied to the input, *before* the block) and adds the input back as a
6//! residual connection:
7//!
8//! ```text
9//! y = x + Mamba3( RMSNorm(x) )
10//! ```
11//!
12//! This matches the architecture described in §5 of the Mamba-2 paper and is
13//! identical in structure to Pre-LN Transformer layers.
14//!
15//! ## Virtual layers
16//!
17//! [`Mamba3Layers`] supports *virtual layers*: a larger logical depth achieved
18//! by cycling through a smaller set of *real* (weight-bearing) layers
19//! according to a [`Schedule`]. For example, 48 virtual layers over 12 real
20//! layers repeats each weight set 4 times. Each virtual layer still has its
21//! **own cache** (the hidden state evolves independently), but shares the
22//! underlying parameters.
23//!
24//! ## Residual scale
25//!
26//! The first and/or last residual connection in the stack can optionally be
27//! zeroed out (`ignore_first_residual` / `ignore_last_residual`), which is
28//! useful when composing Mamba-3 blocks with other module types (e.g. in a
29//! hybrid Mamba-3 + attention architecture where neighbouring blocks already
30//! carry residuals).
31
32use crate::mamba3::prelude::*;
33use crate::schedule::Schedule;
34use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
35use burn::prelude::*;
36
37// ---------------------------------------------------------------------------
38// Mamba3Layers (the full layer stack)
39// ---------------------------------------------------------------------------
40
41/// A stack of Mamba-3 layers with optional virtual-layer scheduling.
42///
43/// The stack maintains `n_real_layers` distinct weight sets but can execute
44/// `n_virtual_layers` logical forward passes, cycling through weights
45/// according to the provided [`Schedule`].
46#[derive(Module, Debug)]
47pub struct Mamba3Layers<B: Backend> {
48 /// Number of real (weight-bearing) layers.
49 pub n_real_layers: usize,
50
51 /// Optional `(n_virtual_layers, schedule)` for weight-sharing.
52 ///
53 /// When `None`, the virtual layer count falls back to `n_real_layers` (no
54 /// sharing). Marked `module(skip)` so Burn does not treat it as a
55 /// trainable parameter.
56 #[module(skip)]
57 pub n_virtual_layers: Option<(usize, Schedule)>,
58
59 /// The actual weight-bearing layer instances.
60 ///
61 /// Length: `n_real_layers`.
62 pub real_layers: Vec<Mamba3Layer<B>>,
63
64 /// When `true`, the residual connection of the **first** virtual layer is
65 /// scaled to zero (i.e. the first block acts as a pure projection, not a
66 /// residual update).
67 pub ignore_first_residual: bool,
68
69 /// When `true`, the residual connection of the **last** virtual layer is
70 /// scaled to zero.
71 pub ignore_last_residual: bool,
72}
73
74/// Configuration / factory for [`Mamba3Layers`].
75#[derive(Config, Debug)]
76pub struct Mamba3LayersConfig {
77 /// Number of distinct weight sets to allocate.
78 pub n_real_layers: usize,
79
80 /// Optional virtual-layer scheduling. See [`Mamba3Layers`] for details.
81 #[config(default = "None")]
82 pub n_virtual_layers: Option<(usize, Schedule)>,
83
84 /// Configuration shared by all Mamba-3 blocks in the stack.
85 pub mamba_block: Mamba3Config,
86
87 /// See [`Mamba3Layers::ignore_first_residual`].
88 #[config(default = false)]
89 pub ignore_first_residual: bool,
90
91 /// See [`Mamba3Layers::ignore_last_residual`].
92 #[config(default = false)]
93 pub ignore_last_residual: bool,
94}
95
96impl Mamba3LayersConfig {
97 /// Allocate and initialise all layers on `device`.
98 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Layers<B> {
99 let real_layers = (0..self.n_real_layers)
100 .map(|_| Mamba3LayerConfig::new(self.mamba_block.clone()).init(device))
101 .collect();
102
103 Mamba3Layers {
104 n_real_layers: self.n_real_layers,
105 n_virtual_layers: self.n_virtual_layers.clone(),
106 real_layers,
107 ignore_first_residual: self.ignore_first_residual,
108 ignore_last_residual: self.ignore_last_residual,
109 }
110 }
111}
112
113impl<B: Backend + Mamba3BackendExt> Mamba3Layers<B> {
114 // -----------------------------------------------------------------------
115 // forward (chunked SSD — used for training / prefill)
116 // -----------------------------------------------------------------------
117
118 /// Process a full sequence through every (virtual) layer.
119 ///
120 /// Internally each layer calls [`Mamba3::forward`], which runs the
121 /// chunkwise SSD algorithm. This is efficient for training because the
122 /// intra-chunk products can exploit GEMM / tensor cores.
123 ///
124 /// If `caches` is `None`, zero-initialised caches are created automatically.
125 ///
126 /// # Arguments
127 /// - `x` — input tensor, shape `[batch, sequence, d_model]`
128 /// - `caches` — optional pre-filled layer caches (useful for prefill
129 /// followed by decode)
130 /// - `ssd_path` — SSD algorithm and chunk length selection.
131 ///
132 /// # Returns
133 /// `(output, updated_caches)` where `output` has shape
134 /// `[batch, sequence, d_model]`.
135 pub fn forward(
136 &self,
137 mut x: Tensor<B, 3>,
138 caches: Option<Mamba3Caches<B>>,
139 ssd_path: Mamba3SsdPath,
140 ) -> (Tensor<B, 3>, Mamba3Caches<B>) {
141 // The effective number of forward passes equals the number of *virtual*
142 // layers. When no scheduling is configured this equals n_real_layers.
143 let n_virtual_layers = self.n_virtual_count();
144
145 // Lazily allocate zero caches the first time (e.g. during training or
146 // the first prefill call).
147 let caches = caches.unwrap_or_else(|| self.make_zero_caches(&x, n_virtual_layers));
148
149 assert_eq!(
150 caches.caches.len(),
151 n_virtual_layers,
152 "cache count must match the number of virtual layers; \
153 layers in forward() cannot share caches"
154 );
155
156 // Unwrap each cache slot into an `Option` so we can `take` it in the
157 // loop without cloning (Burn tensors are reference-counted).
158 let mut caches: Vec<Option<Mamba3Cache<B>>> = caches.caches.into_iter().map(Some).collect();
159
160 #[allow(clippy::needless_range_loop)]
161 for i in 0..n_virtual_layers {
162 // Map virtual layer index → real (weight-bearing) layer index.
163 let layer_idx = self.real_idx(i);
164 let layer = &self.real_layers[layer_idx];
165
166 // The residual scale is 0.0 for the first/last layer if the
167 // corresponding `ignore_*_residual` flag is set, and 1.0 otherwise.
168 let residual_scale = self.residual_scale(i, n_virtual_layers);
169
170 let cache = caches[i].take().unwrap();
171 let (x_, cache_) = layer.forward(x, Some(cache), ssd_path.clone(), residual_scale);
172 x = x_;
173 caches[i] = Some(cache_);
174 }
175
176 let caches = Mamba3Caches {
177 caches: caches.into_iter().map(Option::unwrap).collect(),
178 };
179 (x, caches)
180 }
181
182 // -----------------------------------------------------------------------
183 // step (recurrent SSM — used for autoregressive decoding)
184 // -----------------------------------------------------------------------
185
186 /// Process a **single token** through every (virtual) layer.
187 ///
188 /// Each layer calls [`Mamba3::step`], which runs one tick of the recurrent
189 /// SSM: `hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ`, `yₜ = Cₜᵀ hₜ + D xₜ`.
190 /// This is O(H·P·N) per step — independent of sequence length — and
191 /// requires no KV-cache.
192 ///
193 /// # Arguments
194 /// - `x` — current token embedding, shape `[batch, d_model]`
195 /// - `caches` — layer caches from the previous step (or `None` for the
196 /// first token, in which case zero caches are created)
197 ///
198 /// # Returns
199 /// `(output, updated_caches)` where `output` has shape `[batch, d_model]`.
200 pub fn step(
201 &self,
202 mut x: Tensor<B, 2>,
203 caches: Option<Mamba3Caches<B>>,
204 ) -> (Tensor<B, 2>, Mamba3Caches<B>) {
205 let n_virtual_layers = self.n_virtual_count();
206 let caches = caches.unwrap_or_else(|| self.make_zero_caches_2d(&x, n_virtual_layers));
207
208 assert_eq!(
209 caches.caches.len(),
210 n_virtual_layers,
211 "cache count must match the number of virtual layers; \
212 layers in step() cannot share caches"
213 );
214
215 let mut caches: Vec<Option<Mamba3Cache<B>>> = caches.caches.into_iter().map(Some).collect();
216
217 #[allow(clippy::needless_range_loop)]
218 for i in 0..n_virtual_layers {
219 let layer_idx = self.real_idx(i);
220 let layer = &self.real_layers[layer_idx];
221 let residual_scale = self.residual_scale(i, n_virtual_layers);
222
223 let cache = caches[i].take().unwrap();
224 let (x_, cache_) = layer.step(x, Some(cache), residual_scale);
225 x = x_;
226 caches[i] = Some(cache_);
227 }
228
229 let caches = Mamba3Caches {
230 caches: caches.into_iter().map(Option::unwrap).collect(),
231 };
232 (x, caches)
233 }
234
235 // -----------------------------------------------------------------------
236 // Private helpers
237 // -----------------------------------------------------------------------
238
239 /// Effective number of forward passes (virtual layers).
240 fn n_virtual_count(&self) -> usize {
241 self.n_virtual_layers
242 .as_ref()
243 .map(|(l, _)| *l)
244 .unwrap_or(self.n_real_layers)
245 }
246
247 /// Map a virtual layer index to the corresponding real layer index using
248 /// the configured schedule (or identity when no schedule is set).
249 fn real_idx(&self, virtual_idx: usize) -> usize {
250 if let Some((n_virtual_layers, schedule)) = &self.n_virtual_layers {
251 schedule.real_idx(virtual_idx, *n_virtual_layers, self.n_real_layers)
252 } else {
253 virtual_idx
254 }
255 }
256
257 /// Returns 0.0 if this layer's residual should be suppressed, else 1.0.
258 fn residual_scale(&self, i: usize, n_virtual: usize) -> f32 {
259 let is_first = self.ignore_first_residual && i == 0;
260 let is_last = self.ignore_last_residual && i + 1 == n_virtual;
261 if is_first || is_last { 0.0 } else { 1.0 }
262 }
263
264 /// Build zero-initialised caches from a 3-D input tensor `[B, S, D]`.
265 fn make_zero_caches(&self, x: &Tensor<B, 3>, n_virtual: usize) -> Mamba3Caches<B> {
266 let device = &x.device();
267 let [batch, _sequence, _d_model] = x.dims();
268 let layer0 = &self.real_layers[0].mamba_block;
269
270 Mamba3CachesConfig::new(
271 n_virtual,
272 Mamba3CacheConfig {
273 batch,
274 state_rank: layer0.state_rank,
275 num_rope_angles: layer0.num_rope_angles,
276 per_head_dim: layer0.per_head_dim(),
277 nheads: layer0.nheads(),
278 mimo_rank: layer0.mimo_rank,
279 },
280 )
281 .init(device)
282 }
283
284 /// Build zero-initialised caches from a 2-D input tensor `[B, D]`.
285 fn make_zero_caches_2d(&self, x: &Tensor<B, 2>, n_virtual: usize) -> Mamba3Caches<B> {
286 let device = &x.device();
287 let [batch, _d_model] = x.dims();
288 let layer0 = &self.real_layers[0].mamba_block;
289
290 Mamba3CachesConfig::new(
291 n_virtual,
292 Mamba3CacheConfig {
293 batch,
294 state_rank: layer0.state_rank,
295 num_rope_angles: layer0.num_rope_angles,
296 per_head_dim: layer0.per_head_dim(),
297 nheads: layer0.nheads(),
298 mimo_rank: layer0.mimo_rank,
299 },
300 )
301 .init(device)
302 }
303}
304
305// ---------------------------------------------------------------------------
306// Mamba3Layer (single Pre-LN residual block)
307// ---------------------------------------------------------------------------
308
309/// A single Mamba-3 residual block:
310///
311/// ```text
312/// output = x·scale + Mamba3( RMSNorm(x) )
313/// ```
314///
315/// where `scale` is 1.0 normally and 0.0 when the residual connection is
316/// intentionally suppressed by the layer stack configuration.
317#[derive(Module, Debug)]
318pub struct Mamba3Layer<B: Backend> {
319 /// Pre-norm applied to the input before the SSM block.
320 ///
321 /// Using RMSNorm *before* the block (Pre-LN) is standard practice in
322 /// modern LLMs and improves training stability.
323 pub norm: RmsNorm<B>,
324
325 /// The Mamba-3 SSM block (see [`Mamba3`]).
326 pub mamba_block: Mamba3<B>,
327}
328
329/// Configuration / factory for [`Mamba3Layer`].
330#[derive(Config, Debug)]
331pub struct Mamba3LayerConfig {
332 /// Configuration for the inner Mamba-3 block.
333 pub mamba_block: Mamba3Config,
334}
335
336impl Mamba3LayerConfig {
337 /// Allocate and initialise the layer on `device`.
338 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Layer<B> {
339 Mamba3Layer {
340 norm: RmsNormConfig::new(self.mamba_block.d_model).init(device),
341 mamba_block: self.mamba_block.init(device),
342 }
343 }
344}
345
346impl<B: Backend + Mamba3BackendExt> Mamba3Layer<B> {
347 // -----------------------------------------------------------------------
348 // forward (full sequence)
349 // -----------------------------------------------------------------------
350
351 /// Run the Pre-LN residual block over a full sequence.
352 ///
353 /// Computes:
354 /// ```text
355 /// output = x · residual_scale + Mamba3( RMSNorm(x) )
356 /// ```
357 ///
358 /// # Shapes
359 /// - `x` : `[batch, sequence, d_model]`
360 /// - output : `[batch, sequence, d_model]`
361 pub fn forward(
362 &self,
363 x: Tensor<B, 3>,
364 cache: Option<Mamba3Cache<B>>,
365 ssd_path: Mamba3SsdPath,
366 residual_scale: f32,
367 ) -> (Tensor<B, 3>, Mamba3Cache<B>) {
368 let [batch, sequence, d_model] = x.dims();
369
370 // Save the (optionally scaled) residual *before* normalisation so that
371 // the norm does not affect the skip path.
372 let res_bsm = x.clone() * residual_scale;
373
374 let normed_bsm = self.norm.forward(x);
375 assert_eq!([batch, sequence, d_model], normed_bsm.dims());
376
377 let (mamba_out_bsm, cache) = self
378 .mamba_block
379 .forward(normed_bsm, cache, ssd_path.clone());
380 assert_eq!([batch, sequence, d_model], mamba_out_bsm.dims());
381
382 // Residual addition: y = x · scale + Mamba3(norm(x))
383 let out_bsm = mamba_out_bsm + res_bsm;
384 assert_eq!([batch, sequence, d_model], out_bsm.dims());
385
386 (out_bsm, cache)
387 }
388
389 // -----------------------------------------------------------------------
390 // step (single token)
391 // -----------------------------------------------------------------------
392
393 /// Run the Pre-LN residual block for a **single** decoding step.
394 ///
395 /// Computes:
396 /// ```text
397 /// output = x · residual_scale + Mamba3.step( RMSNorm(x) )
398 /// ```
399 ///
400 /// # Shapes
401 /// - `x` : `[batch, d_model]`
402 /// - output: `[batch, d_model]`
403 pub fn step(
404 &self,
405 x: Tensor<B, 2>,
406 cache: Option<Mamba3Cache<B>>,
407 residual_scale: f32,
408 ) -> (Tensor<B, 2>, Mamba3Cache<B>) {
409 let [batch, d_model] = x.dims();
410
411 let res_bm = x.clone() * residual_scale;
412
413 let normed_bm = self.norm.forward(x);
414 assert_eq!([batch, d_model], normed_bm.dims());
415
416 let (mamba_out_bm, cache) = self.mamba_block.step(normed_bm, cache);
417 assert_eq!([batch, d_model], mamba_out_bm.dims());
418
419 let out_bm = mamba_out_bm + res_bm;
420 assert_eq!([batch, d_model], out_bm.dims());
421
422 (out_bm, cache)
423 }
424}