Skip to main content

burn_mamba/modules/
bidi.rs

1use crate::modules::{Residuals, ResidualsConfig, RmsNorm, RmsNormConfig};
2use crate::prelude::*;
3use crate::utils::BidiSchedule;
4use crate::utils::ClassLatent;
5use crate::utils::class::{class_marker_output_indices, init_class_emb, insert_class_markers};
6use burn::config::Config;
7use burn::module::Param;
8use burn::nn::{Linear, LinearConfig};
9use burn::prelude::*;
10
11#[cfg(test)]
12mod tests;
13
14// ===========================================================================
15// Bidirectional support (family-generic; forward-only, non-autoregressive)
16// ===========================================================================
17//
18// A `BidiLayerPair<M>` runs a straight (→) and a reversed (← via `flip`) Pre-LN
19// pass and merges them with an [`OutputMerge`]; `BidiLayers<M>` stacks pairs with
20// a [`BidiSchedule`]. The block itself is unchanged — only how its two passes are
21// scheduled and combined is bidirectional. Written once for all families; the
22// merge is family-agnostic (`RmsNorm`/`Linear` over `Tensor<3>`).
23
24/// A zero-parameter placeholder for the parameterless `Mean` merge.
25#[derive(Module, Debug)]
26pub struct NoOp;
27
28/// How the two directions of a bidirectional pair are combined.
29#[allow(clippy::large_enum_variant)]
30#[derive(Module, Debug)]
31pub enum OutputMerge {
32    /// Element-wise average of the two directions (no parameters).
33    Mean(NoOp),
34    /// Concatenate along the feature axis and project back down with a learnable
35    /// `[2 · d_model, d_model]` linear layer.
36    CatLinear(Linear),
37}
38
39impl OutputMerge {
40    /// Merge the two directional outputs (each `[batch, sequence, d_model]`).
41    pub fn forward(&self, straight: Tensor<3>, reverse: Tensor<3>) -> Tensor<3> {
42        let [batch, sequence, d_model] = straight.dims();
43        assert_eq!(straight.dims(), reverse.dims());
44        match self {
45            OutputMerge::Mean(_) => (straight + reverse) * 0.5,
46            OutputMerge::CatLinear(proj) => {
47                let cat = Tensor::cat([straight, reverse].to_vec(), 2);
48                assert_eq!([batch, sequence, 2 * d_model], cat.dims());
49                let merged = proj.forward(cat);
50                assert_eq!([batch, sequence, d_model], merged.dims());
51                merged
52            }
53        }
54    }
55}
56
57/// Configuration / factory for [`OutputMerge`].
58#[derive(Config, Debug)]
59pub enum OutputMergeConfig {
60    /// Build an [`OutputMerge::Mean`].
61    Mean,
62    /// Build an [`OutputMerge::CatLinear`].
63    CatLinear,
64}
65
66impl OutputMergeConfig {
67    /// A vector of `n_real_layers / 2` [`Self::Mean`] configs (one per pair).
68    pub fn mean(n_real_layers: usize) -> Vec<Self> {
69        vec![Self::Mean; n_real_layers / 2]
70    }
71    /// A vector of `n_real_layers / 2` [`Self::CatLinear`] configs (one per pair).
72    pub fn cat_linear(n_real_layers: usize) -> Vec<Self> {
73        vec![Self::CatLinear; n_real_layers / 2]
74    }
75    /// Allocate the merge module on `device` for the given `d_model`.
76    pub fn init(&self, d_model: usize, device: &Device) -> OutputMerge {
77        match self {
78            OutputMergeConfig::Mean => OutputMerge::Mean(NoOp),
79            OutputMergeConfig::CatLinear => {
80                OutputMerge::CatLinear(LinearConfig::new(d_model * 2, d_model).init(device))
81            }
82        }
83    }
84}
85
86/// A single bidirectional pair: a straight (→) and a reversed (←) Pre-LN block
87/// whose outputs are merged. The residual is **not** applied here — the
88/// enclosing [`BidiLayers`] adds it (or suppresses it on the first/last pair),
89/// mirroring the [`Layer`](crate::modules::Layer) / [`Layers`](crate::modules::Layers) split.
90#[derive(Module, Debug)]
91pub struct BidiLayerPair<M: Module> {
92    /// Pre-norm for the straight pass.
93    pub straight_norm: RmsNorm,
94    /// Pre-norm for the reversed pass.
95    pub reverse_norm: RmsNorm,
96    /// The block run left-to-right.
97    pub straight_block: M,
98    /// The block run right-to-left (over the flipped sequence).
99    pub reverse_block: M,
100    /// Merge strategy combining the two directions.
101    pub output_merge: OutputMerge,
102    /// Positions of this pair's class latents, spliced in before either
103    /// direction runs (both directions, and the residual, see the lengthened
104    /// sequence). Empty ⇒ none.
105    #[module(skip)]
106    pub class_latents: Vec<ClassLatent>,
107    /// This pair's class-latent embeddings, `[num_class_latents, d_model]`.
108    pub class_latents_emb: Option<Param<Tensor<2>>>,
109}
110
111impl<M: MambaBlock> BidiLayerPair<M>
112where
113    M::SsdPath: Clone,
114{
115    /// Splice this bidi-layer-pair's class latents into `x` (no-op when there are none).
116    fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
117        if self.class_latents_emb.is_none() {
118            return x;
119        }
120        insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
121    }
122
123    /// `[batch, sequence, d_model]` → `[batch, sequence, d_model]`, plus the two
124    /// updated direction caches. (`sequence` grows by the class-latent count.)
125    /// Returns the merged directions **without** the residual — the enclosing
126    /// [`BidiLayers`] adds it.
127    pub fn forward(
128        &self,
129        x: Tensor<3>,
130        straight_cache: Option<M::Cache>,
131        reverse_cache: Option<M::Cache>,
132        ssd_path: M::SsdPath,
133    ) -> (Tensor<3>, M::Cache, M::Cache) {
134        let x = self.insert_latents(x);
135        bidi_pair_forward(
136            &self.straight_norm,
137            &self.reverse_norm,
138            &self.straight_block,
139            &self.reverse_block,
140            &self.output_merge,
141            x,
142            straight_cache,
143            reverse_cache,
144            ssd_path,
145        )
146    }
147}
148
149/// The straight + reverse + merge computation of a bidirectional pair, over
150/// **borrowed** sub-modules.
151///
152/// Taking references (rather than owning clones) is load-bearing: a Burn `Param`
153/// that is still lazily-initialised re-runs its random initialiser **on every
154/// clone**, so cloning a not-yet-materialised block per forward would resample
155/// fresh random weights each call. [`BidiLayers`] therefore calls this directly
156/// on its real layers instead of building a transient (cloned) [`BidiLayerPair`].
157#[allow(clippy::too_many_arguments)]
158fn bidi_pair_forward<M: MambaBlock>(
159    straight_norm: &RmsNorm,
160    reverse_norm: &RmsNorm,
161    straight_block: &M,
162    reverse_block: &M,
163    output_merge: &OutputMerge,
164    x: Tensor<3>,
165    straight_cache: Option<M::Cache>,
166    reverse_cache: Option<M::Cache>,
167    ssd_path: M::SsdPath,
168) -> (Tensor<3>, M::Cache, M::Cache)
169where
170    M::SsdPath: Clone,
171{
172    let [batch, sequence, d_model] = x.dims();
173
174    // x reads >x₀>x₁>…; x_rev (flipped) reads the sequence backwards.
175    let x_rev = x.clone().flip([1]);
176    let x = straight_norm.forward(x);
177    let x_rev = reverse_norm.forward(x_rev);
178
179    let (x, straight_cache) = straight_block.block_forward(x, straight_cache, ssd_path.clone());
180    assert_eq!([batch, sequence, d_model], x.dims());
181
182    let (x_rev, reverse_cache) = reverse_block.block_forward(x_rev, reverse_cache, ssd_path);
183    assert_eq!([batch, sequence, d_model], x_rev.dims());
184
185    // Re-align the reversed read, then merge.
186    let x_rev = x_rev.flip([1]);
187    let merged = output_merge.forward(x, x_rev);
188    (merged, straight_cache, reverse_cache)
189}
190
191/// A stack of bidirectional [`Layer`] pairs with optional virtual-layer
192/// scheduling — one struct for every Mamba-x family.
193#[derive(Module, Debug)]
194pub struct BidiLayers<M: Module> {
195    /// Number of real (weight-bearing) layers; must be even (used in pairs).
196    pub n_real_layers: usize,
197    /// Optional `(n_virtual_layers, schedule)` for weight-sharing.
198    #[module(skip)]
199    pub n_virtual_layers: Option<(usize, BidiSchedule)>,
200    /// The weight-bearing layers, length `n_real_layers`.
201    pub real_layers: Vec<Layer<M>>,
202    /// Zero the first virtual pair's residual when `true`.
203    pub ignore_first_residual: bool,
204    /// Zero the last virtual pair's residual when `true`.
205    pub ignore_last_residual: bool,
206    /// One direction-merge per pair, length `n_real_layers / 2`.
207    pub outputs_merge: Vec<OutputMerge>,
208    /// How residuals are threaded between **pairs** (plain additive vs
209    /// Multi-Gate). The MGR unit is the pair: one module per real/virtual pair.
210    pub residuals: Residuals,
211    /// Positions of the stack-level class latents, spliced into the sequence
212    /// once before the first pair (independent of any per-pair class latents).
213    #[module(skip)]
214    pub class_latents: Vec<ClassLatent>,
215    /// The stack-level class-latent embeddings, `[num_class_latents, d_model]`.
216    pub class_latents_emb: Option<Param<Tensor<2>>>,
217}
218
219impl<M: MambaBlock + Clone> BidiLayers<M>
220where
221    M::SsdPath: Clone,
222{
223    /// Output positions of the stack-level class latents for an `orig_len` input.
224    pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize> {
225        class_marker_output_indices(&self.class_latents, orig_len)
226    }
227
228    /// Splice this bidi-layers' class latents into `x` (no-op when there are none).
229    fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
230        if self.class_latents_emb.is_none() {
231            return x;
232        }
233        insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
234    }
235
236    /// Seed the MultiGate streams from a full-sequence input — `n_stream` copies
237    /// of `x` as `[batch, sequence, n_stream, d_model]` — or `None` for the
238    /// Standard path. Panics if MultiGate is paired with stack-level class latents.
239    fn multi_gate_streams_seed(&self, x: &Tensor<3>) -> Option<Tensor<4>> {
240        let Residuals::MultiGate(mg) = &self.residuals else {
241            return None;
242        };
243        assert!(
244            self.class_latents_emb.is_none(),
245            "MultiGate residuals do not support stack-level class latents"
246        );
247        let [batch, sequence, d_model] = x.dims();
248        Some(
249            x.clone()
250                .unsqueeze_dim::<4>(2)
251                .expand([batch, sequence, mg.n_stream, d_model]),
252        )
253    }
254
255    /// `[batch, sequence, d_model]` → `[batch, sequence, d_model]`
256    /// (`sequence` grows by the stack-level class-latent count).
257    ///
258    /// Each pair returns its merged transform `F_l` (no residual). With
259    /// [`Residuals::Standard`] the input skip is added per pair (unless
260    /// suppressed). With [`Residuals::MultiGate`] the skip is dropped and
261    /// `n_stream` parallel streams — seeded from `x` — carry the residual between
262    /// pairs: each pair reads their attention-pooled aggregate as input and its
263    /// merged output is gated back into every stream (see [`MultiGate`]).
264    ///
265    /// [`MultiGate`]: crate::modules::MultiGate
266    pub fn forward(
267        &self,
268        mut x: Tensor<3>,
269        caches: Option<M::Caches>,
270        ssd_path: M::SsdPath,
271    ) -> (Tensor<3>, M::Caches) {
272        x = self.insert_latents(x);
273        let n = self
274            .n_virtual_layers
275            .as_ref()
276            .map(|(l, _)| {
277                assert!(l.is_multiple_of(2), "Bidi virtual layers are used in pairs");
278                *l
279            })
280            .unwrap_or_else(|| {
281                assert!(
282                    self.n_real_layers.is_multiple_of(2),
283                    "Bidi layers are used in pairs"
284                );
285                self.n_real_layers
286            });
287
288        let caches =
289            caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_3d(&x, n));
290        assert_eq!(
291            caches.slot_count(),
292            n,
293            "straight and reverse layers cannot share caches"
294        );
295
296        let mut slots = caches.into_slots();
297        // MultiGate keeps `n_stream` parallel streams (seeded from the input);
298        // Standard threads the single tensor `x` directly (streams stays `None`).
299        let mut streams = self.multi_gate_streams_seed(&x);
300        for i in 0..n / 2 {
301            let (straight_i, reverse_i) = (i * 2, i * 2 + 1);
302            let (straight_idx, reverse_idx) =
303                if let Some((n_virtual, schedule)) = &self.n_virtual_layers {
304                    (
305                        schedule.real_idx(straight_i, *n_virtual, self.n_real_layers),
306                        schedule.real_idx(reverse_i, *n_virtual, self.n_real_layers),
307                    )
308                } else {
309                    (straight_i, reverse_i)
310                };
311            let straight_layer = &self.real_layers[straight_idx];
312            let reverse_layer = &self.real_layers[reverse_idx];
313
314            let straight_cache = slots[straight_i].take().unwrap();
315            let reverse_cache = slots[reverse_i].take().unwrap();
316
317            let first = self.ignore_first_residual && i == 0;
318            let last = self.ignore_last_residual && i + 1 == n / 2;
319
320            // For the Standard path the residual is the (pre-pair) input skip;
321            // clone it before the pair consumes `x`, and only when it is used.
322            // MultiGate carries the residual in its streams, so clones nothing.
323            let residual = match &self.residuals {
324                Residuals::Standard(_) if !(first || last) => Some(x.clone()),
325                _ => None,
326            };
327
328            // Run the pair directly on the (borrowed) real layers — never clone a
329            // block, since cloning a lazily-initialised `Param` resamples its
330            // random weights (see [`bidi_pair_forward`]). Stack-level class
331            // latents were already spliced above; pairs carry none of their own.
332            //
333            // The pair returns its merged transform `F_l` without the residual.
334            // The merge is a per-real-pair weight set (`n_real_layers / 2` of
335            // them), so it is indexed by the *real* pair `straight_idx / 2` — not
336            // the virtual pair `i` — sharing weights under virtual scheduling just
337            // like the blocks (and matching the MGR real-pair index below). In the
338            // non-virtual case `straight_idx == i * 2`, so this is `i`.
339            let (merged, sc, rc) = bidi_pair_forward(
340                &straight_layer.norm,
341                &reverse_layer.norm,
342                &straight_layer.mamba_block,
343                &reverse_layer.mamba_block,
344                &self.outputs_merge[straight_idx / 2],
345                x,
346                Some(straight_cache),
347                Some(reverse_cache),
348                ssd_path.clone(),
349            );
350            slots[straight_i] = Some(sc);
351            slots[reverse_i] = Some(rc);
352
353            match &self.residuals {
354                Residuals::Standard(_noop) => {
355                    // Add the input skip here (the pair already consumed `x`), or
356                    // output the bare transform when the residual is suppressed.
357                    x = match residual {
358                        Some(r) => merged + r,
359                        None => merged,
360                    };
361                }
362                Residuals::MultiGate(mg) => {
363                    let s = streams.take().unwrap();
364                    // A skipped residual is β ≡ 1 in the mixer (`new_streams =
365                    // F_l`), the aggregator then collapsing to `F_l` — both
366                    // branches shortcut that (mirrors `Layers::forward`). The MGR
367                    // unit is the pair: virtual pair `i`, real pair `straight_idx
368                    // / 2` (the straight index of a pair is even).
369                    if last {
370                        x = merged;
371                        streams = Some(s);
372                    } else if first {
373                        let [b, seq, d] = merged.dims();
374                        streams = Some(merged.clone().unsqueeze_dim::<4>(2).expand([
375                            b,
376                            seq,
377                            mg.n_stream,
378                            d,
379                        ]));
380                        x = merged;
381                    } else {
382                        let idx = mg.module_index(i, straight_idx / 2);
383                        let (new_h, new_streams) = mg.layers[idx].forward(merged, s);
384                        x = new_h;
385                        streams = Some(new_streams);
386                    }
387                }
388            }
389        }
390
391        (x, M::Caches::from_slots(slots))
392    }
393}
394
395/// Plain (non-serde) factory for [`BidiLayers`].
396pub struct BidiLayersBuilder<C> {
397    /// Number of real (weight-bearing) layers (must be even).
398    pub n_real_layers: usize,
399    /// Optional virtual-layer scheduling.
400    pub n_virtual_layers: Option<(usize, BidiSchedule)>,
401    /// Shared block config.
402    pub mamba_block: C,
403    /// Zero the first virtual pair's residual.
404    pub ignore_first_residual: bool,
405    /// Zero the last virtual pair's residual.
406    pub ignore_last_residual: bool,
407    /// One merge config per pair, length `n_real_layers / 2`.
408    pub outputs_merge: Vec<OutputMergeConfig>,
409    /// Stack-level class latents (spliced once before the first pair).
410    pub class_latents: Vec<ClassLatent>,
411    /// Inter-pair residual scheme (defaults to plain additive).
412    pub residuals: ResidualsConfig,
413}
414
415impl<C: MambaBlockConfig> BidiLayersBuilder<C> {
416    /// Allocate and initialise the bidirectional stack on `device`.
417    pub fn init(&self, device: &Device) -> BidiLayers<C::Block> {
418        let d_model = self.mamba_block.d_model();
419        let real_layers = (0..self.n_real_layers)
420            .map(|_| Layer {
421                norm: RmsNormConfig::new(d_model).init(device),
422                mamba_block: self.mamba_block.init_block(device),
423                class_latents: Vec::new(),
424                class_latents_emb: None,
425            })
426            .collect();
427        let outputs_merge = (0..self.n_real_layers / 2)
428            .map(|i| self.outputs_merge[i].init(d_model, device))
429            .collect();
430        // The MGR unit is the pair, so size the modules by *pairs* (halved real
431        // and virtual layer counts).
432        let n_virtual = self
433            .n_virtual_layers
434            .as_ref()
435            .map(|(l, _)| *l)
436            .unwrap_or(self.n_real_layers);
437        let residuals = self
438            .residuals
439            .init(d_model, self.n_real_layers / 2, n_virtual / 2, device);
440        BidiLayers {
441            n_real_layers: self.n_real_layers,
442            n_virtual_layers: self.n_virtual_layers.clone(),
443            real_layers,
444            ignore_first_residual: self.ignore_first_residual,
445            ignore_last_residual: self.ignore_last_residual,
446            outputs_merge,
447            residuals,
448            class_latents_emb: init_class_emb(self.class_latents.len(), d_model, device),
449            class_latents: self.class_latents.clone(),
450        }
451    }
452}
453
454// ===========================================================================
455// Unifying enums: one runtime + one serializable Config across all families
456// ===========================================================================
457
458/// A runtime-selectable bidirectional stack: the same paired straight/reverse
459/// structure over any Mamba-x family, chosen at runtime. The forward-only
460/// counterpart of [`MambaLatentNet`] for non-autoregressive tasks.
461#[derive(Module, Debug)]
462pub enum MambaBidiLayers {
463    /// Mamba-1 bidirectional stack.
464    #[cfg(feature = "mamba1")]
465    Mamba1(BidiLayers<crate::mamba1::prelude::Mamba1>),
466    /// Mamba-2 bidirectional stack.
467    #[cfg(feature = "mamba2")]
468    Mamba2(BidiLayers<crate::mamba2::prelude::Mamba2>),
469    /// Mamba-3 bidirectional stack.
470    #[cfg(feature = "mamba3")]
471    Mamba3(BidiLayers<crate::mamba3::prelude::Mamba3>),
472}
473
474impl MambaBidiLayers {
475    /// Output positions of the stack-level class latents for an `orig_len`
476    /// input (so a caller can read a class latent back out of the lengthened
477    /// `forward` output — e.g. as a pooled summary).
478    pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize> {
479        match self {
480            #[cfg(feature = "mamba1")]
481            Self::Mamba1(layers) => layers.class_latent_output_indices(orig_len),
482            #[cfg(feature = "mamba2")]
483            Self::Mamba2(layers) => layers.class_latent_output_indices(orig_len),
484            #[cfg(feature = "mamba3")]
485            Self::Mamba3(layers) => layers.class_latent_output_indices(orig_len),
486        }
487    }
488
489    /// Full-sequence bidirectional pass. The `ssd_path` must match the stack's
490    /// family; a mismatch is a caller error and panics.
491    pub fn forward(
492        &self,
493        x: Tensor<3>,
494        caches: Option<MambaCaches>,
495        ssd_path: MambaSsdPath,
496    ) -> (Tensor<3>, MambaCaches) {
497        match self {
498            #[cfg(feature = "mamba1")]
499            Self::Mamba1(layers) => {
500                let caches = caches.map(|c| match c {
501                    MambaCaches::Mamba1(c) => c,
502                    #[allow(unreachable_patterns)]
503                    _ => panic!("cache family does not match Mamba-1 bidi stack"),
504                });
505                match ssd_path {
506                    MambaSsdPath::Mamba1 => {}
507                    #[allow(unreachable_patterns)]
508                    _ => panic!("ssd_path family does not match Mamba-1 bidi stack"),
509                }
510                let (y, c) = layers.forward(x, caches, ());
511                (y, MambaCaches::Mamba1(c))
512            }
513            #[cfg(feature = "mamba2")]
514            Self::Mamba2(layers) => {
515                let caches = caches.map(|c| match c {
516                    MambaCaches::Mamba2(c) => c,
517                    #[allow(unreachable_patterns)]
518                    _ => panic!("cache family does not match Mamba-2 bidi stack"),
519                });
520                let path = match ssd_path {
521                    MambaSsdPath::Mamba2(p) => p,
522                    #[allow(unreachable_patterns)]
523                    _ => panic!("ssd_path family does not match Mamba-2 bidi stack"),
524                };
525                let (y, c) = layers.forward(x, caches, path);
526                (y, MambaCaches::Mamba2(c))
527            }
528            #[cfg(feature = "mamba3")]
529            Self::Mamba3(layers) => {
530                let caches = caches.map(|c| match c {
531                    MambaCaches::Mamba3(c) => c,
532                    #[allow(unreachable_patterns)]
533                    _ => panic!("cache family does not match Mamba-3 bidi stack"),
534                });
535                let path = match ssd_path {
536                    MambaSsdPath::Mamba3(p) => p,
537                    #[allow(unreachable_patterns)]
538                    _ => panic!("ssd_path family does not match Mamba-3 bidi stack"),
539                };
540                let (y, c) = layers.forward(x, caches, path);
541                (y, MambaCaches::Mamba3(c))
542            }
543        }
544    }
545}
546
547/// The serializable config for [`MambaBidiLayers`]. Each variant is concrete
548/// (per-family), so `#[derive(Config)]` applies; `init` builds the matching
549/// stack variant.
550#[derive(Config, Debug)]
551pub enum MambaBidiLayersConfig {
552    /// Build a Mamba-1 bidirectional stack.
553    #[cfg(feature = "mamba1")]
554    Mamba1 {
555        /// Number of real layers (must be even — used in pairs).
556        n_real_layers: usize,
557        n_virtual_layers: Option<(usize, BidiSchedule)>,
558        /// Shared block config.
559        mamba_block: crate::mamba1::prelude::Mamba1Config,
560        ignore_first_residual: bool,
561        ignore_last_residual: bool,
562        /// One merge config per pair, length `n_real_layers / 2`.
563        outputs_merge: Vec<OutputMergeConfig>,
564        /// Stack-level class latents, spliced into the sequence before the
565        /// first pair (e.g. a `Middle` summary latent in place of mean-pooling).
566        class_latents: Vec<ClassLatent>,
567        /// Inter-pair residual scheme (plain additive vs Multi-Gate).
568        residuals: ResidualsConfig,
569    },
570    /// Build a Mamba-2 bidirectional stack.
571    #[cfg(feature = "mamba2")]
572    Mamba2 {
573        /// Number of real layers (must be even — used in pairs).
574        n_real_layers: usize,
575        n_virtual_layers: Option<(usize, BidiSchedule)>,
576        /// Shared block config.
577        mamba_block: crate::mamba2::prelude::Mamba2Config,
578        ignore_first_residual: bool,
579        ignore_last_residual: bool,
580        /// One merge config per pair, length `n_real_layers / 2`.
581        outputs_merge: Vec<OutputMergeConfig>,
582        /// Stack-level class latents, spliced into the sequence before the
583        /// first pair (e.g. a `Middle` summary latent in place of mean-pooling).
584        class_latents: Vec<ClassLatent>,
585        /// Inter-pair residual scheme (plain additive vs Multi-Gate).
586        residuals: ResidualsConfig,
587    },
588    /// Build a Mamba-3 bidirectional stack.
589    #[cfg(feature = "mamba3")]
590    Mamba3 {
591        /// Number of real layers (must be even — used in pairs).
592        n_real_layers: usize,
593        n_virtual_layers: Option<(usize, BidiSchedule)>,
594        /// Shared block config.
595        mamba_block: crate::mamba3::prelude::Mamba3Config,
596        ignore_first_residual: bool,
597        ignore_last_residual: bool,
598        /// One merge config per pair, length `n_real_layers / 2`.
599        outputs_merge: Vec<OutputMergeConfig>,
600        /// Stack-level class latents, spliced into the sequence before the
601        /// first pair (e.g. a `Middle` summary latent in place of mean-pooling).
602        class_latents: Vec<ClassLatent>,
603        /// Inter-pair residual scheme (plain additive vs Multi-Gate).
604        residuals: ResidualsConfig,
605    },
606}
607
608impl MambaBidiLayersConfig {
609    /// Allocate and initialise the selected bidirectional stack on `device`.
610    pub fn init(&self, device: &Device) -> MambaBidiLayers {
611        match self {
612            #[cfg(feature = "mamba1")]
613            Self::Mamba1 {
614                n_real_layers,
615                n_virtual_layers,
616                mamba_block,
617                ignore_first_residual,
618                ignore_last_residual,
619                outputs_merge,
620                class_latents,
621                residuals,
622            } => MambaBidiLayers::Mamba1(
623                BidiLayersBuilder {
624                    n_real_layers: *n_real_layers,
625                    n_virtual_layers: n_virtual_layers.clone(),
626                    mamba_block: mamba_block.clone(),
627                    ignore_first_residual: *ignore_first_residual,
628                    ignore_last_residual: *ignore_last_residual,
629                    outputs_merge: outputs_merge.clone(),
630                    class_latents: class_latents.clone(),
631                    residuals: residuals.clone(),
632                }
633                .init(device),
634            ),
635            #[cfg(feature = "mamba2")]
636            Self::Mamba2 {
637                n_real_layers,
638                n_virtual_layers,
639                mamba_block,
640                ignore_first_residual,
641                ignore_last_residual,
642                outputs_merge,
643                class_latents,
644                residuals,
645            } => MambaBidiLayers::Mamba2(
646                BidiLayersBuilder {
647                    n_real_layers: *n_real_layers,
648                    n_virtual_layers: n_virtual_layers.clone(),
649                    mamba_block: mamba_block.clone(),
650                    ignore_first_residual: *ignore_first_residual,
651                    ignore_last_residual: *ignore_last_residual,
652                    outputs_merge: outputs_merge.clone(),
653                    class_latents: class_latents.clone(),
654                    residuals: residuals.clone(),
655                }
656                .init(device),
657            ),
658            #[cfg(feature = "mamba3")]
659            Self::Mamba3 {
660                n_real_layers,
661                n_virtual_layers,
662                mamba_block,
663                ignore_first_residual,
664                ignore_last_residual,
665                outputs_merge,
666                class_latents,
667                residuals,
668            } => MambaBidiLayers::Mamba3(
669                BidiLayersBuilder {
670                    n_real_layers: *n_real_layers,
671                    n_virtual_layers: n_virtual_layers.clone(),
672                    mamba_block: mamba_block.clone(),
673                    ignore_first_residual: *ignore_first_residual,
674                    ignore_last_residual: *ignore_last_residual,
675                    outputs_merge: outputs_merge.clone(),
676                    class_latents: class_latents.clone(),
677                    residuals: residuals.clone(),
678                }
679                .init(device),
680            ),
681        }
682    }
683}