Skip to main content

burn_mamba/modules/
network.rs

1use crate::modules::LayersBuilder;
2use crate::modules::{ResidualsConfig, RmsNorm, RmsNormConfig};
3use crate::prelude::*;
4use crate::utils::Schedule;
5use crate::utils::class::{
6    assert_step_compatible, class_marker_output_indices, class_step_injections, init_class_emb,
7    insert_class_markers,
8};
9use burn::config::Config;
10use burn::module::Param;
11use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig};
12use burn::prelude::*;
13
14// ===========================================================================
15// LatentNetwork<M>
16// ===========================================================================
17
18/// A feature/regression network on latents:
19/// `in_proj (input_size → d_model) → Layers<M> → out_proj (d_model → output_size)`.
20#[derive(Module, Debug)]
21pub struct LatentNetwork<M: Module> {
22    /// Linear projection `input_size → d_model`.
23    pub in_proj: Linear,
24    /// The shared Mamba-x layer stack.
25    pub layers: Layers<M>,
26    /// Linear projection `d_model → output_size`.
27    pub out_proj: Linear,
28    /// Positions of the network's class tokens, spliced into the input sequence
29    /// (at `input_size` width) **before** `in_proj`. Empty ⇒ none.
30    #[module(skip)]
31    pub class_tokens: Vec<ClassToken>,
32    /// The class-token embeddings, `[num_class_tokens, input_size]`.
33    pub class_tokens_emb: Option<Param<Tensor<2>>>,
34}
35
36impl<M: MambaBlock> LatentNetwork<M>
37where
38    M::SsdPath: Clone,
39{
40    /// Output positions of the class tokens for an `orig_len` input.
41    pub fn class_token_output_indices(&self, orig_len: usize) -> Vec<usize> {
42        class_marker_output_indices(&self.class_tokens, orig_len)
43    }
44
45    /// Splice this network's class latents into `x` (no-op when there are none).
46    fn insert_tokens(&self, x: Tensor<3>) -> Tensor<3> {
47        if self.class_tokens_emb.is_none() {
48            return x;
49        }
50        insert_class_markers(x, &self.class_tokens, self.class_tokens_emb.as_ref()).0
51    }
52
53    /// `in_proj → layers → out_proj` over a full sequence
54    /// (`[batch, sequence, input_size]` → `[batch, sequence (+ class tokens),
55    /// output_size]`).
56    pub fn forward(
57        &self,
58        x: Tensor<3>,
59        caches: Option<M::Caches>,
60        ssd_path: M::SsdPath,
61    ) -> (Tensor<3>, M::Caches) {
62        let x = self.insert_tokens(x);
63        let x = self.in_proj.forward(x);
64        let (x, caches) = self.layers.forward(x, caches, ssd_path);
65        let x = self.out_proj.forward(x);
66        (x, caches)
67    }
68
69    /// Single-token step (`[batch, input_size]` → `[batch, output_size]`).
70    ///
71    /// Three independent class cursors:
72    /// - `own_index` — the network's own [`Self::class_tokens`] (spliced before
73    ///   `in_proj`). When it lands on a class-token position those tokens are
74    ///   stepped first (each a full network pass, advancing `own_index`), then
75    ///   the user token; only the user token's output is returned.
76    /// - `layers_own_index` / `layer_indices` — forwarded straight to the inner
77    ///   [`Layers::step`] (stack-level latents, and the per-virtual-layer cursor
78    ///   vector respectively).
79    ///
80    /// As in `forward`, the network's class tokens are part of the sequence that
81    /// enters the layers, so each is threaded through the layers (carrying the
82    /// inner cursors) just like the user token — only the user token's output is
83    /// returned. A `None` cursor skips that level; `Middle`/`End` markers panic
84    /// for the cursored level (use `forward`).
85    pub fn step(
86        &self,
87        x: Tensor<2>,
88        caches: Option<M::Caches>,
89        own_index: Option<&mut usize>,
90        mut layers_own_index: Option<&mut usize>,
91        mut layer_indices: Option<&mut Vec<usize>>,
92    ) -> (Tensor<2>, M::Caches) {
93        // Network-level class-token injection. Each class token is run through a
94        // full network pass (carrying the inner cursors, so the layers splice
95        // their own latents around it exactly as in `forward`), then the user
96        // token; only the user token's output is returned.
97        if let Some(cursor) = own_index {
98            let [batch, input_size] = x.dims();
99            let inj = class_step_injections(&self.class_tokens, "LatentNetwork");
100            let emb = self.class_tokens_emb.as_ref();
101            let mut caches = caches;
102            while let Some(i) = inj.iter().position(|&p| p == *cursor) {
103                let row = emb
104                    .unwrap()
105                    .val()
106                    .narrow(0, i, 1)
107                    .expand([batch, input_size]);
108                let (_discard, c) = self.step(
109                    row,
110                    caches,
111                    None,
112                    layers_own_index.as_deref_mut(),
113                    layer_indices.as_deref_mut(),
114                );
115                caches = Some(c);
116                *cursor += 1;
117            }
118            let (out, caches) = self.step(x, caches, None, layers_own_index, layer_indices);
119            *cursor += 1;
120            return (out, caches);
121        }
122
123        // The actual one-token work: forward the inner cursors to the stack.
124        assert_step_compatible(&self.class_tokens, "LatentNetwork");
125        let x = self.in_proj.forward(x);
126        let (x, caches) = self.layers.step(x, caches, layers_own_index, layer_indices);
127        let x = self.out_proj.forward(x);
128        (x, caches)
129    }
130
131    /// Stationary fixed point of the network under a constant input token:
132    /// `in_proj → `[`Layers::step_infinite`]` → out_proj`, no caches.
133    /// Cursorless (class tokens are not injected).
134    pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
135        assert_step_compatible(&self.class_tokens, "LatentNetwork");
136        let x = self.in_proj.forward(x);
137        let x = self.layers.step_infinite(x);
138        self.out_proj.forward(x)
139    }
140
141    /// Approximate jump of `n` consecutive cursorless [`Self::step`] calls on
142    /// the same constant token — see [`Layers::step_n_approx`] for the
143    /// approximation contract.
144    pub fn step_n_approx(
145        &self,
146        x: Tensor<2>,
147        n: usize,
148        caches: Option<M::Caches>,
149    ) -> (Tensor<2>, M::Caches) {
150        assert_step_compatible(&self.class_tokens, "LatentNetwork");
151        let x = self.in_proj.forward(x);
152        let (x, caches) = self.layers.step_n_approx(x, n, caches);
153        (self.out_proj.forward(x), caches)
154    }
155}
156
157/// Plain factory for [`LatentNetwork`].
158pub struct LatentNetworkBuilder<C> {
159    /// Width of the input features fed to `in_proj`.
160    pub input_size: usize,
161    /// Builder for the layer stack.
162    pub layers: LayersBuilder<C>,
163    /// Width of the output features produced by `out_proj`.
164    pub output_size: usize,
165    /// Network-level class tokens (spliced into the input before `in_proj`).
166    pub class_tokens: Vec<ClassToken>,
167}
168
169impl<C: MambaBlockConfig> LatentNetworkBuilder<C> {
170    /// Allocate and initialise the network on `device`.
171    pub fn init(&self, device: &Device) -> LatentNetwork<C::Block> {
172        let d_model = self.layers.mamba_block.d_model();
173        LatentNetwork {
174            in_proj: LinearConfig::new(self.input_size, d_model)
175                .with_bias(true)
176                .init(device),
177            layers: self.layers.init(device),
178            out_proj: LinearConfig::new(d_model, self.output_size)
179                .with_bias(true)
180                .init(device),
181            class_tokens_emb: init_class_emb(self.class_tokens.len(), self.input_size, device),
182            class_tokens: self.class_tokens.clone(),
183        }
184    }
185}
186
187// ===========================================================================
188// VocabNetwork<M>
189// ===========================================================================
190
191/// A complete autoregressive language model over a token vocabulary:
192/// `Embedding (vocab → d_model) → Layers<M> → norm_f → LM head (d_model →
193/// vocab)`.
194///
195/// This is the token-LM counterpart of [`LatentNetwork`]; both are built on the
196/// shared [`Layers`] core. The only differences are the I/O boundary (a token
197/// `Embedding` and a vocab logit head, instead of two latent `Linear`s) and a
198/// final pre-head [`RmsNorm`].
199///
200/// The LM head is **tied** (`lm_head = None`, the transposed embedding weight is
201/// reused) or **untied** (a dedicated `Linear`); the vocabulary is rounded up to
202/// a multiple for GPU alignment (see [`VocabNetworkBuilder`]).
203#[derive(Module, Debug)]
204pub struct VocabNetwork<M: Module> {
205    /// Token embedding table, weight shape `[padded_vocab, d_model]`.
206    pub embedding: Embedding,
207    /// The shared Mamba-x layer stack.
208    pub layers: Layers<M>,
209    /// Final RMSNorm applied before the LM head (`norm_f`).
210    pub norm_f: RmsNorm,
211    /// Optional dedicated LM head. `None` ⇒ weight-tied (reuse embedding`ᵀ`).
212    pub lm_head: Option<Linear>,
213}
214
215impl<M: MambaBlock> VocabNetwork<M>
216where
217    M::SsdPath: Clone,
218{
219    /// Full-sequence pass: token IDs `[batch, sequence]` → logits
220    /// `[batch, sequence, padded_vocab]`.
221    pub fn forward(
222        &self,
223        x: Tensor<2, Int>,
224        caches: Option<M::Caches>,
225        ssd_path: M::SsdPath,
226    ) -> (Tensor<3>, M::Caches) {
227        let x = self.embedding.forward(x);
228        let (x, caches) = self.layers.forward(x, caches, ssd_path);
229        let x = self.norm_f.forward(x);
230        (self.apply_lm_head(x), caches)
231    }
232
233    /// Single-token step: token IDs `[batch]` → logits `[batch, padded_vocab]`.
234    ///
235    /// The vocab network has no class tokens of its own (those would duplicate
236    /// the layers' class latents); it simply forwards the inner [`Layers`]
237    /// cursors — `layers_own_index` (stack-level latents) and `layer_indices`
238    /// (per-virtual-layer) — to [`Layers::step`].
239    pub fn step(
240        &self,
241        x: Tensor<1, Int>,
242        caches: Option<M::Caches>,
243        layers_own_index: Option<&mut usize>,
244        layer_indices: Option<&mut Vec<usize>>,
245    ) -> (Tensor<2>, M::Caches) {
246        // Embed the single token via a temporary unit sequence axis.
247        let x = self
248            .embedding
249            .forward(x.unsqueeze_dim::<2>(1))
250            .squeeze_dim(1);
251        let (x, caches) = self.layers.step(x, caches, layers_own_index, layer_indices);
252        let x = self.norm_f.forward(x);
253        // Reuse the 3-D head by lifting/lowering the sequence axis.
254        let logits = self.apply_lm_head(x.unsqueeze_dim(1)).squeeze_dim(1);
255        (logits, caches)
256    }
257
258    /// Stationary fixed point of the LM under a constant token: logits
259    /// `[batch, padded_vocab]` after infinitely many repeats of `x`, no caches
260    /// (see [`Layers::step_infinite`]).
261    pub fn step_infinite(&self, x: Tensor<1, Int>) -> Tensor<2> {
262        let x = self
263            .embedding
264            .forward(x.unsqueeze_dim::<2>(1))
265            .squeeze_dim(1);
266        let x = self.layers.step_infinite(x);
267        let x = self.norm_f.forward(x);
268        self.apply_lm_head(x.unsqueeze_dim(1)).squeeze_dim(1)
269    }
270
271    /// Approximate jump of `n` consecutive [`Self::step`] calls on the same
272    /// constant token — see [`Layers::step_n_approx`] for the approximation
273    /// contract.
274    pub fn step_n_approx(
275        &self,
276        x: Tensor<1, Int>,
277        n: usize,
278        caches: Option<M::Caches>,
279    ) -> (Tensor<2>, M::Caches) {
280        let x = self
281            .embedding
282            .forward(x.unsqueeze_dim::<2>(1))
283            .squeeze_dim(1);
284        let (x, caches) = self.layers.step_n_approx(x, n, caches);
285        let x = self.norm_f.forward(x);
286        let logits = self.apply_lm_head(x.unsqueeze_dim(1)).squeeze_dim(1);
287        (logits, caches)
288    }
289
290    /// Project `[batch, sequence, d_model]` → `[batch, sequence, padded_vocab]`
291    /// using the dedicated head, or the tied (transposed embedding) weight.
292    fn apply_lm_head(&self, x: Tensor<3>) -> Tensor<3> {
293        if let Some(lm_head) = &self.lm_head {
294            lm_head.forward(x)
295        } else {
296            // Weight tying: reuse embedding.weight^T ([d_model, padded_vocab]).
297            let weight = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
298            Linear { weight, bias: None }.forward(x)
299        }
300    }
301}
302
303/// Plain factory for [`VocabNetwork`]. Mirrors [`LatentNetworkBuilder`] but adds
304/// vocab padding and the tied/untied LM-head choice.
305pub struct VocabNetworkBuilder<C> {
306    /// Unpadded vocabulary size (rounded up at init).
307    pub vocab_size: usize,
308    /// Round `vocab_size` up to a multiple of this (1 disables rounding).
309    pub pad_vocab_size_multiple: usize,
310    /// Builder for the layer stack.
311    pub layers: LayersBuilder<C>,
312    /// When `true`, tie the LM head to the (transposed) embedding weights.
313    pub missing_lm_head: bool,
314}
315
316impl<C: MambaBlockConfig> VocabNetworkBuilder<C> {
317    /// Round `vocab_size` up to the next multiple of `multiple`.
318    fn padded_vocab(vocab_size: usize, multiple: usize) -> usize {
319        if vocab_size.is_multiple_of(multiple) {
320            vocab_size
321        } else {
322            ((vocab_size / multiple) + 1) * multiple
323        }
324    }
325
326    /// Allocate and initialise the network on `device`.
327    pub fn init(&self, device: &Device) -> VocabNetwork<C::Block> {
328        let d_model = self.layers.mamba_block.d_model();
329        let padded_vocab = Self::padded_vocab(self.vocab_size, self.pad_vocab_size_multiple);
330        let lm_head = if self.missing_lm_head {
331            None
332        } else {
333            Some(
334                LinearConfig::new(d_model, padded_vocab)
335                    .with_bias(false)
336                    .init(device),
337            )
338        };
339        VocabNetwork {
340            embedding: EmbeddingConfig::new(padded_vocab, d_model).init(device),
341            layers: self.layers.init(device),
342            norm_f: RmsNormConfig::new(d_model).init(device),
343            lm_head,
344        }
345    }
346}
347
348// ===========================================================================
349// Unifying enums: one runtime + one serializable Config across all families
350// ===========================================================================
351
352/// A runtime-selectable latent network: the same `in_proj → Layers → out_proj`
353/// shape over any Mamba-x family, chosen at runtime.
354#[derive(Module, Debug)]
355pub enum MambaLatentNet {
356    /// Mamba-1 latent network.
357    #[cfg(feature = "mamba1")]
358    Mamba1(LatentNetwork<crate::mamba1::prelude::Mamba1>),
359    /// Mamba-2 latent network.
360    #[cfg(feature = "mamba2")]
361    Mamba2(LatentNetwork<crate::mamba2::prelude::Mamba2>),
362    /// Mamba-3 latent network.
363    #[cfg(feature = "mamba3")]
364    Mamba3(LatentNetwork<crate::mamba3::prelude::Mamba3>),
365}
366
367impl MambaLatentNet {
368    /// Full-sequence pass. The `ssd_path` must match the network's family; a
369    /// mismatch is a caller error and panics with an explanatory message.
370    pub fn forward(
371        &self,
372        x: Tensor<3>,
373        caches: Option<MambaCaches>,
374        ssd_path: MambaSsdPath,
375    ) -> (Tensor<3>, MambaCaches) {
376        match self {
377            #[cfg(feature = "mamba1")]
378            Self::Mamba1(net) => {
379                let caches = caches.map(|c| match c {
380                    MambaCaches::Mamba1(c) => c,
381                    #[allow(unreachable_patterns)]
382                    _ => panic!("cache family does not match Mamba-1 network"),
383                });
384                match ssd_path {
385                    MambaSsdPath::Mamba1 => {}
386                    #[allow(unreachable_patterns)]
387                    _ => panic!("ssd_path family does not match Mamba-1 network"),
388                }
389                let (y, c) = net.forward(x, caches, ());
390                (y, MambaCaches::Mamba1(c))
391            }
392            #[cfg(feature = "mamba2")]
393            Self::Mamba2(net) => {
394                let caches = caches.map(|c| match c {
395                    MambaCaches::Mamba2(c) => c,
396                    #[allow(unreachable_patterns)]
397                    _ => panic!("cache family does not match Mamba-2 network"),
398                });
399                let path = match ssd_path {
400                    MambaSsdPath::Mamba2(p) => p,
401                    #[allow(unreachable_patterns)]
402                    _ => panic!("ssd_path family does not match Mamba-2 network"),
403                };
404                let (y, c) = net.forward(x, caches, path);
405                (y, MambaCaches::Mamba2(c))
406            }
407            #[cfg(feature = "mamba3")]
408            Self::Mamba3(net) => {
409                let caches = caches.map(|c| match c {
410                    MambaCaches::Mamba3(c) => c,
411                    #[allow(unreachable_patterns)]
412                    _ => panic!("cache family does not match Mamba-3 network"),
413                });
414                let path = match ssd_path {
415                    MambaSsdPath::Mamba3(p) => p,
416                    #[allow(unreachable_patterns)]
417                    _ => panic!("ssd_path family does not match Mamba-3 network"),
418                };
419                let (y, c) = net.forward(x, caches, path);
420                (y, MambaCaches::Mamba3(c))
421            }
422        }
423    }
424
425    /// Single-token step. No path argument (decoding is recurrent for all
426    /// families). Cache family must match the network. The three class cursors
427    /// (`own_index` for the network's class tokens, `layers_own_index` for the
428    /// stack-level class latents, `layer_indices` for the per-virtual-layer
429    /// vector) are threaded to the inner network — see [`LatentNetwork::step`].
430    pub fn step(
431        &self,
432        x: Tensor<2>,
433        caches: Option<MambaCaches>,
434        own_index: Option<&mut usize>,
435        layers_own_index: Option<&mut usize>,
436        layer_indices: Option<&mut Vec<usize>>,
437    ) -> (Tensor<2>, MambaCaches) {
438        match self {
439            #[cfg(feature = "mamba1")]
440            Self::Mamba1(net) => {
441                let caches = caches.map(|c| match c {
442                    MambaCaches::Mamba1(c) => c,
443                    #[allow(unreachable_patterns)]
444                    _ => panic!("cache family does not match Mamba-1 network"),
445                });
446                let (y, c) = net.step(x, caches, own_index, layers_own_index, layer_indices);
447                (y, MambaCaches::Mamba1(c))
448            }
449            #[cfg(feature = "mamba2")]
450            Self::Mamba2(net) => {
451                let caches = caches.map(|c| match c {
452                    MambaCaches::Mamba2(c) => c,
453                    #[allow(unreachable_patterns)]
454                    _ => panic!("cache family does not match Mamba-2 network"),
455                });
456                let (y, c) = net.step(x, caches, own_index, layers_own_index, layer_indices);
457                (y, MambaCaches::Mamba2(c))
458            }
459            #[cfg(feature = "mamba3")]
460            Self::Mamba3(net) => {
461                let caches = caches.map(|c| match c {
462                    MambaCaches::Mamba3(c) => c,
463                    #[allow(unreachable_patterns)]
464                    _ => panic!("cache family does not match Mamba-3 network"),
465                });
466                let (y, c) = net.step(x, caches, own_index, layers_own_index, layer_indices);
467                (y, MambaCaches::Mamba3(c))
468            }
469        }
470    }
471
472    /// Stationary fixed point under a constant token (no caches) — see
473    /// [`LatentNetwork::step_infinite`]. Only the Mamba-3 family implements the
474    /// closed form; the other variants panic.
475    pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
476        match self {
477            #[cfg(feature = "mamba1")]
478            Self::Mamba1(net) => net.step_infinite(x),
479            #[cfg(feature = "mamba2")]
480            Self::Mamba2(net) => net.step_infinite(x),
481            #[cfg(feature = "mamba3")]
482            Self::Mamba3(net) => net.step_infinite(x),
483        }
484    }
485
486    /// Approximate jump of `n` consecutive constant-token steps — see
487    /// [`LatentNetwork::step_n_approx`]. Cache family must match the network;
488    /// only the Mamba-3 family implements the closed form (others panic).
489    pub fn step_n_approx(
490        &self,
491        x: Tensor<2>,
492        n: usize,
493        caches: Option<MambaCaches>,
494    ) -> (Tensor<2>, MambaCaches) {
495        match self {
496            #[cfg(feature = "mamba1")]
497            Self::Mamba1(net) => {
498                let caches = caches.map(|c| match c {
499                    MambaCaches::Mamba1(c) => c,
500                    #[allow(unreachable_patterns)]
501                    _ => panic!("cache family does not match Mamba-1 network"),
502                });
503                let (y, c) = net.step_n_approx(x, n, caches);
504                (y, MambaCaches::Mamba1(c))
505            }
506            #[cfg(feature = "mamba2")]
507            Self::Mamba2(net) => {
508                let caches = caches.map(|c| match c {
509                    MambaCaches::Mamba2(c) => c,
510                    #[allow(unreachable_patterns)]
511                    _ => panic!("cache family does not match Mamba-2 network"),
512                });
513                let (y, c) = net.step_n_approx(x, n, caches);
514                (y, MambaCaches::Mamba2(c))
515            }
516            #[cfg(feature = "mamba3")]
517            Self::Mamba3(net) => {
518                let caches = caches.map(|c| match c {
519                    MambaCaches::Mamba3(c) => c,
520                    #[allow(unreachable_patterns)]
521                    _ => panic!("cache family does not match Mamba-3 network"),
522                });
523                let (y, c) = net.step_n_approx(x, n, caches);
524                (y, MambaCaches::Mamba3(c))
525            }
526        }
527    }
528}
529
530/// The serializable, documentation-friendly config for [`MambaLatentNet`]. Each
531/// variant is concrete (per-family), so `#[derive(Config)]` applies; `init`
532/// builds the matching network variant.
533#[derive(Config, Debug)]
534pub enum MambaLatentNetConfig {
535    /// Build a Mamba-1 latent network.
536    #[cfg(feature = "mamba1")]
537    Mamba1 {
538        /// Input feature width.
539        input_size: usize,
540        /// Number of real layers.
541        n_real_layers: usize,
542        /// Optional virtual-layer scheduling.
543        n_virtual_layers: Option<(usize, Schedule)>,
544        /// Shared block config.
545        mamba_block: crate::mamba1::prelude::Mamba1Config,
546        /// Output feature width.
547        output_size: usize,
548        /// Network-level class tokens, spliced into the input before `in_proj`.
549        class_tokens: Vec<ClassToken>,
550        /// Suppress the first virtual layer's residual (Pre-LN skip / MultiGate
551        /// seed carry). See [`Layers`](crate::modules::Layers).
552        ignore_first_residual: bool,
553        /// Suppress the last virtual layer's residual (output is the last
554        /// layer's transform alone). See [`Layers`](crate::modules::Layers).
555        ignore_last_residual: bool,
556        /// Inter-layer residual scheme (plain additive vs Multi-Gate).
557        residuals: ResidualsConfig,
558    },
559    /// Build a Mamba-2 latent network.
560    #[cfg(feature = "mamba2")]
561    Mamba2 {
562        /// Input feature width.
563        input_size: usize,
564        /// Number of real layers.
565        n_real_layers: usize,
566        /// Optional virtual-layer scheduling.
567        n_virtual_layers: Option<(usize, Schedule)>,
568        /// Shared block config.
569        mamba_block: crate::mamba2::prelude::Mamba2Config,
570        /// Output feature width.
571        output_size: usize,
572        /// Network-level class tokens, spliced into the input before `in_proj`.
573        class_tokens: Vec<ClassToken>,
574        /// Suppress the first virtual layer's residual (Pre-LN skip / MultiGate
575        /// seed carry). See [`Layers`](crate::modules::Layers).
576        ignore_first_residual: bool,
577        /// Suppress the last virtual layer's residual (output is the last
578        /// layer's transform alone). See [`Layers`](crate::modules::Layers).
579        ignore_last_residual: bool,
580        /// Inter-layer residual scheme (plain additive vs Multi-Gate).
581        residuals: ResidualsConfig,
582    },
583    /// Build a Mamba-3 latent network.
584    #[cfg(feature = "mamba3")]
585    Mamba3 {
586        /// Input feature width.
587        input_size: usize,
588        /// Number of real layers.
589        n_real_layers: usize,
590        /// Optional virtual-layer scheduling.
591        n_virtual_layers: Option<(usize, Schedule)>,
592        /// Shared block config.
593        mamba_block: crate::mamba3::prelude::Mamba3Config,
594        /// Output feature width.
595        output_size: usize,
596        /// Network-level class tokens, spliced into the input before `in_proj`.
597        class_tokens: Vec<ClassToken>,
598        /// Suppress the first virtual layer's residual (Pre-LN skip / MultiGate
599        /// seed carry). See [`Layers`](crate::modules::Layers).
600        ignore_first_residual: bool,
601        /// Suppress the last virtual layer's residual (output is the last
602        /// layer's transform alone). See [`Layers`](crate::modules::Layers).
603        ignore_last_residual: bool,
604        /// Inter-layer residual scheme (plain additive vs Multi-Gate).
605        residuals: ResidualsConfig,
606    },
607}
608
609impl MambaLatentNetConfig {
610    /// Allocate and initialise the selected network on `device`.
611    pub fn init(&self, device: &Device) -> MambaLatentNet {
612        match self {
613            #[cfg(feature = "mamba1")]
614            Self::Mamba1 {
615                input_size,
616                n_real_layers,
617                n_virtual_layers,
618                mamba_block,
619                output_size,
620                class_tokens,
621                ignore_first_residual,
622                ignore_last_residual,
623                residuals,
624            } => MambaLatentNet::Mamba1(
625                LatentNetworkBuilder {
626                    input_size: *input_size,
627                    layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
628                        .with_n_virtual_layers(n_virtual_layers.clone())
629                        .with_residuals(residuals.clone())
630                        .with_ignore_first_residual(*ignore_first_residual)
631                        .with_ignore_last_residual(*ignore_last_residual),
632                    output_size: *output_size,
633                    class_tokens: class_tokens.clone(),
634                }
635                .init(device),
636            ),
637            #[cfg(feature = "mamba2")]
638            Self::Mamba2 {
639                input_size,
640                n_real_layers,
641                n_virtual_layers,
642                mamba_block,
643                output_size,
644                class_tokens,
645                ignore_first_residual,
646                ignore_last_residual,
647                residuals,
648            } => MambaLatentNet::Mamba2(
649                LatentNetworkBuilder {
650                    input_size: *input_size,
651                    layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
652                        .with_n_virtual_layers(n_virtual_layers.clone())
653                        .with_residuals(residuals.clone())
654                        .with_ignore_first_residual(*ignore_first_residual)
655                        .with_ignore_last_residual(*ignore_last_residual),
656                    output_size: *output_size,
657                    class_tokens: class_tokens.clone(),
658                }
659                .init(device),
660            ),
661            #[cfg(feature = "mamba3")]
662            Self::Mamba3 {
663                input_size,
664                n_real_layers,
665                n_virtual_layers,
666                mamba_block,
667                output_size,
668                class_tokens,
669                ignore_first_residual,
670                ignore_last_residual,
671                residuals,
672            } => MambaLatentNet::Mamba3(
673                LatentNetworkBuilder {
674                    input_size: *input_size,
675                    layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
676                        .with_n_virtual_layers(n_virtual_layers.clone())
677                        .with_residuals(residuals.clone())
678                        .with_ignore_first_residual(*ignore_first_residual)
679                        .with_ignore_last_residual(*ignore_last_residual),
680                    output_size: *output_size,
681                    class_tokens: class_tokens.clone(),
682                }
683                .init(device),
684            ),
685        }
686    }
687}
688
689/// A runtime-selectable token language model: the same `Embedding → Layers →
690/// norm_f → LM head` shape over any Mamba-x family, chosen at runtime. The
691/// vocabulary counterpart of [`MambaLatentNet`].
692#[derive(Module, Debug)]
693pub enum MambaVocabNet {
694    /// Mamba-1 language model.
695    #[cfg(feature = "mamba1")]
696    Mamba1(VocabNetwork<crate::mamba1::prelude::Mamba1>),
697    /// Mamba-2 language model.
698    #[cfg(feature = "mamba2")]
699    Mamba2(VocabNetwork<crate::mamba2::prelude::Mamba2>),
700    /// Mamba-3 language model.
701    #[cfg(feature = "mamba3")]
702    Mamba3(VocabNetwork<crate::mamba3::prelude::Mamba3>),
703}
704
705impl MambaVocabNet {
706    /// Full-sequence pass: token IDs `[batch, sequence]` → logits
707    /// `[batch, sequence, padded_vocab]`. The `ssd_path`/`caches` family must
708    /// match the network; a mismatch is a caller error and panics.
709    pub fn forward(
710        &self,
711        x: Tensor<2, Int>,
712        caches: Option<MambaCaches>,
713        ssd_path: MambaSsdPath,
714    ) -> (Tensor<3>, MambaCaches) {
715        match self {
716            #[cfg(feature = "mamba1")]
717            Self::Mamba1(net) => {
718                let caches = caches.map(|c| match c {
719                    MambaCaches::Mamba1(c) => c,
720                    #[allow(unreachable_patterns)]
721                    _ => panic!("cache family does not match Mamba-1 network"),
722                });
723                match ssd_path {
724                    MambaSsdPath::Mamba1 => {}
725                    #[allow(unreachable_patterns)]
726                    _ => panic!("ssd_path family does not match Mamba-1 network"),
727                }
728                let (y, c) = net.forward(x, caches, ());
729                (y, MambaCaches::Mamba1(c))
730            }
731            #[cfg(feature = "mamba2")]
732            Self::Mamba2(net) => {
733                let caches = caches.map(|c| match c {
734                    MambaCaches::Mamba2(c) => c,
735                    #[allow(unreachable_patterns)]
736                    _ => panic!("cache family does not match Mamba-2 network"),
737                });
738                let path = match ssd_path {
739                    MambaSsdPath::Mamba2(p) => p,
740                    #[allow(unreachable_patterns)]
741                    _ => panic!("ssd_path family does not match Mamba-2 network"),
742                };
743                let (y, c) = net.forward(x, caches, path);
744                (y, MambaCaches::Mamba2(c))
745            }
746            #[cfg(feature = "mamba3")]
747            Self::Mamba3(net) => {
748                let caches = caches.map(|c| match c {
749                    MambaCaches::Mamba3(c) => c,
750                    #[allow(unreachable_patterns)]
751                    _ => panic!("cache family does not match Mamba-3 network"),
752                });
753                let path = match ssd_path {
754                    MambaSsdPath::Mamba3(p) => p,
755                    #[allow(unreachable_patterns)]
756                    _ => panic!("ssd_path family does not match Mamba-3 network"),
757                };
758                let (y, c) = net.forward(x, caches, path);
759                (y, MambaCaches::Mamba3(c))
760            }
761        }
762    }
763
764    /// Single-token step: token IDs `[batch]` → logits `[batch, padded_vocab]`.
765    /// Cache family must match the network. The two inner [`Layers`] class
766    /// cursors (`layers_own_index`, `layer_indices`) are forwarded — see
767    /// [`VocabNetwork::step`].
768    pub fn step(
769        &self,
770        x: Tensor<1, Int>,
771        caches: Option<MambaCaches>,
772        layers_own_index: Option<&mut usize>,
773        layer_indices: Option<&mut Vec<usize>>,
774    ) -> (Tensor<2>, MambaCaches) {
775        match self {
776            #[cfg(feature = "mamba1")]
777            Self::Mamba1(net) => {
778                let caches = caches.map(|c| match c {
779                    MambaCaches::Mamba1(c) => c,
780                    #[allow(unreachable_patterns)]
781                    _ => panic!("cache family does not match Mamba-1 network"),
782                });
783                let (y, c) = net.step(x, caches, layers_own_index, layer_indices);
784                (y, MambaCaches::Mamba1(c))
785            }
786            #[cfg(feature = "mamba2")]
787            Self::Mamba2(net) => {
788                let caches = caches.map(|c| match c {
789                    MambaCaches::Mamba2(c) => c,
790                    #[allow(unreachable_patterns)]
791                    _ => panic!("cache family does not match Mamba-2 network"),
792                });
793                let (y, c) = net.step(x, caches, layers_own_index, layer_indices);
794                (y, MambaCaches::Mamba2(c))
795            }
796            #[cfg(feature = "mamba3")]
797            Self::Mamba3(net) => {
798                let caches = caches.map(|c| match c {
799                    MambaCaches::Mamba3(c) => c,
800                    #[allow(unreachable_patterns)]
801                    _ => panic!("cache family does not match Mamba-3 network"),
802                });
803                let (y, c) = net.step(x, caches, layers_own_index, layer_indices);
804                (y, MambaCaches::Mamba3(c))
805            }
806        }
807    }
808
809    /// Stationary fixed point under a constant token (no caches) — see
810    /// [`VocabNetwork::step_infinite`]. Only the Mamba-3 family implements the
811    /// closed form; the other variants panic.
812    pub fn step_infinite(&self, x: Tensor<1, Int>) -> Tensor<2> {
813        match self {
814            #[cfg(feature = "mamba1")]
815            Self::Mamba1(net) => net.step_infinite(x),
816            #[cfg(feature = "mamba2")]
817            Self::Mamba2(net) => net.step_infinite(x),
818            #[cfg(feature = "mamba3")]
819            Self::Mamba3(net) => net.step_infinite(x),
820        }
821    }
822
823    /// Approximate jump of `n` consecutive constant-token steps — see
824    /// [`VocabNetwork::step_n_approx`]. Cache family must match the network;
825    /// only the Mamba-3 family implements the closed form (others panic).
826    pub fn step_n_approx(
827        &self,
828        x: Tensor<1, Int>,
829        n: usize,
830        caches: Option<MambaCaches>,
831    ) -> (Tensor<2>, MambaCaches) {
832        match self {
833            #[cfg(feature = "mamba1")]
834            Self::Mamba1(net) => {
835                let caches = caches.map(|c| match c {
836                    MambaCaches::Mamba1(c) => c,
837                    #[allow(unreachable_patterns)]
838                    _ => panic!("cache family does not match Mamba-1 network"),
839                });
840                let (y, c) = net.step_n_approx(x, n, caches);
841                (y, MambaCaches::Mamba1(c))
842            }
843            #[cfg(feature = "mamba2")]
844            Self::Mamba2(net) => {
845                let caches = caches.map(|c| match c {
846                    MambaCaches::Mamba2(c) => c,
847                    #[allow(unreachable_patterns)]
848                    _ => panic!("cache family does not match Mamba-2 network"),
849                });
850                let (y, c) = net.step_n_approx(x, n, caches);
851                (y, MambaCaches::Mamba2(c))
852            }
853            #[cfg(feature = "mamba3")]
854            Self::Mamba3(net) => {
855                let caches = caches.map(|c| match c {
856                    MambaCaches::Mamba3(c) => c,
857                    #[allow(unreachable_patterns)]
858                    _ => panic!("cache family does not match Mamba-3 network"),
859                });
860                let (y, c) = net.step_n_approx(x, n, caches);
861                (y, MambaCaches::Mamba3(c))
862            }
863        }
864    }
865}
866
867/// The serializable, documentation-friendly config for [`MambaVocabNet`]. Each
868/// variant is concrete (per-family), so `#[derive(Config)]` applies; `init`
869/// builds the matching network variant.
870#[derive(Config, Debug)]
871pub enum MambaVocabNetConfig {
872    /// Build a Mamba-1 language model.
873    #[cfg(feature = "mamba1")]
874    Mamba1 {
875        /// Number of real layers.
876        n_real_layers: usize,
877        /// Optional virtual-layer scheduling.
878        n_virtual_layers: Option<(usize, Schedule)>,
879        /// Unpadded vocabulary size.
880        vocab_size: usize,
881        /// Round `vocab_size` up to a multiple of this (1 disables rounding).
882        pad_vocab_size_multiple: usize,
883        /// Shared block config.
884        mamba_block: crate::mamba1::prelude::Mamba1Config,
885        /// Tie the LM head to the (transposed) embedding weights when `true`.
886        missing_lm_head: bool,
887        /// Suppress the first virtual layer's residual (Pre-LN skip / MultiGate
888        /// seed carry). See [`Layers`](crate::modules::Layers).
889        ignore_first_residual: bool,
890        /// Suppress the last virtual layer's residual (output is the last
891        /// layer's transform alone). See [`Layers`](crate::modules::Layers).
892        ignore_last_residual: bool,
893        /// Inter-layer residual scheme (plain additive vs Multi-Gate).
894        residuals: ResidualsConfig,
895    },
896    /// Build a Mamba-2 language model.
897    #[cfg(feature = "mamba2")]
898    Mamba2 {
899        /// Number of real layers.
900        n_real_layers: usize,
901        /// Optional virtual-layer scheduling.
902        n_virtual_layers: Option<(usize, Schedule)>,
903        /// Unpadded vocabulary size.
904        vocab_size: usize,
905        /// Round `vocab_size` up to a multiple of this (1 disables rounding).
906        pad_vocab_size_multiple: usize,
907        /// Shared block config.
908        mamba_block: crate::mamba2::prelude::Mamba2Config,
909        /// Tie the LM head to the (transposed) embedding weights when `true`.
910        missing_lm_head: bool,
911        /// Suppress the first virtual layer's residual (Pre-LN skip / MultiGate
912        /// seed carry). See [`Layers`](crate::modules::Layers).
913        ignore_first_residual: bool,
914        /// Suppress the last virtual layer's residual (output is the last
915        /// layer's transform alone). See [`Layers`](crate::modules::Layers).
916        ignore_last_residual: bool,
917        /// Inter-layer residual scheme (plain additive vs Multi-Gate).
918        residuals: ResidualsConfig,
919    },
920    /// Build a Mamba-3 language model.
921    #[cfg(feature = "mamba3")]
922    Mamba3 {
923        /// Number of real layers.
924        n_real_layers: usize,
925        /// Optional virtual-layer scheduling.
926        n_virtual_layers: Option<(usize, Schedule)>,
927        /// Unpadded vocabulary size.
928        vocab_size: usize,
929        /// Round `vocab_size` up to a multiple of this (1 disables rounding).
930        pad_vocab_size_multiple: usize,
931        /// Shared block config.
932        mamba_block: crate::mamba3::prelude::Mamba3Config,
933        /// Tie the LM head to the (transposed) embedding weights when `true`.
934        missing_lm_head: bool,
935        /// Suppress the first virtual layer's residual (Pre-LN skip / MultiGate
936        /// seed carry). See [`Layers`](crate::modules::Layers).
937        ignore_first_residual: bool,
938        /// Suppress the last virtual layer's residual (output is the last
939        /// layer's transform alone). See [`Layers`](crate::modules::Layers).
940        ignore_last_residual: bool,
941        /// Inter-layer residual scheme (plain additive vs Multi-Gate).
942        residuals: ResidualsConfig,
943    },
944}
945
946impl MambaVocabNetConfig {
947    /// Allocate and initialise the selected language model on `device`.
948    pub fn init(&self, device: &Device) -> MambaVocabNet {
949        match self {
950            #[cfg(feature = "mamba1")]
951            Self::Mamba1 {
952                n_real_layers,
953                n_virtual_layers,
954                vocab_size,
955                pad_vocab_size_multiple,
956                mamba_block,
957                missing_lm_head,
958                ignore_first_residual,
959                ignore_last_residual,
960                residuals,
961            } => MambaVocabNet::Mamba1(
962                VocabNetworkBuilder {
963                    vocab_size: *vocab_size,
964                    pad_vocab_size_multiple: *pad_vocab_size_multiple,
965                    layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
966                        .with_n_virtual_layers(n_virtual_layers.clone())
967                        .with_residuals(residuals.clone())
968                        .with_ignore_first_residual(*ignore_first_residual)
969                        .with_ignore_last_residual(*ignore_last_residual),
970                    missing_lm_head: *missing_lm_head,
971                }
972                .init(device),
973            ),
974            #[cfg(feature = "mamba2")]
975            Self::Mamba2 {
976                n_real_layers,
977                n_virtual_layers,
978                vocab_size,
979                pad_vocab_size_multiple,
980                mamba_block,
981                missing_lm_head,
982                ignore_first_residual,
983                ignore_last_residual,
984                residuals,
985            } => MambaVocabNet::Mamba2(
986                VocabNetworkBuilder {
987                    vocab_size: *vocab_size,
988                    pad_vocab_size_multiple: *pad_vocab_size_multiple,
989                    layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
990                        .with_n_virtual_layers(n_virtual_layers.clone())
991                        .with_residuals(residuals.clone())
992                        .with_ignore_first_residual(*ignore_first_residual)
993                        .with_ignore_last_residual(*ignore_last_residual),
994                    missing_lm_head: *missing_lm_head,
995                }
996                .init(device),
997            ),
998            #[cfg(feature = "mamba3")]
999            Self::Mamba3 {
1000                n_real_layers,
1001                n_virtual_layers,
1002                vocab_size,
1003                pad_vocab_size_multiple,
1004                mamba_block,
1005                missing_lm_head,
1006                ignore_first_residual,
1007                ignore_last_residual,
1008                residuals,
1009            } => MambaVocabNet::Mamba3(
1010                VocabNetworkBuilder {
1011                    vocab_size: *vocab_size,
1012                    pad_vocab_size_multiple: *pad_vocab_size_multiple,
1013                    layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
1014                        .with_n_virtual_layers(n_virtual_layers.clone())
1015                        .with_residuals(residuals.clone())
1016                        .with_ignore_first_residual(*ignore_first_residual)
1017                        .with_ignore_last_residual(*ignore_last_residual),
1018                    missing_lm_head: *missing_lm_head,
1019                }
1020                .init(device),
1021            ),
1022        }
1023    }
1024}