burn_mamba/modules/bidi.rs
1use crate::modules::{Residuals, ResidualsConfig, RmsNorm, RmsNormConfig};
2use crate::prelude::*;
3use crate::utils::BidiSchedule;
4use crate::utils::ClassLatent;
5use crate::utils::class::{class_marker_output_indices, init_class_emb, insert_class_markers};
6use burn::config::Config;
7use burn::module::Param;
8use burn::nn::{Linear, LinearConfig};
9use burn::prelude::*;
10
11#[cfg(test)]
12mod tests;
13
14// ===========================================================================
15// Bidirectional support (family-generic; forward-only, non-autoregressive)
16// ===========================================================================
17//
18// A `BidiLayerPair<M>` runs a straight (→) and a reversed (← via `flip`) Pre-LN
19// pass and merges them with an [`OutputMerge`]; `BidiLayers<M>` stacks pairs with
20// a [`BidiSchedule`]. The block itself is unchanged — only how its two passes are
21// scheduled and combined is bidirectional. Written once for all families; the
22// merge is family-agnostic (`RmsNorm`/`Linear` over `Tensor<3>`).
23
24/// A zero-parameter placeholder for the parameterless `Mean` merge.
25#[derive(Module, Debug)]
26pub struct NoOp;
27
28/// How the two directions of a bidirectional pair are combined.
29#[allow(clippy::large_enum_variant)]
30#[derive(Module, Debug)]
31pub enum OutputMerge {
32 /// Element-wise average of the two directions (no parameters).
33 Mean(NoOp),
34 /// Concatenate along the feature axis and project back down with a learnable
35 /// `[2 · d_model, d_model]` linear layer.
36 CatLinear(Linear),
37}
38
39impl OutputMerge {
40 /// Merge the two directional outputs (each `[batch, sequence, d_model]`).
41 pub fn forward(&self, straight: Tensor<3>, reverse: Tensor<3>) -> Tensor<3> {
42 let [batch, sequence, d_model] = straight.dims();
43 assert_eq!(straight.dims(), reverse.dims());
44 match self {
45 OutputMerge::Mean(_) => (straight + reverse) * 0.5,
46 OutputMerge::CatLinear(proj) => {
47 let cat = Tensor::cat([straight, reverse].to_vec(), 2);
48 assert_eq!([batch, sequence, 2 * d_model], cat.dims());
49 let merged = proj.forward(cat);
50 assert_eq!([batch, sequence, d_model], merged.dims());
51 merged
52 }
53 }
54 }
55}
56
57/// Configuration / factory for [`OutputMerge`].
58#[derive(Config, Debug)]
59pub enum OutputMergeConfig {
60 /// Build an [`OutputMerge::Mean`].
61 Mean,
62 /// Build an [`OutputMerge::CatLinear`].
63 CatLinear,
64}
65
66impl OutputMergeConfig {
67 /// A vector of `n_real_layers / 2` [`Self::Mean`] configs (one per pair).
68 pub fn mean(n_real_layers: usize) -> Vec<Self> {
69 vec![Self::Mean; n_real_layers / 2]
70 }
71 /// A vector of `n_real_layers / 2` [`Self::CatLinear`] configs (one per pair).
72 pub fn cat_linear(n_real_layers: usize) -> Vec<Self> {
73 vec![Self::CatLinear; n_real_layers / 2]
74 }
75 /// Allocate the merge module on `device` for the given `d_model`.
76 pub fn init(&self, d_model: usize, device: &Device) -> OutputMerge {
77 match self {
78 OutputMergeConfig::Mean => OutputMerge::Mean(NoOp),
79 OutputMergeConfig::CatLinear => {
80 OutputMerge::CatLinear(LinearConfig::new(d_model * 2, d_model).init(device))
81 }
82 }
83 }
84}
85
86/// A single bidirectional pair: a straight (→) and a reversed (←) Pre-LN block
87/// whose outputs are merged. The residual is **not** applied here — the
88/// enclosing [`BidiLayers`] adds it (or suppresses it on the first/last pair),
89/// mirroring the [`Layer`](crate::modules::Layer) / [`Layers`](crate::modules::Layers) split.
90#[derive(Module, Debug)]
91pub struct BidiLayerPair<M: Module> {
92 /// Pre-norm for the straight pass.
93 pub straight_norm: RmsNorm,
94 /// Pre-norm for the reversed pass.
95 pub reverse_norm: RmsNorm,
96 /// The block run left-to-right.
97 pub straight_block: M,
98 /// The block run right-to-left (over the flipped sequence).
99 pub reverse_block: M,
100 /// Merge strategy combining the two directions.
101 pub output_merge: OutputMerge,
102 /// Positions of this pair's class latents, spliced in before either
103 /// direction runs (both directions, and the residual, see the lengthened
104 /// sequence). Empty ⇒ none.
105 #[module(skip)]
106 pub class_latents: Vec<ClassLatent>,
107 /// This pair's class-latent embeddings, `[num_class_latents, d_model]`.
108 pub class_latents_emb: Option<Param<Tensor<2>>>,
109}
110
111impl<M: MambaBlock> BidiLayerPair<M>
112where
113 M::SsdPath: Clone,
114{
115 /// Splice this bidi-layer-pair's class latents into `x` (no-op when there are none).
116 fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
117 if self.class_latents_emb.is_none() {
118 return x;
119 }
120 insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
121 }
122
123 /// `[batch, sequence, d_model]` → `[batch, sequence, d_model]`, plus the two
124 /// updated direction caches. (`sequence` grows by the class-latent count.)
125 /// Returns the merged directions **without** the residual — the enclosing
126 /// [`BidiLayers`] adds it.
127 pub fn forward(
128 &self,
129 x: Tensor<3>,
130 straight_cache: Option<M::Cache>,
131 reverse_cache: Option<M::Cache>,
132 ssd_path: M::SsdPath,
133 ) -> (Tensor<3>, M::Cache, M::Cache) {
134 let x = self.insert_latents(x);
135 bidi_pair_forward(
136 &self.straight_norm,
137 &self.reverse_norm,
138 &self.straight_block,
139 &self.reverse_block,
140 &self.output_merge,
141 x,
142 straight_cache,
143 reverse_cache,
144 ssd_path,
145 )
146 }
147}
148
149/// The straight + reverse + merge computation of a bidirectional pair, over
150/// **borrowed** sub-modules.
151///
152/// Taking references (rather than owning clones) is load-bearing: a Burn `Param`
153/// that is still lazily-initialised re-runs its random initialiser **on every
154/// clone**, so cloning a not-yet-materialised block per forward would resample
155/// fresh random weights each call. [`BidiLayers`] therefore calls this directly
156/// on its real layers instead of building a transient (cloned) [`BidiLayerPair`].
157#[allow(clippy::too_many_arguments)]
158fn bidi_pair_forward<M: MambaBlock>(
159 straight_norm: &RmsNorm,
160 reverse_norm: &RmsNorm,
161 straight_block: &M,
162 reverse_block: &M,
163 output_merge: &OutputMerge,
164 x: Tensor<3>,
165 straight_cache: Option<M::Cache>,
166 reverse_cache: Option<M::Cache>,
167 ssd_path: M::SsdPath,
168) -> (Tensor<3>, M::Cache, M::Cache)
169where
170 M::SsdPath: Clone,
171{
172 let [batch, sequence, d_model] = x.dims();
173
174 // x reads >x₀>x₁>…; x_rev (flipped) reads the sequence backwards.
175 let x_rev = x.clone().flip([1]);
176 let x = straight_norm.forward(x);
177 let x_rev = reverse_norm.forward(x_rev);
178
179 let (x, straight_cache) = straight_block.block_forward(x, straight_cache, ssd_path.clone());
180 assert_eq!([batch, sequence, d_model], x.dims());
181
182 let (x_rev, reverse_cache) = reverse_block.block_forward(x_rev, reverse_cache, ssd_path);
183 assert_eq!([batch, sequence, d_model], x_rev.dims());
184
185 // Re-align the reversed read, then merge.
186 let x_rev = x_rev.flip([1]);
187 let merged = output_merge.forward(x, x_rev);
188 (merged, straight_cache, reverse_cache)
189}
190
191/// A stack of bidirectional [`Layer`] pairs with optional virtual-layer
192/// scheduling — one struct for every Mamba-x family.
193#[derive(Module, Debug)]
194pub struct BidiLayers<M: Module> {
195 /// Number of real (weight-bearing) layers; must be even (used in pairs).
196 pub n_real_layers: usize,
197 /// Optional `(n_virtual_layers, schedule)` for weight-sharing.
198 #[module(skip)]
199 pub n_virtual_layers: Option<(usize, BidiSchedule)>,
200 /// The weight-bearing layers, length `n_real_layers`.
201 pub real_layers: Vec<Layer<M>>,
202 /// Zero the first virtual pair's residual when `true`.
203 pub ignore_first_residual: bool,
204 /// Zero the last virtual pair's residual when `true`.
205 pub ignore_last_residual: bool,
206 /// One direction-merge per pair, length `n_real_layers / 2`.
207 pub outputs_merge: Vec<OutputMerge>,
208 /// How residuals are threaded between **pairs** (plain additive vs
209 /// Multi-Gate). The MGR unit is the pair: one module per real/virtual pair.
210 pub residuals: Residuals,
211 /// Positions of the stack-level class latents, spliced into the sequence
212 /// once before the first pair (independent of any per-pair class latents).
213 #[module(skip)]
214 pub class_latents: Vec<ClassLatent>,
215 /// The stack-level class-latent embeddings, `[num_class_latents, d_model]`.
216 pub class_latents_emb: Option<Param<Tensor<2>>>,
217}
218
219impl<M: MambaBlock + Clone> BidiLayers<M>
220where
221 M::SsdPath: Clone,
222{
223 /// Output positions of the stack-level class latents for an `orig_len` input.
224 pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize> {
225 class_marker_output_indices(&self.class_latents, orig_len)
226 }
227
228 /// Splice this bidi-layers' class latents into `x` (no-op when there are none).
229 fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
230 if self.class_latents_emb.is_none() {
231 return x;
232 }
233 insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
234 }
235
236 /// Seed the MultiGate streams from a full-sequence input — `n_stream` copies
237 /// of `x` as `[batch, sequence, n_stream, d_model]` — or `None` for the
238 /// Standard path. Panics if MultiGate is paired with stack-level class latents.
239 fn multi_gate_streams_seed(&self, x: &Tensor<3>) -> Option<Tensor<4>> {
240 let Residuals::MultiGate(mg) = &self.residuals else {
241 return None;
242 };
243 assert!(
244 self.class_latents_emb.is_none(),
245 "MultiGate residuals do not support stack-level class latents"
246 );
247 let [batch, sequence, d_model] = x.dims();
248 Some(
249 x.clone()
250 .unsqueeze_dim::<4>(2)
251 .expand([batch, sequence, mg.n_stream, d_model]),
252 )
253 }
254
255 /// `[batch, sequence, d_model]` → `[batch, sequence, d_model]`
256 /// (`sequence` grows by the stack-level class-latent count).
257 ///
258 /// Each pair returns its merged transform `F_l` (no residual). With
259 /// [`Residuals::Standard`] the input skip is added per pair (unless
260 /// suppressed). With [`Residuals::MultiGate`] the skip is dropped and
261 /// `n_stream` parallel streams — seeded from `x` — carry the residual between
262 /// pairs: each pair reads their attention-pooled aggregate as input and its
263 /// merged output is gated back into every stream (see [`MultiGate`]).
264 ///
265 /// [`MultiGate`]: crate::modules::MultiGate
266 pub fn forward(
267 &self,
268 mut x: Tensor<3>,
269 caches: Option<M::Caches>,
270 ssd_path: M::SsdPath,
271 ) -> (Tensor<3>, M::Caches) {
272 x = self.insert_latents(x);
273 let n = self
274 .n_virtual_layers
275 .as_ref()
276 .map(|(l, _)| {
277 assert!(l.is_multiple_of(2), "Bidi virtual layers are used in pairs");
278 *l
279 })
280 .unwrap_or_else(|| {
281 assert!(
282 self.n_real_layers.is_multiple_of(2),
283 "Bidi layers are used in pairs"
284 );
285 self.n_real_layers
286 });
287
288 let caches =
289 caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_3d(&x, n));
290 assert_eq!(
291 caches.slot_count(),
292 n,
293 "straight and reverse layers cannot share caches"
294 );
295
296 let mut slots = caches.into_slots();
297 // MultiGate keeps `n_stream` parallel streams (seeded from the input);
298 // Standard threads the single tensor `x` directly (streams stays `None`).
299 let mut streams = self.multi_gate_streams_seed(&x);
300 for i in 0..n / 2 {
301 let (straight_i, reverse_i) = (i * 2, i * 2 + 1);
302 let (straight_idx, reverse_idx) =
303 if let Some((n_virtual, schedule)) = &self.n_virtual_layers {
304 (
305 schedule.real_idx(straight_i, *n_virtual, self.n_real_layers),
306 schedule.real_idx(reverse_i, *n_virtual, self.n_real_layers),
307 )
308 } else {
309 (straight_i, reverse_i)
310 };
311 let straight_layer = &self.real_layers[straight_idx];
312 let reverse_layer = &self.real_layers[reverse_idx];
313
314 let straight_cache = slots[straight_i].take().unwrap();
315 let reverse_cache = slots[reverse_i].take().unwrap();
316
317 let first = self.ignore_first_residual && i == 0;
318 let last = self.ignore_last_residual && i + 1 == n / 2;
319
320 // For the Standard path the residual is the (pre-pair) input skip;
321 // clone it before the pair consumes `x`, and only when it is used.
322 // MultiGate carries the residual in its streams, so clones nothing.
323 let residual = match &self.residuals {
324 Residuals::Standard(_) if !(first || last) => Some(x.clone()),
325 _ => None,
326 };
327
328 // Run the pair directly on the (borrowed) real layers — never clone a
329 // block, since cloning a lazily-initialised `Param` resamples its
330 // random weights (see [`bidi_pair_forward`]). Stack-level class
331 // latents were already spliced above; pairs carry none of their own.
332 //
333 // The pair returns its merged transform `F_l` without the residual.
334 // The merge is a per-real-pair weight set (`n_real_layers / 2` of
335 // them), so it is indexed by the *real* pair `straight_idx / 2` — not
336 // the virtual pair `i` — sharing weights under virtual scheduling just
337 // like the blocks (and matching the MGR real-pair index below). In the
338 // non-virtual case `straight_idx == i * 2`, so this is `i`.
339 let (merged, sc, rc) = bidi_pair_forward(
340 &straight_layer.norm,
341 &reverse_layer.norm,
342 &straight_layer.mamba_block,
343 &reverse_layer.mamba_block,
344 &self.outputs_merge[straight_idx / 2],
345 x,
346 Some(straight_cache),
347 Some(reverse_cache),
348 ssd_path.clone(),
349 );
350 slots[straight_i] = Some(sc);
351 slots[reverse_i] = Some(rc);
352
353 match &self.residuals {
354 Residuals::Standard(_noop) => {
355 // Add the input skip here (the pair already consumed `x`), or
356 // output the bare transform when the residual is suppressed.
357 x = match residual {
358 Some(r) => merged + r,
359 None => merged,
360 };
361 }
362 Residuals::MultiGate(mg) => {
363 let s = streams.take().unwrap();
364 // A skipped residual is β ≡ 1 in the mixer (`new_streams =
365 // F_l`), the aggregator then collapsing to `F_l` — both
366 // branches shortcut that (mirrors `Layers::forward`). The MGR
367 // unit is the pair: virtual pair `i`, real pair `straight_idx
368 // / 2` (the straight index of a pair is even).
369 if last {
370 x = merged;
371 streams = Some(s);
372 } else if first {
373 let [b, seq, d] = merged.dims();
374 streams = Some(merged.clone().unsqueeze_dim::<4>(2).expand([
375 b,
376 seq,
377 mg.n_stream,
378 d,
379 ]));
380 x = merged;
381 } else {
382 let idx = mg.module_index(i, straight_idx / 2);
383 let (new_h, new_streams) = mg.layers[idx].forward(merged, s);
384 x = new_h;
385 streams = Some(new_streams);
386 }
387 }
388 }
389 }
390
391 (x, M::Caches::from_slots(slots))
392 }
393}
394
395/// Plain (non-serde) factory for [`BidiLayers`].
396pub struct BidiLayersBuilder<C> {
397 /// Number of real (weight-bearing) layers (must be even).
398 pub n_real_layers: usize,
399 /// Optional virtual-layer scheduling.
400 pub n_virtual_layers: Option<(usize, BidiSchedule)>,
401 /// Shared block config.
402 pub mamba_block: C,
403 /// Zero the first virtual pair's residual.
404 pub ignore_first_residual: bool,
405 /// Zero the last virtual pair's residual.
406 pub ignore_last_residual: bool,
407 /// One merge config per pair, length `n_real_layers / 2`.
408 pub outputs_merge: Vec<OutputMergeConfig>,
409 /// Stack-level class latents (spliced once before the first pair).
410 pub class_latents: Vec<ClassLatent>,
411 /// Inter-pair residual scheme (defaults to plain additive).
412 pub residuals: ResidualsConfig,
413}
414
415impl<C: MambaBlockConfig> BidiLayersBuilder<C> {
416 /// Allocate and initialise the bidirectional stack on `device`.
417 pub fn init(&self, device: &Device) -> BidiLayers<C::Block> {
418 let d_model = self.mamba_block.d_model();
419 let real_layers = (0..self.n_real_layers)
420 .map(|_| Layer {
421 norm: RmsNormConfig::new(d_model).init(device),
422 mamba_block: self.mamba_block.init_block(device),
423 class_latents: Vec::new(),
424 class_latents_emb: None,
425 })
426 .collect();
427 let outputs_merge = (0..self.n_real_layers / 2)
428 .map(|i| self.outputs_merge[i].init(d_model, device))
429 .collect();
430 // The MGR unit is the pair, so size the modules by *pairs* (halved real
431 // and virtual layer counts).
432 let n_virtual = self
433 .n_virtual_layers
434 .as_ref()
435 .map(|(l, _)| *l)
436 .unwrap_or(self.n_real_layers);
437 let residuals = self
438 .residuals
439 .init(d_model, self.n_real_layers / 2, n_virtual / 2, device);
440 BidiLayers {
441 n_real_layers: self.n_real_layers,
442 n_virtual_layers: self.n_virtual_layers.clone(),
443 real_layers,
444 ignore_first_residual: self.ignore_first_residual,
445 ignore_last_residual: self.ignore_last_residual,
446 outputs_merge,
447 residuals,
448 class_latents_emb: init_class_emb(self.class_latents.len(), d_model, device),
449 class_latents: self.class_latents.clone(),
450 }
451 }
452}
453
454// ===========================================================================
455// Unifying enums: one runtime + one serializable Config across all families
456// ===========================================================================
457
458/// A runtime-selectable bidirectional stack: the same paired straight/reverse
459/// structure over any Mamba-x family, chosen at runtime. The forward-only
460/// counterpart of [`MambaLatentNet`] for non-autoregressive tasks.
461#[derive(Module, Debug)]
462pub enum MambaBidiLayers {
463 /// Mamba-1 bidirectional stack.
464 #[cfg(feature = "mamba1")]
465 Mamba1(BidiLayers<crate::mamba1::prelude::Mamba1>),
466 /// Mamba-2 bidirectional stack.
467 #[cfg(feature = "mamba2")]
468 Mamba2(BidiLayers<crate::mamba2::prelude::Mamba2>),
469 /// Mamba-3 bidirectional stack.
470 #[cfg(feature = "mamba3")]
471 Mamba3(BidiLayers<crate::mamba3::prelude::Mamba3>),
472}
473
474impl MambaBidiLayers {
475 /// Output positions of the stack-level class latents for an `orig_len`
476 /// input (so a caller can read a class latent back out of the lengthened
477 /// `forward` output — e.g. as a pooled summary).
478 pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize> {
479 match self {
480 #[cfg(feature = "mamba1")]
481 Self::Mamba1(layers) => layers.class_latent_output_indices(orig_len),
482 #[cfg(feature = "mamba2")]
483 Self::Mamba2(layers) => layers.class_latent_output_indices(orig_len),
484 #[cfg(feature = "mamba3")]
485 Self::Mamba3(layers) => layers.class_latent_output_indices(orig_len),
486 }
487 }
488
489 /// Full-sequence bidirectional pass. The `ssd_path` must match the stack's
490 /// family; a mismatch is a caller error and panics.
491 pub fn forward(
492 &self,
493 x: Tensor<3>,
494 caches: Option<MambaCaches>,
495 ssd_path: MambaSsdPath,
496 ) -> (Tensor<3>, MambaCaches) {
497 match self {
498 #[cfg(feature = "mamba1")]
499 Self::Mamba1(layers) => {
500 let caches = caches.map(|c| match c {
501 MambaCaches::Mamba1(c) => c,
502 #[allow(unreachable_patterns)]
503 _ => panic!("cache family does not match Mamba-1 bidi stack"),
504 });
505 match ssd_path {
506 MambaSsdPath::Mamba1 => {}
507 #[allow(unreachable_patterns)]
508 _ => panic!("ssd_path family does not match Mamba-1 bidi stack"),
509 }
510 let (y, c) = layers.forward(x, caches, ());
511 (y, MambaCaches::Mamba1(c))
512 }
513 #[cfg(feature = "mamba2")]
514 Self::Mamba2(layers) => {
515 let caches = caches.map(|c| match c {
516 MambaCaches::Mamba2(c) => c,
517 #[allow(unreachable_patterns)]
518 _ => panic!("cache family does not match Mamba-2 bidi stack"),
519 });
520 let path = match ssd_path {
521 MambaSsdPath::Mamba2(p) => p,
522 #[allow(unreachable_patterns)]
523 _ => panic!("ssd_path family does not match Mamba-2 bidi stack"),
524 };
525 let (y, c) = layers.forward(x, caches, path);
526 (y, MambaCaches::Mamba2(c))
527 }
528 #[cfg(feature = "mamba3")]
529 Self::Mamba3(layers) => {
530 let caches = caches.map(|c| match c {
531 MambaCaches::Mamba3(c) => c,
532 #[allow(unreachable_patterns)]
533 _ => panic!("cache family does not match Mamba-3 bidi stack"),
534 });
535 let path = match ssd_path {
536 MambaSsdPath::Mamba3(p) => p,
537 #[allow(unreachable_patterns)]
538 _ => panic!("ssd_path family does not match Mamba-3 bidi stack"),
539 };
540 let (y, c) = layers.forward(x, caches, path);
541 (y, MambaCaches::Mamba3(c))
542 }
543 }
544 }
545}
546
547/// The serializable config for [`MambaBidiLayers`]. Each variant is concrete
548/// (per-family), so `#[derive(Config)]` applies; `init` builds the matching
549/// stack variant.
550#[derive(Config, Debug)]
551pub enum MambaBidiLayersConfig {
552 /// Build a Mamba-1 bidirectional stack.
553 #[cfg(feature = "mamba1")]
554 Mamba1 {
555 /// Number of real layers (must be even — used in pairs).
556 n_real_layers: usize,
557 n_virtual_layers: Option<(usize, BidiSchedule)>,
558 /// Shared block config.
559 mamba_block: crate::mamba1::prelude::Mamba1Config,
560 ignore_first_residual: bool,
561 ignore_last_residual: bool,
562 /// One merge config per pair, length `n_real_layers / 2`.
563 outputs_merge: Vec<OutputMergeConfig>,
564 /// Stack-level class latents, spliced into the sequence before the
565 /// first pair (e.g. a `Middle` summary latent in place of mean-pooling).
566 class_latents: Vec<ClassLatent>,
567 /// Inter-pair residual scheme (plain additive vs Multi-Gate).
568 residuals: ResidualsConfig,
569 },
570 /// Build a Mamba-2 bidirectional stack.
571 #[cfg(feature = "mamba2")]
572 Mamba2 {
573 /// Number of real layers (must be even — used in pairs).
574 n_real_layers: usize,
575 n_virtual_layers: Option<(usize, BidiSchedule)>,
576 /// Shared block config.
577 mamba_block: crate::mamba2::prelude::Mamba2Config,
578 ignore_first_residual: bool,
579 ignore_last_residual: bool,
580 /// One merge config per pair, length `n_real_layers / 2`.
581 outputs_merge: Vec<OutputMergeConfig>,
582 /// Stack-level class latents, spliced into the sequence before the
583 /// first pair (e.g. a `Middle` summary latent in place of mean-pooling).
584 class_latents: Vec<ClassLatent>,
585 /// Inter-pair residual scheme (plain additive vs Multi-Gate).
586 residuals: ResidualsConfig,
587 },
588 /// Build a Mamba-3 bidirectional stack.
589 #[cfg(feature = "mamba3")]
590 Mamba3 {
591 /// Number of real layers (must be even — used in pairs).
592 n_real_layers: usize,
593 n_virtual_layers: Option<(usize, BidiSchedule)>,
594 /// Shared block config.
595 mamba_block: crate::mamba3::prelude::Mamba3Config,
596 ignore_first_residual: bool,
597 ignore_last_residual: bool,
598 /// One merge config per pair, length `n_real_layers / 2`.
599 outputs_merge: Vec<OutputMergeConfig>,
600 /// Stack-level class latents, spliced into the sequence before the
601 /// first pair (e.g. a `Middle` summary latent in place of mean-pooling).
602 class_latents: Vec<ClassLatent>,
603 /// Inter-pair residual scheme (plain additive vs Multi-Gate).
604 residuals: ResidualsConfig,
605 },
606}
607
608impl MambaBidiLayersConfig {
609 /// Allocate and initialise the selected bidirectional stack on `device`.
610 pub fn init(&self, device: &Device) -> MambaBidiLayers {
611 match self {
612 #[cfg(feature = "mamba1")]
613 Self::Mamba1 {
614 n_real_layers,
615 n_virtual_layers,
616 mamba_block,
617 ignore_first_residual,
618 ignore_last_residual,
619 outputs_merge,
620 class_latents,
621 residuals,
622 } => MambaBidiLayers::Mamba1(
623 BidiLayersBuilder {
624 n_real_layers: *n_real_layers,
625 n_virtual_layers: n_virtual_layers.clone(),
626 mamba_block: mamba_block.clone(),
627 ignore_first_residual: *ignore_first_residual,
628 ignore_last_residual: *ignore_last_residual,
629 outputs_merge: outputs_merge.clone(),
630 class_latents: class_latents.clone(),
631 residuals: residuals.clone(),
632 }
633 .init(device),
634 ),
635 #[cfg(feature = "mamba2")]
636 Self::Mamba2 {
637 n_real_layers,
638 n_virtual_layers,
639 mamba_block,
640 ignore_first_residual,
641 ignore_last_residual,
642 outputs_merge,
643 class_latents,
644 residuals,
645 } => MambaBidiLayers::Mamba2(
646 BidiLayersBuilder {
647 n_real_layers: *n_real_layers,
648 n_virtual_layers: n_virtual_layers.clone(),
649 mamba_block: mamba_block.clone(),
650 ignore_first_residual: *ignore_first_residual,
651 ignore_last_residual: *ignore_last_residual,
652 outputs_merge: outputs_merge.clone(),
653 class_latents: class_latents.clone(),
654 residuals: residuals.clone(),
655 }
656 .init(device),
657 ),
658 #[cfg(feature = "mamba3")]
659 Self::Mamba3 {
660 n_real_layers,
661 n_virtual_layers,
662 mamba_block,
663 ignore_first_residual,
664 ignore_last_residual,
665 outputs_merge,
666 class_latents,
667 residuals,
668 } => MambaBidiLayers::Mamba3(
669 BidiLayersBuilder {
670 n_real_layers: *n_real_layers,
671 n_virtual_layers: n_virtual_layers.clone(),
672 mamba_block: mamba_block.clone(),
673 ignore_first_residual: *ignore_first_residual,
674 ignore_last_residual: *ignore_last_residual,
675 outputs_merge: outputs_merge.clone(),
676 class_latents: class_latents.clone(),
677 residuals: residuals.clone(),
678 }
679 .init(device),
680 ),
681 }
682 }
683}