Skip to main content

burn_mamba/mamba2/
layer.rs

1//! # Mamba-2 Layer and Layer Stack
2//!
3//! A **Mamba-2 layer** is the standard Pre-LN residual block used throughout
4//! the network.  It wraps a single [`Mamba2`] SSM block with an RMSNorm
5//! (applied to the input, *before* the block) and adds the input back as a
6//! residual connection:
7//!
8//! ```text
9//!   y = x + Mamba2( RMSNorm(x) )
10//! ```
11//!
12//! This matches the architecture described in §5 of the Mamba-2 paper and is
13//! identical in structure to Pre-LN Transformer layers.
14//!
15//! ## Virtual layers
16//!
17//! [`Mamba2Layers`] supports *virtual layers*: a larger logical depth achieved
18//! by cycling through a smaller set of *real* (weight-bearing) layers
19//! according to a [`Schedule`].  For example, 48 virtual layers over 12 real
20//! layers repeats each weight set 4 times.  Each virtual layer still has its
21//! **own cache** (the hidden state evolves independently), but shares the
22//! underlying parameters.
23//!
24//! ## Residual scale
25//!
26//! The first and/or last residual connection in the stack can optionally be
27//! zeroed out (`ignore_first_residual` / `ignore_last_residual`), which is
28//! useful when composing Mamba-2 blocks with other module types (e.g. in a
29//! hybrid Mamba-2 + attention architecture where neighbouring blocks already
30//! carry residuals).
31
32use crate::mamba2::prelude::*;
33use crate::schedule::Schedule;
34use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
35use burn::prelude::*;
36
37// ---------------------------------------------------------------------------
38// Mamba2Layers  (the full layer stack)
39// ---------------------------------------------------------------------------
40
41/// A stack of Mamba-2 layers with optional virtual-layer scheduling.
42///
43/// The stack maintains `n_real_layers` distinct weight sets but can execute
44/// `n_virtual_layers` logical forward passes, cycling through weights
45/// according to the provided [`Schedule`].
46#[derive(Module, Debug)]
47pub struct Mamba2Layers<B: Backend> {
48    /// Number of real (weight-bearing) layers.
49    pub n_real_layers: usize,
50
51    /// Optional `(n_virtual_layers, schedule)` for weight-sharing.
52    ///
53    /// When `None`, the virtual layer count falls back to `n_real_layers` (no
54    /// sharing).  Marked `module(skip)` so Burn does not treat it as a
55    /// trainable parameter.
56    #[module(skip)]
57    pub n_virtual_layers: Option<(usize, Schedule)>,
58
59    /// The actual weight-bearing layer instances.
60    ///
61    /// Length: `n_real_layers`.
62    pub real_layers: Vec<Mamba2Layer<B>>,
63
64    /// When `true`, the residual connection of the **first** virtual layer is
65    /// scaled to zero (i.e. the first block acts as a pure projection, not a
66    /// residual update).
67    pub ignore_first_residual: bool,
68
69    /// When `true`, the residual connection of the **last** virtual layer is
70    /// scaled to zero.
71    pub ignore_last_residual: bool,
72}
73
74/// Configuration / factory for [`Mamba2Layers`].
75#[derive(Config, Debug)]
76pub struct Mamba2LayersConfig {
77    /// Number of distinct weight sets to allocate.
78    pub n_real_layers: usize,
79
80    /// Optional virtual-layer scheduling.  See [`Mamba2Layers`] for details.
81    #[config(default = "None")]
82    pub n_virtual_layers: Option<(usize, Schedule)>,
83
84    /// Configuration shared by all Mamba-2 blocks in the stack.
85    pub mamba_block: Mamba2Config,
86
87    /// See [`Mamba2Layers::ignore_first_residual`].
88    #[config(default = false)]
89    pub ignore_first_residual: bool,
90
91    /// See [`Mamba2Layers::ignore_last_residual`].
92    #[config(default = false)]
93    pub ignore_last_residual: bool,
94}
95
96impl Mamba2LayersConfig {
97    /// Allocate and initialise all layers on `device`.
98    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Layers<B> {
99        let real_layers = (0..self.n_real_layers)
100            .map(|_| Mamba2LayerConfig::new(self.mamba_block.clone()).init(device))
101            .collect();
102
103        Mamba2Layers {
104            n_real_layers: self.n_real_layers,
105            n_virtual_layers: self.n_virtual_layers.clone(),
106            real_layers,
107            ignore_first_residual: self.ignore_first_residual,
108            ignore_last_residual: self.ignore_last_residual,
109        }
110    }
111}
112
113impl<B: Backend + Mamba2BackendExt> Mamba2Layers<B> {
114    // -----------------------------------------------------------------------
115    // forward  (chunked SSD — used for training / prefill)
116    // -----------------------------------------------------------------------
117
118    /// Process a full sequence through every (virtual) layer.
119    ///
120    /// Internally each layer calls [`Mamba2::forward`], which runs the
121    /// chunkwise SSD algorithm.  This is efficient for training because the
122    /// intra-chunk products can exploit GEMM / tensor cores.
123    ///
124    /// If `caches` is `None`, zero-initialised caches are created automatically.
125    ///
126    /// # Arguments
127    /// - `x` — input tensor, shape `[batch, sequence, d_model]`
128    /// - `caches` — optional pre-filled layer caches (useful for prefill
129    ///   followed by decode)
130    /// - `ssd_path` — SSD algorithm and chunk length selection.
131    ///
132    /// # Returns
133    /// `(output, updated_caches)` where `output` has shape
134    /// `[batch, sequence, d_model]`.
135    pub fn forward(
136        &self,
137        mut x: Tensor<B, 3>,
138        caches: Option<Mamba2Caches<B>>,
139        ssd_path: Mamba2SsdPath,
140    ) -> (Tensor<B, 3>, Mamba2Caches<B>) {
141        // The effective number of forward passes equals the number of *virtual*
142        // layers.  When no scheduling is configured this equals n_real_layers.
143        let n_virtual_layers = self.n_virtual_count();
144
145        // Lazily allocate zero caches the first time (e.g. during training or
146        // the first prefill call).
147        let caches = caches.unwrap_or_else(|| self.make_zero_caches(&x, n_virtual_layers));
148
149        assert_eq!(
150            caches.caches.len(),
151            n_virtual_layers,
152            "cache count must match the number of virtual layers; \
153             layers in forward() cannot share caches"
154        );
155
156        // Unwrap each cache slot into an `Option` so we can `take` it in the
157        // loop without cloning (Burn tensors are reference-counted).
158        let mut caches: Vec<Option<Mamba2Cache<B>>> = caches.caches.into_iter().map(Some).collect();
159
160        #[allow(clippy::needless_range_loop)]
161        for i in 0..n_virtual_layers {
162            // Map virtual layer index → real (weight-bearing) layer index.
163            let layer_idx = self.real_idx(i);
164            let layer = &self.real_layers[layer_idx];
165
166            // The residual scale is 0.0 for the first/last layer if the
167            // corresponding `ignore_*_residual` flag is set, and 1.0 otherwise.
168            let residual_scale = self.residual_scale(i, n_virtual_layers);
169
170            let cache = caches[i].take().unwrap();
171            let (x_, cache_) = layer.forward(x, Some(cache), ssd_path.clone(), residual_scale);
172            x = x_;
173            caches[i] = Some(cache_);
174        }
175
176        let caches = Mamba2Caches {
177            caches: caches.into_iter().map(Option::unwrap).collect(),
178        };
179        (x, caches)
180    }
181
182    // -----------------------------------------------------------------------
183    // step  (recurrent SSM — used for autoregressive decoding)
184    // -----------------------------------------------------------------------
185
186    /// Process a **single token** through every (virtual) layer.
187    ///
188    /// Each layer calls [`Mamba2::step`], which runs one tick of the recurrent
189    /// SSM:  `hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ`,  `yₜ = Cₜᵀ hₜ + D xₜ`.
190    /// This is O(H·P·N) per step — independent of sequence length — and
191    /// requires no KV-cache.
192    ///
193    /// # Arguments
194    /// - `x` — current token embedding, shape `[batch, d_model]`
195    /// - `caches` — layer caches from the previous step (or `None` for the
196    ///   first token, in which case zero caches are created)
197    ///
198    /// # Returns
199    /// `(output, updated_caches)` where `output` has shape `[batch, d_model]`.
200    pub fn step(
201        &self,
202        mut x: Tensor<B, 2>,
203        caches: Option<Mamba2Caches<B>>,
204    ) -> (Tensor<B, 2>, Mamba2Caches<B>) {
205        let n_virtual_layers = self.n_virtual_count();
206        let caches = caches.unwrap_or_else(|| self.make_zero_caches_2d(&x, n_virtual_layers));
207
208        assert_eq!(
209            caches.caches.len(),
210            n_virtual_layers,
211            "cache count must match the number of virtual layers; \
212             layers in step() cannot share caches"
213        );
214
215        let mut caches: Vec<Option<Mamba2Cache<B>>> = caches.caches.into_iter().map(Some).collect();
216
217        #[allow(clippy::needless_range_loop)]
218        for i in 0..n_virtual_layers {
219            let layer_idx = self.real_idx(i);
220            let layer = &self.real_layers[layer_idx];
221            let residual_scale = self.residual_scale(i, n_virtual_layers);
222
223            let cache = caches[i].take().unwrap();
224            let (x_, cache_) = layer.step(x, Some(cache), residual_scale);
225            x = x_;
226            caches[i] = Some(cache_);
227        }
228
229        let caches = Mamba2Caches {
230            caches: caches.into_iter().map(Option::unwrap).collect(),
231        };
232        (x, caches)
233    }
234
235    // -----------------------------------------------------------------------
236    // Private helpers
237    // -----------------------------------------------------------------------
238
239    /// Effective number of forward passes (virtual layers).
240    fn n_virtual_count(&self) -> usize {
241        self.n_virtual_layers
242            .as_ref()
243            .map(|(l, _)| *l)
244            .unwrap_or(self.n_real_layers)
245    }
246
247    /// Map a virtual layer index to the corresponding real layer index using
248    /// the configured schedule (or identity when no schedule is set).
249    fn real_idx(&self, virtual_idx: usize) -> usize {
250        if let Some((n_virtual_layers, schedule)) = &self.n_virtual_layers {
251            schedule.real_idx(virtual_idx, *n_virtual_layers, self.n_real_layers)
252        } else {
253            virtual_idx
254        }
255    }
256
257    /// Returns 0.0 if this layer's residual should be suppressed, else 1.0.
258    fn residual_scale(&self, i: usize, n_virtual: usize) -> f32 {
259        let is_first = self.ignore_first_residual && i == 0;
260        let is_last = self.ignore_last_residual && i + 1 == n_virtual;
261        if is_first || is_last { 0.0 } else { 1.0 }
262    }
263
264    /// Build zero-initialised caches from a 3-D input tensor `[B, S, D]`.
265    fn make_zero_caches(&self, x: &Tensor<B, 3>, n_virtual: usize) -> Mamba2Caches<B> {
266        let device = &x.device();
267        let [batch, _sequence, _d_model] = x.dims();
268        let layer0 = &self.real_layers[0].mamba_block;
269        let [conv_dim, _, conv_kernel] = layer0.conv1d.weight.dims();
270
271        Mamba2CachesConfig::new(
272            n_virtual,
273            Mamba2CacheConfig {
274                batch,
275                state_rank: layer0.state_rank,
276                conv_kernel,
277                conv_dim,
278                per_head_dim: layer0.per_head_dim(),
279                nheads: layer0.nheads(),
280            },
281        )
282        .init(device)
283    }
284
285    /// Build zero-initialised caches from a 2-D input tensor `[B, D]`.
286    fn make_zero_caches_2d(&self, x: &Tensor<B, 2>, n_virtual: usize) -> Mamba2Caches<B> {
287        let device = &x.device();
288        let [batch, _d_model] = x.dims();
289        let layer0 = &self.real_layers[0].mamba_block;
290        let [conv_dim, _, conv_kernel] = layer0.conv1d.weight.dims();
291
292        Mamba2CachesConfig::new(
293            n_virtual,
294            Mamba2CacheConfig {
295                batch,
296                state_rank: layer0.state_rank,
297                conv_kernel,
298                conv_dim,
299                per_head_dim: layer0.per_head_dim(),
300                nheads: layer0.nheads(),
301            },
302        )
303        .init(device)
304    }
305}
306
307// ---------------------------------------------------------------------------
308// Mamba2Layer  (single Pre-LN residual block)
309// ---------------------------------------------------------------------------
310
311/// A single Mamba-2 residual block:
312///
313/// ```text
314///   output = x·scale + Mamba2( RMSNorm(x) )
315/// ```
316///
317/// where `scale` is 1.0 normally and 0.0 when the residual connection is
318/// intentionally suppressed by the layer stack configuration.
319#[derive(Module, Debug)]
320pub struct Mamba2Layer<B: Backend> {
321    /// Pre-norm applied to the input before the SSM block.
322    ///
323    /// Using RMSNorm *before* the block (Pre-LN) is standard practice in
324    /// modern LLMs and improves training stability.
325    pub norm: RmsNorm<B>,
326
327    /// The Mamba-2 SSM block (see [`Mamba2`]).
328    pub mamba_block: Mamba2<B>,
329}
330
331/// Configuration / factory for [`Mamba2Layer`].
332#[derive(Config, Debug)]
333pub struct Mamba2LayerConfig {
334    /// Configuration for the inner Mamba-2 block.
335    pub mamba_block: Mamba2Config,
336}
337
338impl Mamba2LayerConfig {
339    /// Allocate and initialise the layer on `device`.
340    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Layer<B> {
341        Mamba2Layer {
342            norm: RmsNormConfig::new(self.mamba_block.d_model).init(device),
343            mamba_block: self.mamba_block.init(device),
344        }
345    }
346}
347
348impl<B: Backend + Mamba2BackendExt> Mamba2Layer<B> {
349    // -----------------------------------------------------------------------
350    // forward  (full sequence)
351    // -----------------------------------------------------------------------
352
353    /// Run the Pre-LN residual block over a full sequence.
354    ///
355    /// Computes:
356    /// ```text
357    ///   output = x · residual_scale + Mamba2( RMSNorm(x) )
358    /// ```
359    ///
360    /// # Shapes
361    /// - `x`      : `[batch, sequence, d_model]`
362    /// - output   : `[batch, sequence, d_model]`
363    pub fn forward(
364        &self,
365        x: Tensor<B, 3>,
366        cache: Option<Mamba2Cache<B>>,
367        ssd_path: Mamba2SsdPath,
368        residual_scale: f32,
369    ) -> (Tensor<B, 3>, Mamba2Cache<B>) {
370        let [batch, sequence, d_model] = x.dims();
371
372        // Save the (optionally scaled) residual *before* normalisation so that
373        // the norm does not affect the skip path.
374        let res_bsm = x.clone() * residual_scale;
375
376        let normed_bsm = self.norm.forward(x);
377        assert_eq!([batch, sequence, d_model], normed_bsm.dims());
378
379        let (mamba_out_bsm, cache) = self
380            .mamba_block
381            .forward(normed_bsm, cache, ssd_path.clone());
382        assert_eq!([batch, sequence, d_model], mamba_out_bsm.dims());
383
384        // Residual addition:  y = x · scale + Mamba2(norm(x))
385        let out_bsm = mamba_out_bsm + res_bsm;
386        assert_eq!([batch, sequence, d_model], out_bsm.dims());
387
388        (out_bsm, cache)
389    }
390
391    // -----------------------------------------------------------------------
392    // step  (single token)
393    // -----------------------------------------------------------------------
394
395    /// Run the Pre-LN residual block for a **single** decoding step.
396    ///
397    /// Computes:
398    /// ```text
399    ///   output = x · residual_scale + Mamba2.step( RMSNorm(x) )
400    /// ```
401    ///
402    /// # Shapes
403    /// - `x`  : `[batch, d_model]`
404    /// - output: `[batch, d_model]`
405    pub fn step(
406        &self,
407        x: Tensor<B, 2>,
408        cache: Option<Mamba2Cache<B>>,
409        residual_scale: f32,
410    ) -> (Tensor<B, 2>, Mamba2Cache<B>) {
411        let [batch, d_model] = x.dims();
412
413        let res_bm = x.clone() * residual_scale;
414
415        let normed_bm = self.norm.forward(x);
416        assert_eq!([batch, d_model], normed_bm.dims());
417
418        let (mamba_out_bm, cache) = self.mamba_block.step(normed_bm, cache);
419        assert_eq!([batch, d_model], mamba_out_bm.dims());
420
421        let out_bm = mamba_out_bm + res_bm;
422        assert_eq!([batch, d_model], out_bm.dims());
423
424        (out_bm, cache)
425    }
426}