1use crate::modules::LayersBuilder;
2use crate::modules::{ResidualsConfig, RmsNorm, RmsNormConfig};
3use crate::prelude::*;
4use crate::utils::Schedule;
5use crate::utils::class::{
6 assert_step_compatible, class_marker_output_indices, class_step_injections, init_class_emb,
7 insert_class_markers,
8};
9use burn::config::Config;
10use burn::module::Param;
11use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig};
12use burn::prelude::*;
13
14#[derive(Module, Debug)]
21pub struct LatentNetwork<M: Module> {
22 pub in_proj: Linear,
24 pub layers: Layers<M>,
26 pub out_proj: Linear,
28 #[module(skip)]
31 pub class_tokens: Vec<ClassToken>,
32 pub class_tokens_emb: Option<Param<Tensor<2>>>,
34}
35
36impl<M: MambaBlock> LatentNetwork<M>
37where
38 M::SsdPath: Clone,
39{
40 pub fn class_token_output_indices(&self, orig_len: usize) -> Vec<usize> {
42 class_marker_output_indices(&self.class_tokens, orig_len)
43 }
44
45 fn insert_tokens(&self, x: Tensor<3>) -> Tensor<3> {
47 if self.class_tokens_emb.is_none() {
48 return x;
49 }
50 insert_class_markers(x, &self.class_tokens, self.class_tokens_emb.as_ref()).0
51 }
52
53 pub fn forward(
57 &self,
58 x: Tensor<3>,
59 caches: Option<M::Caches>,
60 ssd_path: M::SsdPath,
61 ) -> (Tensor<3>, M::Caches) {
62 let x = self.insert_tokens(x);
63 let x = self.in_proj.forward(x);
64 let (x, caches) = self.layers.forward(x, caches, ssd_path);
65 let x = self.out_proj.forward(x);
66 (x, caches)
67 }
68
69 pub fn step(
86 &self,
87 x: Tensor<2>,
88 caches: Option<M::Caches>,
89 own_index: Option<&mut usize>,
90 mut layers_own_index: Option<&mut usize>,
91 mut layer_indices: Option<&mut Vec<usize>>,
92 ) -> (Tensor<2>, M::Caches) {
93 if let Some(cursor) = own_index {
98 let [batch, input_size] = x.dims();
99 let inj = class_step_injections(&self.class_tokens, "LatentNetwork");
100 let emb = self.class_tokens_emb.as_ref();
101 let mut caches = caches;
102 while let Some(i) = inj.iter().position(|&p| p == *cursor) {
103 let row = emb
104 .unwrap()
105 .val()
106 .narrow(0, i, 1)
107 .expand([batch, input_size]);
108 let (_discard, c) = self.step(
109 row,
110 caches,
111 None,
112 layers_own_index.as_deref_mut(),
113 layer_indices.as_deref_mut(),
114 );
115 caches = Some(c);
116 *cursor += 1;
117 }
118 let (out, caches) = self.step(x, caches, None, layers_own_index, layer_indices);
119 *cursor += 1;
120 return (out, caches);
121 }
122
123 assert_step_compatible(&self.class_tokens, "LatentNetwork");
125 let x = self.in_proj.forward(x);
126 let (x, caches) = self.layers.step(x, caches, layers_own_index, layer_indices);
127 let x = self.out_proj.forward(x);
128 (x, caches)
129 }
130
131 pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
135 assert_step_compatible(&self.class_tokens, "LatentNetwork");
136 let x = self.in_proj.forward(x);
137 let x = self.layers.step_infinite(x);
138 self.out_proj.forward(x)
139 }
140
141 pub fn step_n_approx(
145 &self,
146 x: Tensor<2>,
147 n: usize,
148 caches: Option<M::Caches>,
149 ) -> (Tensor<2>, M::Caches) {
150 assert_step_compatible(&self.class_tokens, "LatentNetwork");
151 let x = self.in_proj.forward(x);
152 let (x, caches) = self.layers.step_n_approx(x, n, caches);
153 (self.out_proj.forward(x), caches)
154 }
155}
156
157pub struct LatentNetworkBuilder<C> {
159 pub input_size: usize,
161 pub layers: LayersBuilder<C>,
163 pub output_size: usize,
165 pub class_tokens: Vec<ClassToken>,
167}
168
169impl<C: MambaBlockConfig> LatentNetworkBuilder<C> {
170 pub fn init(&self, device: &Device) -> LatentNetwork<C::Block> {
172 let d_model = self.layers.mamba_block.d_model();
173 LatentNetwork {
174 in_proj: LinearConfig::new(self.input_size, d_model)
175 .with_bias(true)
176 .init(device),
177 layers: self.layers.init(device),
178 out_proj: LinearConfig::new(d_model, self.output_size)
179 .with_bias(true)
180 .init(device),
181 class_tokens_emb: init_class_emb(self.class_tokens.len(), self.input_size, device),
182 class_tokens: self.class_tokens.clone(),
183 }
184 }
185}
186
187#[derive(Module, Debug)]
204pub struct VocabNetwork<M: Module> {
205 pub embedding: Embedding,
207 pub layers: Layers<M>,
209 pub norm_f: RmsNorm,
211 pub lm_head: Option<Linear>,
213}
214
215impl<M: MambaBlock> VocabNetwork<M>
216where
217 M::SsdPath: Clone,
218{
219 pub fn forward(
222 &self,
223 x: Tensor<2, Int>,
224 caches: Option<M::Caches>,
225 ssd_path: M::SsdPath,
226 ) -> (Tensor<3>, M::Caches) {
227 let x = self.embedding.forward(x);
228 let (x, caches) = self.layers.forward(x, caches, ssd_path);
229 let x = self.norm_f.forward(x);
230 (self.apply_lm_head(x), caches)
231 }
232
233 pub fn step(
240 &self,
241 x: Tensor<1, Int>,
242 caches: Option<M::Caches>,
243 layers_own_index: Option<&mut usize>,
244 layer_indices: Option<&mut Vec<usize>>,
245 ) -> (Tensor<2>, M::Caches) {
246 let x = self
248 .embedding
249 .forward(x.unsqueeze_dim::<2>(1))
250 .squeeze_dim(1);
251 let (x, caches) = self.layers.step(x, caches, layers_own_index, layer_indices);
252 let x = self.norm_f.forward(x);
253 let logits = self.apply_lm_head(x.unsqueeze_dim(1)).squeeze_dim(1);
255 (logits, caches)
256 }
257
258 pub fn step_infinite(&self, x: Tensor<1, Int>) -> Tensor<2> {
262 let x = self
263 .embedding
264 .forward(x.unsqueeze_dim::<2>(1))
265 .squeeze_dim(1);
266 let x = self.layers.step_infinite(x);
267 let x = self.norm_f.forward(x);
268 self.apply_lm_head(x.unsqueeze_dim(1)).squeeze_dim(1)
269 }
270
271 pub fn step_n_approx(
275 &self,
276 x: Tensor<1, Int>,
277 n: usize,
278 caches: Option<M::Caches>,
279 ) -> (Tensor<2>, M::Caches) {
280 let x = self
281 .embedding
282 .forward(x.unsqueeze_dim::<2>(1))
283 .squeeze_dim(1);
284 let (x, caches) = self.layers.step_n_approx(x, n, caches);
285 let x = self.norm_f.forward(x);
286 let logits = self.apply_lm_head(x.unsqueeze_dim(1)).squeeze_dim(1);
287 (logits, caches)
288 }
289
290 fn apply_lm_head(&self, x: Tensor<3>) -> Tensor<3> {
293 if let Some(lm_head) = &self.lm_head {
294 lm_head.forward(x)
295 } else {
296 let weight = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
298 Linear { weight, bias: None }.forward(x)
299 }
300 }
301}
302
303pub struct VocabNetworkBuilder<C> {
306 pub vocab_size: usize,
308 pub pad_vocab_size_multiple: usize,
310 pub layers: LayersBuilder<C>,
312 pub missing_lm_head: bool,
314}
315
316impl<C: MambaBlockConfig> VocabNetworkBuilder<C> {
317 fn padded_vocab(vocab_size: usize, multiple: usize) -> usize {
319 if vocab_size.is_multiple_of(multiple) {
320 vocab_size
321 } else {
322 ((vocab_size / multiple) + 1) * multiple
323 }
324 }
325
326 pub fn init(&self, device: &Device) -> VocabNetwork<C::Block> {
328 let d_model = self.layers.mamba_block.d_model();
329 let padded_vocab = Self::padded_vocab(self.vocab_size, self.pad_vocab_size_multiple);
330 let lm_head = if self.missing_lm_head {
331 None
332 } else {
333 Some(
334 LinearConfig::new(d_model, padded_vocab)
335 .with_bias(false)
336 .init(device),
337 )
338 };
339 VocabNetwork {
340 embedding: EmbeddingConfig::new(padded_vocab, d_model).init(device),
341 layers: self.layers.init(device),
342 norm_f: RmsNormConfig::new(d_model).init(device),
343 lm_head,
344 }
345 }
346}
347
348#[derive(Module, Debug)]
355pub enum MambaLatentNet {
356 #[cfg(feature = "mamba1")]
358 Mamba1(LatentNetwork<crate::mamba1::prelude::Mamba1>),
359 #[cfg(feature = "mamba2")]
361 Mamba2(LatentNetwork<crate::mamba2::prelude::Mamba2>),
362 #[cfg(feature = "mamba3")]
364 Mamba3(LatentNetwork<crate::mamba3::prelude::Mamba3>),
365}
366
367impl MambaLatentNet {
368 pub fn forward(
371 &self,
372 x: Tensor<3>,
373 caches: Option<MambaCaches>,
374 ssd_path: MambaSsdPath,
375 ) -> (Tensor<3>, MambaCaches) {
376 match self {
377 #[cfg(feature = "mamba1")]
378 Self::Mamba1(net) => {
379 let caches = caches.map(|c| match c {
380 MambaCaches::Mamba1(c) => c,
381 #[allow(unreachable_patterns)]
382 _ => panic!("cache family does not match Mamba-1 network"),
383 });
384 match ssd_path {
385 MambaSsdPath::Mamba1 => {}
386 #[allow(unreachable_patterns)]
387 _ => panic!("ssd_path family does not match Mamba-1 network"),
388 }
389 let (y, c) = net.forward(x, caches, ());
390 (y, MambaCaches::Mamba1(c))
391 }
392 #[cfg(feature = "mamba2")]
393 Self::Mamba2(net) => {
394 let caches = caches.map(|c| match c {
395 MambaCaches::Mamba2(c) => c,
396 #[allow(unreachable_patterns)]
397 _ => panic!("cache family does not match Mamba-2 network"),
398 });
399 let path = match ssd_path {
400 MambaSsdPath::Mamba2(p) => p,
401 #[allow(unreachable_patterns)]
402 _ => panic!("ssd_path family does not match Mamba-2 network"),
403 };
404 let (y, c) = net.forward(x, caches, path);
405 (y, MambaCaches::Mamba2(c))
406 }
407 #[cfg(feature = "mamba3")]
408 Self::Mamba3(net) => {
409 let caches = caches.map(|c| match c {
410 MambaCaches::Mamba3(c) => c,
411 #[allow(unreachable_patterns)]
412 _ => panic!("cache family does not match Mamba-3 network"),
413 });
414 let path = match ssd_path {
415 MambaSsdPath::Mamba3(p) => p,
416 #[allow(unreachable_patterns)]
417 _ => panic!("ssd_path family does not match Mamba-3 network"),
418 };
419 let (y, c) = net.forward(x, caches, path);
420 (y, MambaCaches::Mamba3(c))
421 }
422 }
423 }
424
425 pub fn step(
431 &self,
432 x: Tensor<2>,
433 caches: Option<MambaCaches>,
434 own_index: Option<&mut usize>,
435 layers_own_index: Option<&mut usize>,
436 layer_indices: Option<&mut Vec<usize>>,
437 ) -> (Tensor<2>, MambaCaches) {
438 match self {
439 #[cfg(feature = "mamba1")]
440 Self::Mamba1(net) => {
441 let caches = caches.map(|c| match c {
442 MambaCaches::Mamba1(c) => c,
443 #[allow(unreachable_patterns)]
444 _ => panic!("cache family does not match Mamba-1 network"),
445 });
446 let (y, c) = net.step(x, caches, own_index, layers_own_index, layer_indices);
447 (y, MambaCaches::Mamba1(c))
448 }
449 #[cfg(feature = "mamba2")]
450 Self::Mamba2(net) => {
451 let caches = caches.map(|c| match c {
452 MambaCaches::Mamba2(c) => c,
453 #[allow(unreachable_patterns)]
454 _ => panic!("cache family does not match Mamba-2 network"),
455 });
456 let (y, c) = net.step(x, caches, own_index, layers_own_index, layer_indices);
457 (y, MambaCaches::Mamba2(c))
458 }
459 #[cfg(feature = "mamba3")]
460 Self::Mamba3(net) => {
461 let caches = caches.map(|c| match c {
462 MambaCaches::Mamba3(c) => c,
463 #[allow(unreachable_patterns)]
464 _ => panic!("cache family does not match Mamba-3 network"),
465 });
466 let (y, c) = net.step(x, caches, own_index, layers_own_index, layer_indices);
467 (y, MambaCaches::Mamba3(c))
468 }
469 }
470 }
471
472 pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
476 match self {
477 #[cfg(feature = "mamba1")]
478 Self::Mamba1(net) => net.step_infinite(x),
479 #[cfg(feature = "mamba2")]
480 Self::Mamba2(net) => net.step_infinite(x),
481 #[cfg(feature = "mamba3")]
482 Self::Mamba3(net) => net.step_infinite(x),
483 }
484 }
485
486 pub fn step_n_approx(
490 &self,
491 x: Tensor<2>,
492 n: usize,
493 caches: Option<MambaCaches>,
494 ) -> (Tensor<2>, MambaCaches) {
495 match self {
496 #[cfg(feature = "mamba1")]
497 Self::Mamba1(net) => {
498 let caches = caches.map(|c| match c {
499 MambaCaches::Mamba1(c) => c,
500 #[allow(unreachable_patterns)]
501 _ => panic!("cache family does not match Mamba-1 network"),
502 });
503 let (y, c) = net.step_n_approx(x, n, caches);
504 (y, MambaCaches::Mamba1(c))
505 }
506 #[cfg(feature = "mamba2")]
507 Self::Mamba2(net) => {
508 let caches = caches.map(|c| match c {
509 MambaCaches::Mamba2(c) => c,
510 #[allow(unreachable_patterns)]
511 _ => panic!("cache family does not match Mamba-2 network"),
512 });
513 let (y, c) = net.step_n_approx(x, n, caches);
514 (y, MambaCaches::Mamba2(c))
515 }
516 #[cfg(feature = "mamba3")]
517 Self::Mamba3(net) => {
518 let caches = caches.map(|c| match c {
519 MambaCaches::Mamba3(c) => c,
520 #[allow(unreachable_patterns)]
521 _ => panic!("cache family does not match Mamba-3 network"),
522 });
523 let (y, c) = net.step_n_approx(x, n, caches);
524 (y, MambaCaches::Mamba3(c))
525 }
526 }
527 }
528}
529
530#[derive(Config, Debug)]
534pub enum MambaLatentNetConfig {
535 #[cfg(feature = "mamba1")]
537 Mamba1 {
538 input_size: usize,
540 n_real_layers: usize,
542 n_virtual_layers: Option<(usize, Schedule)>,
544 mamba_block: crate::mamba1::prelude::Mamba1Config,
546 output_size: usize,
548 class_tokens: Vec<ClassToken>,
550 ignore_first_residual: bool,
553 ignore_last_residual: bool,
556 residuals: ResidualsConfig,
558 },
559 #[cfg(feature = "mamba2")]
561 Mamba2 {
562 input_size: usize,
564 n_real_layers: usize,
566 n_virtual_layers: Option<(usize, Schedule)>,
568 mamba_block: crate::mamba2::prelude::Mamba2Config,
570 output_size: usize,
572 class_tokens: Vec<ClassToken>,
574 ignore_first_residual: bool,
577 ignore_last_residual: bool,
580 residuals: ResidualsConfig,
582 },
583 #[cfg(feature = "mamba3")]
585 Mamba3 {
586 input_size: usize,
588 n_real_layers: usize,
590 n_virtual_layers: Option<(usize, Schedule)>,
592 mamba_block: crate::mamba3::prelude::Mamba3Config,
594 output_size: usize,
596 class_tokens: Vec<ClassToken>,
598 ignore_first_residual: bool,
601 ignore_last_residual: bool,
604 residuals: ResidualsConfig,
606 },
607}
608
609impl MambaLatentNetConfig {
610 pub fn init(&self, device: &Device) -> MambaLatentNet {
612 match self {
613 #[cfg(feature = "mamba1")]
614 Self::Mamba1 {
615 input_size,
616 n_real_layers,
617 n_virtual_layers,
618 mamba_block,
619 output_size,
620 class_tokens,
621 ignore_first_residual,
622 ignore_last_residual,
623 residuals,
624 } => MambaLatentNet::Mamba1(
625 LatentNetworkBuilder {
626 input_size: *input_size,
627 layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
628 .with_n_virtual_layers(n_virtual_layers.clone())
629 .with_residuals(residuals.clone())
630 .with_ignore_first_residual(*ignore_first_residual)
631 .with_ignore_last_residual(*ignore_last_residual),
632 output_size: *output_size,
633 class_tokens: class_tokens.clone(),
634 }
635 .init(device),
636 ),
637 #[cfg(feature = "mamba2")]
638 Self::Mamba2 {
639 input_size,
640 n_real_layers,
641 n_virtual_layers,
642 mamba_block,
643 output_size,
644 class_tokens,
645 ignore_first_residual,
646 ignore_last_residual,
647 residuals,
648 } => MambaLatentNet::Mamba2(
649 LatentNetworkBuilder {
650 input_size: *input_size,
651 layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
652 .with_n_virtual_layers(n_virtual_layers.clone())
653 .with_residuals(residuals.clone())
654 .with_ignore_first_residual(*ignore_first_residual)
655 .with_ignore_last_residual(*ignore_last_residual),
656 output_size: *output_size,
657 class_tokens: class_tokens.clone(),
658 }
659 .init(device),
660 ),
661 #[cfg(feature = "mamba3")]
662 Self::Mamba3 {
663 input_size,
664 n_real_layers,
665 n_virtual_layers,
666 mamba_block,
667 output_size,
668 class_tokens,
669 ignore_first_residual,
670 ignore_last_residual,
671 residuals,
672 } => MambaLatentNet::Mamba3(
673 LatentNetworkBuilder {
674 input_size: *input_size,
675 layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
676 .with_n_virtual_layers(n_virtual_layers.clone())
677 .with_residuals(residuals.clone())
678 .with_ignore_first_residual(*ignore_first_residual)
679 .with_ignore_last_residual(*ignore_last_residual),
680 output_size: *output_size,
681 class_tokens: class_tokens.clone(),
682 }
683 .init(device),
684 ),
685 }
686 }
687}
688
689#[derive(Module, Debug)]
693pub enum MambaVocabNet {
694 #[cfg(feature = "mamba1")]
696 Mamba1(VocabNetwork<crate::mamba1::prelude::Mamba1>),
697 #[cfg(feature = "mamba2")]
699 Mamba2(VocabNetwork<crate::mamba2::prelude::Mamba2>),
700 #[cfg(feature = "mamba3")]
702 Mamba3(VocabNetwork<crate::mamba3::prelude::Mamba3>),
703}
704
705impl MambaVocabNet {
706 pub fn forward(
710 &self,
711 x: Tensor<2, Int>,
712 caches: Option<MambaCaches>,
713 ssd_path: MambaSsdPath,
714 ) -> (Tensor<3>, MambaCaches) {
715 match self {
716 #[cfg(feature = "mamba1")]
717 Self::Mamba1(net) => {
718 let caches = caches.map(|c| match c {
719 MambaCaches::Mamba1(c) => c,
720 #[allow(unreachable_patterns)]
721 _ => panic!("cache family does not match Mamba-1 network"),
722 });
723 match ssd_path {
724 MambaSsdPath::Mamba1 => {}
725 #[allow(unreachable_patterns)]
726 _ => panic!("ssd_path family does not match Mamba-1 network"),
727 }
728 let (y, c) = net.forward(x, caches, ());
729 (y, MambaCaches::Mamba1(c))
730 }
731 #[cfg(feature = "mamba2")]
732 Self::Mamba2(net) => {
733 let caches = caches.map(|c| match c {
734 MambaCaches::Mamba2(c) => c,
735 #[allow(unreachable_patterns)]
736 _ => panic!("cache family does not match Mamba-2 network"),
737 });
738 let path = match ssd_path {
739 MambaSsdPath::Mamba2(p) => p,
740 #[allow(unreachable_patterns)]
741 _ => panic!("ssd_path family does not match Mamba-2 network"),
742 };
743 let (y, c) = net.forward(x, caches, path);
744 (y, MambaCaches::Mamba2(c))
745 }
746 #[cfg(feature = "mamba3")]
747 Self::Mamba3(net) => {
748 let caches = caches.map(|c| match c {
749 MambaCaches::Mamba3(c) => c,
750 #[allow(unreachable_patterns)]
751 _ => panic!("cache family does not match Mamba-3 network"),
752 });
753 let path = match ssd_path {
754 MambaSsdPath::Mamba3(p) => p,
755 #[allow(unreachable_patterns)]
756 _ => panic!("ssd_path family does not match Mamba-3 network"),
757 };
758 let (y, c) = net.forward(x, caches, path);
759 (y, MambaCaches::Mamba3(c))
760 }
761 }
762 }
763
764 pub fn step(
769 &self,
770 x: Tensor<1, Int>,
771 caches: Option<MambaCaches>,
772 layers_own_index: Option<&mut usize>,
773 layer_indices: Option<&mut Vec<usize>>,
774 ) -> (Tensor<2>, MambaCaches) {
775 match self {
776 #[cfg(feature = "mamba1")]
777 Self::Mamba1(net) => {
778 let caches = caches.map(|c| match c {
779 MambaCaches::Mamba1(c) => c,
780 #[allow(unreachable_patterns)]
781 _ => panic!("cache family does not match Mamba-1 network"),
782 });
783 let (y, c) = net.step(x, caches, layers_own_index, layer_indices);
784 (y, MambaCaches::Mamba1(c))
785 }
786 #[cfg(feature = "mamba2")]
787 Self::Mamba2(net) => {
788 let caches = caches.map(|c| match c {
789 MambaCaches::Mamba2(c) => c,
790 #[allow(unreachable_patterns)]
791 _ => panic!("cache family does not match Mamba-2 network"),
792 });
793 let (y, c) = net.step(x, caches, layers_own_index, layer_indices);
794 (y, MambaCaches::Mamba2(c))
795 }
796 #[cfg(feature = "mamba3")]
797 Self::Mamba3(net) => {
798 let caches = caches.map(|c| match c {
799 MambaCaches::Mamba3(c) => c,
800 #[allow(unreachable_patterns)]
801 _ => panic!("cache family does not match Mamba-3 network"),
802 });
803 let (y, c) = net.step(x, caches, layers_own_index, layer_indices);
804 (y, MambaCaches::Mamba3(c))
805 }
806 }
807 }
808
809 pub fn step_infinite(&self, x: Tensor<1, Int>) -> Tensor<2> {
813 match self {
814 #[cfg(feature = "mamba1")]
815 Self::Mamba1(net) => net.step_infinite(x),
816 #[cfg(feature = "mamba2")]
817 Self::Mamba2(net) => net.step_infinite(x),
818 #[cfg(feature = "mamba3")]
819 Self::Mamba3(net) => net.step_infinite(x),
820 }
821 }
822
823 pub fn step_n_approx(
827 &self,
828 x: Tensor<1, Int>,
829 n: usize,
830 caches: Option<MambaCaches>,
831 ) -> (Tensor<2>, MambaCaches) {
832 match self {
833 #[cfg(feature = "mamba1")]
834 Self::Mamba1(net) => {
835 let caches = caches.map(|c| match c {
836 MambaCaches::Mamba1(c) => c,
837 #[allow(unreachable_patterns)]
838 _ => panic!("cache family does not match Mamba-1 network"),
839 });
840 let (y, c) = net.step_n_approx(x, n, caches);
841 (y, MambaCaches::Mamba1(c))
842 }
843 #[cfg(feature = "mamba2")]
844 Self::Mamba2(net) => {
845 let caches = caches.map(|c| match c {
846 MambaCaches::Mamba2(c) => c,
847 #[allow(unreachable_patterns)]
848 _ => panic!("cache family does not match Mamba-2 network"),
849 });
850 let (y, c) = net.step_n_approx(x, n, caches);
851 (y, MambaCaches::Mamba2(c))
852 }
853 #[cfg(feature = "mamba3")]
854 Self::Mamba3(net) => {
855 let caches = caches.map(|c| match c {
856 MambaCaches::Mamba3(c) => c,
857 #[allow(unreachable_patterns)]
858 _ => panic!("cache family does not match Mamba-3 network"),
859 });
860 let (y, c) = net.step_n_approx(x, n, caches);
861 (y, MambaCaches::Mamba3(c))
862 }
863 }
864 }
865}
866
867#[derive(Config, Debug)]
871pub enum MambaVocabNetConfig {
872 #[cfg(feature = "mamba1")]
874 Mamba1 {
875 n_real_layers: usize,
877 n_virtual_layers: Option<(usize, Schedule)>,
879 vocab_size: usize,
881 pad_vocab_size_multiple: usize,
883 mamba_block: crate::mamba1::prelude::Mamba1Config,
885 missing_lm_head: bool,
887 ignore_first_residual: bool,
890 ignore_last_residual: bool,
893 residuals: ResidualsConfig,
895 },
896 #[cfg(feature = "mamba2")]
898 Mamba2 {
899 n_real_layers: usize,
901 n_virtual_layers: Option<(usize, Schedule)>,
903 vocab_size: usize,
905 pad_vocab_size_multiple: usize,
907 mamba_block: crate::mamba2::prelude::Mamba2Config,
909 missing_lm_head: bool,
911 ignore_first_residual: bool,
914 ignore_last_residual: bool,
917 residuals: ResidualsConfig,
919 },
920 #[cfg(feature = "mamba3")]
922 Mamba3 {
923 n_real_layers: usize,
925 n_virtual_layers: Option<(usize, Schedule)>,
927 vocab_size: usize,
929 pad_vocab_size_multiple: usize,
931 mamba_block: crate::mamba3::prelude::Mamba3Config,
933 missing_lm_head: bool,
935 ignore_first_residual: bool,
938 ignore_last_residual: bool,
941 residuals: ResidualsConfig,
943 },
944}
945
946impl MambaVocabNetConfig {
947 pub fn init(&self, device: &Device) -> MambaVocabNet {
949 match self {
950 #[cfg(feature = "mamba1")]
951 Self::Mamba1 {
952 n_real_layers,
953 n_virtual_layers,
954 vocab_size,
955 pad_vocab_size_multiple,
956 mamba_block,
957 missing_lm_head,
958 ignore_first_residual,
959 ignore_last_residual,
960 residuals,
961 } => MambaVocabNet::Mamba1(
962 VocabNetworkBuilder {
963 vocab_size: *vocab_size,
964 pad_vocab_size_multiple: *pad_vocab_size_multiple,
965 layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
966 .with_n_virtual_layers(n_virtual_layers.clone())
967 .with_residuals(residuals.clone())
968 .with_ignore_first_residual(*ignore_first_residual)
969 .with_ignore_last_residual(*ignore_last_residual),
970 missing_lm_head: *missing_lm_head,
971 }
972 .init(device),
973 ),
974 #[cfg(feature = "mamba2")]
975 Self::Mamba2 {
976 n_real_layers,
977 n_virtual_layers,
978 vocab_size,
979 pad_vocab_size_multiple,
980 mamba_block,
981 missing_lm_head,
982 ignore_first_residual,
983 ignore_last_residual,
984 residuals,
985 } => MambaVocabNet::Mamba2(
986 VocabNetworkBuilder {
987 vocab_size: *vocab_size,
988 pad_vocab_size_multiple: *pad_vocab_size_multiple,
989 layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
990 .with_n_virtual_layers(n_virtual_layers.clone())
991 .with_residuals(residuals.clone())
992 .with_ignore_first_residual(*ignore_first_residual)
993 .with_ignore_last_residual(*ignore_last_residual),
994 missing_lm_head: *missing_lm_head,
995 }
996 .init(device),
997 ),
998 #[cfg(feature = "mamba3")]
999 Self::Mamba3 {
1000 n_real_layers,
1001 n_virtual_layers,
1002 vocab_size,
1003 pad_vocab_size_multiple,
1004 mamba_block,
1005 missing_lm_head,
1006 ignore_first_residual,
1007 ignore_last_residual,
1008 residuals,
1009 } => MambaVocabNet::Mamba3(
1010 VocabNetworkBuilder {
1011 vocab_size: *vocab_size,
1012 pad_vocab_size_multiple: *pad_vocab_size_multiple,
1013 layers: LayersBuilder::new(*n_real_layers, mamba_block.clone())
1014 .with_n_virtual_layers(n_virtual_layers.clone())
1015 .with_residuals(residuals.clone())
1016 .with_ignore_first_residual(*ignore_first_residual)
1017 .with_ignore_last_residual(*ignore_last_residual),
1018 missing_lm_head: *missing_lm_head,
1019 }
1020 .init(device),
1021 ),
1022 }
1023 }
1024}