Skip to main content

burn_mamba/mamba3/
layer.rs

1//! # Mamba-3 Layer and Layer Stack
2//!
3//! A **Mamba-3 layer** is the standard Pre-LN residual block used throughout
4//! the network.  It wraps a single [`Mamba3`] SSM block with an RMSNorm
5//! (applied to the input, *before* the block) and adds the input back as a
6//! residual connection:
7//!
8//! ```text
9//!   y = x + Mamba3( RMSNorm(x) )
10//! ```
11//!
12//! This matches the architecture described in §5 of the Mamba-2 paper and is
13//! identical in structure to Pre-LN Transformer layers.
14//!
15//! ## Virtual layers
16//!
17//! [`Mamba3Layers`] supports *virtual layers*: a larger logical depth achieved
18//! by cycling through a smaller set of *real* (weight-bearing) layers
19//! according to a [`Schedule`].  For example, 48 virtual layers over 12 real
20//! layers repeats each weight set 4 times.  Each virtual layer still has its
21//! **own cache** (the hidden state evolves independently), but shares the
22//! underlying parameters.
23//!
24//! ## Residual scale
25//!
26//! The first and/or last residual connection in the stack can optionally be
27//! zeroed out (`ignore_first_residual` / `ignore_last_residual`), which is
28//! useful when composing Mamba-3 blocks with other module types (e.g. in a
29//! hybrid Mamba-3 + attention architecture where neighbouring blocks already
30//! carry residuals).
31
32use crate::mamba3::prelude::*;
33use crate::schedule::Schedule;
34use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
35use burn::prelude::*;
36
37// ---------------------------------------------------------------------------
38// Mamba3Layers  (the full layer stack)
39// ---------------------------------------------------------------------------
40
41/// A stack of Mamba-3 layers with optional virtual-layer scheduling.
42///
43/// The stack maintains `n_real_layers` distinct weight sets but can execute
44/// `n_virtual_layers` logical forward passes, cycling through weights
45/// according to the provided [`Schedule`].
46#[derive(Module, Debug)]
47pub struct Mamba3Layers<B: Backend> {
48    /// Number of real (weight-bearing) layers.
49    pub n_real_layers: usize,
50
51    /// Optional `(n_virtual_layers, schedule)` for weight-sharing.
52    ///
53    /// When `None`, the virtual layer count falls back to `n_real_layers` (no
54    /// sharing).  Marked `module(skip)` so Burn does not treat it as a
55    /// trainable parameter.
56    #[module(skip)]
57    pub n_virtual_layers: Option<(usize, Schedule)>,
58
59    /// The actual weight-bearing layer instances.
60    ///
61    /// Length: `n_real_layers`.
62    pub real_layers: Vec<Mamba3Layer<B>>,
63
64    /// When `true`, the residual connection of the **first** virtual layer is
65    /// scaled to zero (i.e. the first block acts as a pure projection, not a
66    /// residual update).
67    pub ignore_first_residual: bool,
68
69    /// When `true`, the residual connection of the **last** virtual layer is
70    /// scaled to zero.
71    pub ignore_last_residual: bool,
72}
73
74/// Configuration / factory for [`Mamba3Layers`].
75#[derive(Config, Debug)]
76pub struct Mamba3LayersConfig {
77    /// Number of distinct weight sets to allocate.
78    pub n_real_layers: usize,
79
80    /// Optional virtual-layer scheduling.  See [`Mamba3Layers`] for details.
81    #[config(default = "None")]
82    pub n_virtual_layers: Option<(usize, Schedule)>,
83
84    /// Configuration shared by all Mamba-3 blocks in the stack.
85    pub mamba_block: Mamba3Config,
86
87    /// See [`Mamba3Layers::ignore_first_residual`].
88    #[config(default = false)]
89    pub ignore_first_residual: bool,
90
91    /// See [`Mamba3Layers::ignore_last_residual`].
92    #[config(default = false)]
93    pub ignore_last_residual: bool,
94}
95
96impl Mamba3LayersConfig {
97    /// Allocate and initialise all layers on `device`.
98    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Layers<B> {
99        let real_layers = (0..self.n_real_layers)
100            .map(|_| Mamba3LayerConfig::new(self.mamba_block.clone()).init(device))
101            .collect();
102
103        Mamba3Layers {
104            n_real_layers: self.n_real_layers,
105            n_virtual_layers: self.n_virtual_layers.clone(),
106            real_layers,
107            ignore_first_residual: self.ignore_first_residual,
108            ignore_last_residual: self.ignore_last_residual,
109        }
110    }
111}
112
113impl<B: Backend + Mamba3BackendExt> Mamba3Layers<B> {
114    // -----------------------------------------------------------------------
115    // forward  (chunked SSD — used for training / prefill)
116    // -----------------------------------------------------------------------
117
118    /// Process a full sequence through every (virtual) layer.
119    ///
120    /// Internally each layer calls [`Mamba3::forward`], which runs the
121    /// chunkwise SSD algorithm.  This is efficient for training because the
122    /// intra-chunk products can exploit GEMM / tensor cores.
123    ///
124    /// If `caches` is `None`, zero-initialised caches are created automatically.
125    ///
126    /// # Arguments
127    /// - `x` — input tensor, shape `[batch, sequence, d_model]`
128    /// - `caches` — optional pre-filled layer caches (useful for prefill
129    ///   followed by decode)
130    /// - `ssd_path` — SSD algorithm and chunk length selection.
131    ///
132    /// # Returns
133    /// `(output, updated_caches)` where `output` has shape
134    /// `[batch, sequence, d_model]`.
135    pub fn forward(
136        &self,
137        mut x: Tensor<B, 3>,
138        caches: Option<Mamba3Caches<B>>,
139        ssd_path: Mamba3SsdPath,
140    ) -> (Tensor<B, 3>, Mamba3Caches<B>) {
141        // The effective number of forward passes equals the number of *virtual*
142        // layers.  When no scheduling is configured this equals n_real_layers.
143        let n_virtual_layers = self.n_virtual_count();
144
145        // Lazily allocate zero caches the first time (e.g. during training or
146        // the first prefill call).
147        let caches = caches.unwrap_or_else(|| self.make_zero_caches(&x, n_virtual_layers));
148
149        assert_eq!(
150            caches.caches.len(),
151            n_virtual_layers,
152            "cache count must match the number of virtual layers; \
153             layers in forward() cannot share caches"
154        );
155
156        // Unwrap each cache slot into an `Option` so we can `take` it in the
157        // loop without cloning (Burn tensors are reference-counted).
158        let mut caches: Vec<Option<Mamba3Cache<B>>> = caches.caches.into_iter().map(Some).collect();
159
160        #[allow(clippy::needless_range_loop)]
161        for i in 0..n_virtual_layers {
162            // Map virtual layer index → real (weight-bearing) layer index.
163            let layer_idx = self.real_idx(i);
164            let layer = &self.real_layers[layer_idx];
165
166            // The residual scale is 0.0 for the first/last layer if the
167            // corresponding `ignore_*_residual` flag is set, and 1.0 otherwise.
168            let residual_scale = self.residual_scale(i, n_virtual_layers);
169
170            let cache = caches[i].take().unwrap();
171            let (x_, cache_) = layer.forward(x, Some(cache), ssd_path.clone(), residual_scale);
172            x = x_;
173            caches[i] = Some(cache_);
174        }
175
176        let caches = Mamba3Caches {
177            caches: caches.into_iter().map(Option::unwrap).collect(),
178        };
179        (x, caches)
180    }
181
182    // -----------------------------------------------------------------------
183    // step  (recurrent SSM — used for autoregressive decoding)
184    // -----------------------------------------------------------------------
185
186    /// Process a **single token** through every (virtual) layer.
187    ///
188    /// Each layer calls [`Mamba3::step`], which runs one tick of the recurrent
189    /// SSM:  `hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ`,  `yₜ = Cₜᵀ hₜ + D xₜ`.
190    /// This is O(H·P·N) per step — independent of sequence length — and
191    /// requires no KV-cache.
192    ///
193    /// # Arguments
194    /// - `x` — current token embedding, shape `[batch, d_model]`
195    /// - `caches` — layer caches from the previous step (or `None` for the
196    ///   first token, in which case zero caches are created)
197    ///
198    /// # Returns
199    /// `(output, updated_caches)` where `output` has shape `[batch, d_model]`.
200    pub fn step(
201        &self,
202        mut x: Tensor<B, 2>,
203        caches: Option<Mamba3Caches<B>>,
204    ) -> (Tensor<B, 2>, Mamba3Caches<B>) {
205        let n_virtual_layers = self.n_virtual_count();
206        let caches = caches.unwrap_or_else(|| self.make_zero_caches_2d(&x, n_virtual_layers));
207
208        assert_eq!(
209            caches.caches.len(),
210            n_virtual_layers,
211            "cache count must match the number of virtual layers; \
212             layers in step() cannot share caches"
213        );
214
215        let mut caches: Vec<Option<Mamba3Cache<B>>> = caches.caches.into_iter().map(Some).collect();
216
217        #[allow(clippy::needless_range_loop)]
218        for i in 0..n_virtual_layers {
219            let layer_idx = self.real_idx(i);
220            let layer = &self.real_layers[layer_idx];
221            let residual_scale = self.residual_scale(i, n_virtual_layers);
222
223            let cache = caches[i].take().unwrap();
224            let (x_, cache_) = layer.step(x, Some(cache), residual_scale);
225            x = x_;
226            caches[i] = Some(cache_);
227        }
228
229        let caches = Mamba3Caches {
230            caches: caches.into_iter().map(Option::unwrap).collect(),
231        };
232        (x, caches)
233    }
234
235    // -----------------------------------------------------------------------
236    // Private helpers
237    // -----------------------------------------------------------------------
238
239    /// Effective number of forward passes (virtual layers).
240    fn n_virtual_count(&self) -> usize {
241        self.n_virtual_layers
242            .as_ref()
243            .map(|(l, _)| *l)
244            .unwrap_or(self.n_real_layers)
245    }
246
247    /// Map a virtual layer index to the corresponding real layer index using
248    /// the configured schedule (or identity when no schedule is set).
249    fn real_idx(&self, virtual_idx: usize) -> usize {
250        if let Some((n_virtual_layers, schedule)) = &self.n_virtual_layers {
251            schedule.real_idx(virtual_idx, *n_virtual_layers, self.n_real_layers)
252        } else {
253            virtual_idx
254        }
255    }
256
257    /// Returns 0.0 if this layer's residual should be suppressed, else 1.0.
258    fn residual_scale(&self, i: usize, n_virtual: usize) -> f32 {
259        let is_first = self.ignore_first_residual && i == 0;
260        let is_last = self.ignore_last_residual && i + 1 == n_virtual;
261        if is_first || is_last { 0.0 } else { 1.0 }
262    }
263
264    /// Build zero-initialised caches from a 3-D input tensor `[B, S, D]`.
265    fn make_zero_caches(&self, x: &Tensor<B, 3>, n_virtual: usize) -> Mamba3Caches<B> {
266        let device = &x.device();
267        let [batch, _sequence, _d_model] = x.dims();
268        let layer0 = &self.real_layers[0].mamba_block;
269
270        Mamba3CachesConfig::new(
271            n_virtual,
272            Mamba3CacheConfig {
273                batch,
274                state_rank: layer0.state_rank,
275                num_rope_angles: layer0.num_rope_angles,
276                per_head_dim: layer0.per_head_dim(),
277                nheads: layer0.nheads(),
278                mimo_rank: layer0.mimo_rank,
279            },
280        )
281        .init(device)
282    }
283
284    /// Build zero-initialised caches from a 2-D input tensor `[B, D]`.
285    fn make_zero_caches_2d(&self, x: &Tensor<B, 2>, n_virtual: usize) -> Mamba3Caches<B> {
286        let device = &x.device();
287        let [batch, _d_model] = x.dims();
288        let layer0 = &self.real_layers[0].mamba_block;
289
290        Mamba3CachesConfig::new(
291            n_virtual,
292            Mamba3CacheConfig {
293                batch,
294                state_rank: layer0.state_rank,
295                num_rope_angles: layer0.num_rope_angles,
296                per_head_dim: layer0.per_head_dim(),
297                nheads: layer0.nheads(),
298                mimo_rank: layer0.mimo_rank,
299            },
300        )
301        .init(device)
302    }
303}
304
305// ---------------------------------------------------------------------------
306// Mamba3Layer  (single Pre-LN residual block)
307// ---------------------------------------------------------------------------
308
309/// A single Mamba-3 residual block:
310///
311/// ```text
312///   output = x·scale + Mamba3( RMSNorm(x) )
313/// ```
314///
315/// where `scale` is 1.0 normally and 0.0 when the residual connection is
316/// intentionally suppressed by the layer stack configuration.
317#[derive(Module, Debug)]
318pub struct Mamba3Layer<B: Backend> {
319    /// Pre-norm applied to the input before the SSM block.
320    ///
321    /// Using RMSNorm *before* the block (Pre-LN) is standard practice in
322    /// modern LLMs and improves training stability.
323    pub norm: RmsNorm<B>,
324
325    /// The Mamba-3 SSM block (see [`Mamba3`]).
326    pub mamba_block: Mamba3<B>,
327}
328
329/// Configuration / factory for [`Mamba3Layer`].
330#[derive(Config, Debug)]
331pub struct Mamba3LayerConfig {
332    /// Configuration for the inner Mamba-3 block.
333    pub mamba_block: Mamba3Config,
334}
335
336impl Mamba3LayerConfig {
337    /// Allocate and initialise the layer on `device`.
338    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Layer<B> {
339        Mamba3Layer {
340            norm: RmsNormConfig::new(self.mamba_block.d_model).init(device),
341            mamba_block: self.mamba_block.init(device),
342        }
343    }
344}
345
346impl<B: Backend + Mamba3BackendExt> Mamba3Layer<B> {
347    // -----------------------------------------------------------------------
348    // forward  (full sequence)
349    // -----------------------------------------------------------------------
350
351    /// Run the Pre-LN residual block over a full sequence.
352    ///
353    /// Computes:
354    /// ```text
355    ///   output = x · residual_scale + Mamba3( RMSNorm(x) )
356    /// ```
357    ///
358    /// # Shapes
359    /// - `x`      : `[batch, sequence, d_model]`
360    /// - output   : `[batch, sequence, d_model]`
361    pub fn forward(
362        &self,
363        x: Tensor<B, 3>,
364        cache: Option<Mamba3Cache<B>>,
365        ssd_path: Mamba3SsdPath,
366        residual_scale: f32,
367    ) -> (Tensor<B, 3>, Mamba3Cache<B>) {
368        let [batch, sequence, d_model] = x.dims();
369
370        // Save the (optionally scaled) residual *before* normalisation so that
371        // the norm does not affect the skip path.
372        let res_bsm = x.clone() * residual_scale;
373
374        let normed_bsm = self.norm.forward(x);
375        assert_eq!([batch, sequence, d_model], normed_bsm.dims());
376
377        let (mamba_out_bsm, cache) = self
378            .mamba_block
379            .forward(normed_bsm, cache, ssd_path.clone());
380        assert_eq!([batch, sequence, d_model], mamba_out_bsm.dims());
381
382        // Residual addition:  y = x · scale + Mamba3(norm(x))
383        let out_bsm = mamba_out_bsm + res_bsm;
384        assert_eq!([batch, sequence, d_model], out_bsm.dims());
385
386        (out_bsm, cache)
387    }
388
389    // -----------------------------------------------------------------------
390    // step  (single token)
391    // -----------------------------------------------------------------------
392
393    /// Run the Pre-LN residual block for a **single** decoding step.
394    ///
395    /// Computes:
396    /// ```text
397    ///   output = x · residual_scale + Mamba3.step( RMSNorm(x) )
398    /// ```
399    ///
400    /// # Shapes
401    /// - `x`  : `[batch, d_model]`
402    /// - output: `[batch, d_model]`
403    pub fn step(
404        &self,
405        x: Tensor<B, 2>,
406        cache: Option<Mamba3Cache<B>>,
407        residual_scale: f32,
408    ) -> (Tensor<B, 2>, Mamba3Cache<B>) {
409        let [batch, d_model] = x.dims();
410
411        let res_bm = x.clone() * residual_scale;
412
413        let normed_bm = self.norm.forward(x);
414        assert_eq!([batch, d_model], normed_bm.dims());
415
416        let (mamba_out_bm, cache) = self.mamba_block.step(normed_bm, cache);
417        assert_eq!([batch, d_model], mamba_out_bm.dims());
418
419        let out_bm = mamba_out_bm + res_bm;
420        assert_eq!([batch, d_model], out_bm.dims());
421
422        (out_bm, cache)
423    }
424}