1use crate::mamba3::prelude::*;
45use crate::utils::sanity::sanity as san;
46use crate::utils::{
47 rms_norm::{RmsNorm, RmsNormConfig},
48 rms_norm_gated::{RmsNormGated, RmsNormGatedConfig},
49 silu::Silu,
50 softplus::softplus,
51};
52use burn::prelude::*;
53use burn::{
54 module::{Module, Param},
55 nn::{Initializer, Linear, LinearConfig},
56};
57
58fn sigmoid<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
60 ((-x).exp() + 1.).recip()
61}
62
63#[derive(Module, Debug)]
76pub struct Mamba3<B: Backend> {
77 pub in_proj: Linear<B>,
84
85 pub dt_bias_h: Param<Tensor<B, 1>>,
88
89 pub dt_limit: (f64, f64),
91
92 pub a_floor: f64,
94
95 pub d_h: Param<Tensor<B, 1>>,
98
99 pub b_norm: RmsNorm<B>,
102
103 pub c_norm: RmsNorm<B>,
106
107 pub b_bias_hrn: Param<Tensor<B, 3>>,
112
113 pub c_bias_hrn: Param<Tensor<B, 3>>,
116
117 pub mimo_x: Option<Param<Tensor<B, 3>>>,
121
122 pub mimo_z: Option<Param<Tensor<B, 3>>>,
126
127 pub mimo_o: Option<Param<Tensor<B, 3>>>,
131
132 pub out_norm: Option<RmsNormGated<B>>,
138
139 pub out_proj: Linear<B>,
141
142 pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
145
146 pub state_rank: usize,
148
149 pub ngroups: usize,
151
152 pub num_rope_angles: usize,
154
155 pub rope_dim: usize,
158
159 pub mimo_rank: usize,
161}
162
163impl<B: Backend> Mamba3<B> {
164 pub fn d_inner(&self) -> usize {
166 let [d_inner, _d_model] = self.out_proj.weight.dims();
167 d_inner
168 }
169
170 pub fn nheads(&self) -> usize {
172 let [nheads] = self.d_h.dims();
173 nheads
174 }
175
176 pub fn per_head_dim(&self) -> usize {
178 self.d_inner() / self.nheads()
179 }
180}
181
182#[derive(Config, Debug)]
188pub struct Mamba3Config {
189 pub d_model: usize,
191
192 #[config(default = 128)]
195 pub state_rank: usize,
196
197 #[config(default = 2)]
199 pub expand: usize,
200
201 #[config(default = 64)]
203 pub per_head_dim: usize,
204
205 #[config(default = 1)]
207 pub ngroups: usize,
208
209 #[config(default = 1)]
215 pub mimo_rank: usize,
216
217 #[config(default = "1e-4")]
219 pub a_floor: f64,
220
221 #[config(default = 1e-3)]
223 pub dt_min: f64,
224
225 #[config(default = 0.1)]
227 pub dt_max: f64,
228
229 #[config(default = 1e-4)]
231 pub dt_init_floor: f64,
232
233 #[config(default = "(0., 6.5504e+4)")]
235 pub dt_limit: (f64, f64),
236
237 #[config(default = false)]
239 pub has_proj_bias: bool,
240
241 #[config(default = false)]
243 pub has_learnable_init_state: bool,
244
245 #[config(default = 0.5)]
253 pub rope_fraction: f64,
254
255 #[config(default = false)]
262 pub has_outproj_norm: bool,
263}
264
265impl Mamba3Config {
266 pub fn d_inner(&self) -> usize {
267 self.expand * self.d_model
268 }
269 pub fn nheads(&self) -> usize {
270 self.d_inner() / self.per_head_dim
271 }
272
273 pub fn rope_dim(&self) -> usize {
276 let mut d = (self.state_rank as f64 * self.rope_fraction) as usize;
277 if d % 2 != 0 {
278 d -= 1;
279 }
280 d
281 }
282
283 pub fn num_rope_angles(&self) -> usize {
284 self.rope_dim() / 2
285 }
286
287 pub fn d_in_proj(&self) -> usize {
291 2 * self.d_inner()
292 + 2 * self.ngroups * self.state_rank * self.mimo_rank
293 + 3 * self.nheads()
294 + self.num_rope_angles()
295 }
296
297 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3<B> {
299 let d_inner = self.d_inner();
300 let nheads = self.nheads();
301 let ngroups = self.ngroups;
302 let state_rank = self.state_rank;
303 let mimo_rank = self.mimo_rank;
304 let num_rope_angles = self.num_rope_angles();
305
306 assert!(
307 state_rank % 2 == 0,
308 "state_rank must be even for RoPE pairing"
309 );
310 assert!(self.per_head_dim > 0, "per_head_dim must be positive");
311 assert_eq!(
312 nheads * self.per_head_dim,
313 d_inner,
314 "d_inner must be divisible by per_head_dim"
315 );
316 assert_ne!(ngroups, 0, "ngroups must be at least 1");
317 assert_eq!(nheads % ngroups, 0, "nheads must be divisible by ngroups");
318 assert!(self.a_floor > 0.0, "a_floor must be positive");
319 assert!(mimo_rank >= 1, "mimo_rank must be at least 1");
320 assert!(
321 self.rope_fraction == 0.5 || self.rope_fraction == 1.0,
322 "rope_fraction must be 0.5 or 1.0"
323 );
324 assert!(num_rope_angles > 0, "num_rope_angles must be at least 1");
325
326 let uniform_init = |fan_in: usize| {
327 let bound = 1.0 / (fan_in as f64).sqrt();
328 Initializer::Uniform {
329 min: -bound,
330 max: bound,
331 }
332 };
333
334 let in_proj = LinearConfig::new(self.d_model, self.d_in_proj())
335 .with_bias(self.has_proj_bias)
336 .with_initializer(uniform_init(self.d_model))
337 .init::<B>(device);
338
339 let expm1 = |t: Tensor<B, 1>| t.exp() - 1.;
341 let dt_h = Tensor::random(
342 [nheads],
343 burn::tensor::Distribution::Uniform(self.dt_min.ln(), self.dt_max.ln()),
344 device,
345 )
346 .exp();
347 let dt_h = dt_h.clamp(self.dt_init_floor, f64::INFINITY);
348 let inv_dt_h = dt_h.clone() + (-expm1(-dt_h)).log();
349 let dt_bias_h = Param::from_tensor(inv_dt_h);
350
351 let d_h = Initializer::Ones.init::<B, 1, _>([nheads], device);
352
353 let b_norm = RmsNormConfig::new(state_rank).init(device);
354 let c_norm = RmsNormConfig::new(state_rank).init(device);
355
356 let b_bias_hrn = Initializer::Ones.init::<B, 3, _>([nheads, mimo_rank, state_rank], device);
358 let c_bias_hrn = Initializer::Ones.init::<B, 3, _>([nheads, mimo_rank, state_rank], device);
359
360 let (mimo_x, mimo_z, mimo_o) = if mimo_rank > 1 {
362 let per_head_dim = self.per_head_dim;
363 let mx = Param::from_tensor(Tensor::full(
365 [nheads, mimo_rank, per_head_dim],
366 1.0 / mimo_rank as f64,
367 device,
368 ));
369 let mz = Param::from_tensor(Tensor::ones([nheads, mimo_rank, per_head_dim], device));
370 let mo = Param::from_tensor(Tensor::full(
371 [nheads, mimo_rank, per_head_dim],
372 1.0 / mimo_rank as f64,
373 device,
374 ));
375 (Some(mx), Some(mz), Some(mo))
376 } else {
377 (None, None, None)
378 };
379
380 let out_norm = self.has_outproj_norm.then(|| {
382 RmsNormGatedConfig::new(self.per_head_dim)
383 .with_norm_before_gate(true)
384 .init(device)
385 });
386
387 let out_proj = LinearConfig::new(d_inner, self.d_model)
388 .with_bias(self.has_proj_bias)
389 .with_initializer(uniform_init(d_inner))
390 .init(device);
391
392 let init_state_hpr = self.has_learnable_init_state.then(|| {
393 Initializer::Zeros.init::<B, 3, _>([nheads, self.per_head_dim, state_rank], device)
394 });
395
396 Mamba3 {
397 in_proj,
398 dt_bias_h,
399 dt_limit: self.dt_limit,
400 a_floor: self.a_floor,
401 d_h,
402 b_norm,
403 c_norm,
404 b_bias_hrn,
405 c_bias_hrn,
406 mimo_x,
407 mimo_z,
408 mimo_o,
409 out_norm,
410 out_proj,
411 init_state_hpr,
412 state_rank,
413 ngroups,
414 rope_dim: self.rope_dim(),
415 num_rope_angles,
416 mimo_rank,
417 }
418 }
419}
420
421pub fn apply_rope<B: Backend, const D: usize>(
443 x: Tensor<B, D>,
444 angles: Tensor<B, D>,
445 rotate_pairwise: bool,
446) -> Tensor<B, D> {
447 let dims = x.dims();
448 let n = dims[D - 1];
449 let n2 = n / 2;
450 let leading: usize = dims[..D - 1].iter().product();
451
452 let angles_flat = angles.reshape([leading, n2]);
453 let cos = angles_flat.clone().cos();
454 let sin = angles_flat.sin();
455
456 if rotate_pairwise {
457 let x_pairs = x.reshape([leading, n2, 2]);
459 let x0 = x_pairs.clone().narrow(2, 0, 1).squeeze_dim(2);
460 let x1 = x_pairs.narrow(2, 1, 1).squeeze_dim(2);
461
462 let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
463 let x1r = sin * x0 + cos * x1;
464
465 Tensor::cat(
466 vec![x0r.unsqueeze_dim::<3>(2), x1r.unsqueeze_dim::<3>(2)],
467 2,
468 )
469 .reshape(dims)
470 } else {
471 let x_halves = x.reshape([leading, 2, n2]);
473 let x0 = x_halves.clone().narrow(1, 0, 1).squeeze_dim(1);
474 let x1 = x_halves.narrow(1, 1, 1).squeeze_dim(1);
475
476 let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
477 let x1r = sin * x0 + cos * x1;
478
479 Tensor::cat(
480 vec![x0r.unsqueeze_dim::<3>(1), x1r.unsqueeze_dim::<3>(1)],
481 1,
482 )
483 .reshape(dims)
484 }
485}
486
487fn apply_rope_partial<B: Backend, const D: usize>(
504 x: Tensor<B, D>,
505 angles: Tensor<B, D>,
506 rope_dim: usize,
507 rotate_pairwise: bool,
508) -> Tensor<B, D> {
509 let state_rank = x.dims()[D - 1];
510 if rope_dim == state_rank {
511 return apply_rope::<B, D>(x, angles, rotate_pairwise);
512 }
513
514 if rotate_pairwise {
515 let x_rope = x.clone().narrow(D - 1, 0, rope_dim);
519 let x_rest = x.narrow(D - 1, rope_dim, state_rank - rope_dim);
520 let x_rope_rotated = apply_rope::<B, D>(x_rope, angles, true);
521 return Tensor::cat(vec![x_rope_rotated, x_rest], D - 1);
522 }
523
524 let half = state_rank / 2;
528 let num_rope_angles = rope_dim / 2;
529 debug_assert!(
530 num_rope_angles < half,
531 "partial RoPE requires rope_dim < state_rank here"
532 );
533
534 let x_h1 = x.clone().narrow(D - 1, 0, half);
537 let x_h2 = x.narrow(D - 1, half, half);
538 let x_h1_rope = x_h1.clone().narrow(D - 1, 0, num_rope_angles);
539 let x_h1_pass = x_h1.narrow(D - 1, num_rope_angles, half - num_rope_angles);
540 let x_h2_rope = x_h2.clone().narrow(D - 1, 0, num_rope_angles);
541 let x_h2_pass = x_h2.narrow(D - 1, num_rope_angles, half - num_rope_angles);
542
543 let cos = angles.clone().cos();
545 let sin = angles.sin();
546 let x_h1_rot = cos.clone() * x_h1_rope.clone() - sin.clone() * x_h2_rope.clone();
547 let x_h2_rot = sin * x_h1_rope + cos * x_h2_rope;
548
549 let x_h1_out = Tensor::cat(vec![x_h1_rot, x_h1_pass], D - 1);
551 let x_h2_out = Tensor::cat(vec![x_h2_rot, x_h2_pass], D - 1);
552 Tensor::cat(vec![x_h1_out, x_h2_out], D - 1)
553}
554
555fn build_v_mimo<B: Backend>(
568 x_bshp: Tensor<B, 4>,
569 mimo_x_hrp: Option<&Tensor<B, 3>>,
570) -> Tensor<B, 5> {
571 let [batch, seq, nheads, per_head_dim] = x_bshp.dims();
572 match mimo_x_hrp {
573 None => {
574 x_bshp.unsqueeze_dim::<5>(2) }
577 Some(mimo_x) => {
578 let [_, mimo_rank, _] = mimo_x.dims();
579 let x_exp =
581 x_bshp
582 .unsqueeze_dim::<5>(2)
583 .expand([batch, seq, mimo_rank, nheads, per_head_dim]);
584 let mx_exp = mimo_x
586 .clone()
587 .permute([1, 0, 2]) .unsqueeze_dim::<4>(0)
589 .unsqueeze_dim::<5>(0)
590 .expand([batch, seq, mimo_rank, nheads, per_head_dim]);
591 x_exp * mx_exp }
593 }
594}
595
596impl<B: Backend + Mamba3BackendExt> Mamba3<B> {
601 #[allow(non_snake_case)]
611 pub fn forward(
612 &self,
613 input_bsm: Tensor<B, 3>,
614 cache: Option<Mamba3Cache<B>>,
615 ssd_path: Mamba3SsdPath,
616 ) -> (Tensor<B, 3>, Mamba3Cache<B>) {
617 let [batch, sequence, _d_model] = input_bsm.dims();
618 let d_inner = self.d_inner();
619 let nheads = self.nheads();
620 let ngroups = self.ngroups;
621 let per_head_dim = self.per_head_dim();
622 let state_rank = self.state_rank;
623 let num_rope_angles = self.num_rope_angles;
624 let heads_per_group = nheads / ngroups;
625 let mimo_rank = self.mimo_rank;
626 let device = input_bsm.device();
627
628 assert!(sequence > 0, "sequence length must be at least 1");
629 assert_eq!(nheads % ngroups, 0);
630 san(&input_bsm);
631
632 let mut cache = cache.unwrap_or_else(|| {
634 let ssm_bhpr = Tensor::zeros([batch, nheads, per_head_dim, state_rank], &device);
635 let k_state_brhn = Tensor::zeros([batch, mimo_rank, nheads, state_rank], &device);
636 let v_state_bhp = Tensor::zeros([batch, nheads, per_head_dim], &device);
637 let cum_angle_bhr = Tensor::zeros([batch, nheads, num_rope_angles], &device);
638 Mamba3Cache {
639 ssm_bhpr,
640 k_state_brhn,
641 v_state_bhp,
642 cum_angle_bhr,
643 }
644 });
645
646 let proj_bsd = self.in_proj.forward(input_bsm);
648 let bc_size = ngroups * state_rank * mimo_rank;
649
650 let mut parts = proj_bsd
651 .split_with_sizes(
652 vec![
653 d_inner,
654 d_inner,
655 bc_size,
656 bc_size,
657 nheads,
658 nheads,
659 nheads,
660 num_rope_angles,
661 ],
662 2,
663 )
664 .into_iter();
665 let z_bsi = parts.next().unwrap(); let x_bsi = parts.next().unwrap(); let b_raw_bsd = parts.next().unwrap(); let c_raw_bsd = parts.next().unwrap(); let dd_dt_bsh = parts.next().unwrap(); let dd_A_raw_bsh = parts.next().unwrap(); let lam_raw_bsh = parts.next().unwrap(); let theta_bsa = parts.next().unwrap(); san(&z_bsi);
675 san(&x_bsi);
676 san(&dd_dt_bsh);
677
678 let dt_bias_11h = self.dt_bias_h.val().unsqueeze_dims(&[0, 1]);
680 let dt_bsh = softplus(dd_dt_bsh + dt_bias_11h).clamp(self.dt_limit.0, self.dt_limit.1);
681
682 let a_bsh = -softplus(dd_A_raw_bsh).clamp(f64::NEG_INFINITY, -self.a_floor);
683 let da_bsh = dt_bsh.clone() * a_bsh;
684
685 let alpha_bsh = da_bsh.clone().exp();
686 let lam_bsh = sigmoid(lam_raw_bsh);
687 let gamma_bsh = lam_bsh.clone() * dt_bsh.clone();
688 let beta_bsh = (-lam_bsh.clone() + 1.0) * dt_bsh.clone() * alpha_bsh.clone();
689
690 san(&dt_bsh);
691 san(&da_bsh);
692 san(&gamma_bsh);
693 san(&beta_bsh);
694
695 let x_bshp = x_bsi.reshape([batch, sequence, nheads, per_head_dim]);
697
698 let b_bsrhr = {
702 let b_bsrgr = b_raw_bsd.reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
703 let b_norm = self
706 .b_norm
707 .forward(b_bsrgr.reshape([batch * sequence * mimo_rank, ngroups, state_rank]))
708 .reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
709 let b_exp = b_norm
711 .unsqueeze_dim::<6>(4) .expand([
713 batch,
714 sequence,
715 mimo_rank,
716 ngroups,
717 heads_per_group,
718 state_rank,
719 ])
720 .reshape([batch, sequence, mimo_rank, nheads, state_rank]);
721 let bias = self
724 .b_bias_hrn
725 .val()
726 .permute([1, 0, 2]) .unsqueeze_dim::<4>(0)
728 .unsqueeze_dim::<5>(0); b_exp + bias
730 };
731 let c_bsrhr = {
732 let c_bsrgr = c_raw_bsd.reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
733 let c_norm = self
734 .c_norm
735 .forward(c_bsrgr.reshape([batch * sequence * mimo_rank, ngroups, state_rank]))
736 .reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
737 let c_exp = c_norm
738 .unsqueeze_dim::<6>(4)
739 .expand([
740 batch,
741 sequence,
742 mimo_rank,
743 ngroups,
744 heads_per_group,
745 state_rank,
746 ])
747 .reshape([batch, sequence, mimo_rank, nheads, state_rank]);
748 let bias = self
749 .c_bias_hrn
750 .val()
751 .permute([1, 0, 2])
752 .unsqueeze_dim::<4>(0)
753 .unsqueeze_dim::<5>(0);
754 c_exp + bias
755 };
756 assert_eq!(
758 [batch, sequence, mimo_rank, nheads, state_rank],
759 b_bsrhr.dims()
760 );
761 assert_eq!(
762 [batch, sequence, mimo_rank, nheads, state_rank],
763 c_bsrhr.dims()
764 );
765
766 let theta_scaled_bsa = theta_bsa.tanh() * std::f32::consts::PI;
768 let raw_angles_bsha =
769 dt_bsh.clone().unsqueeze_dim::<4>(3) * theta_scaled_bsa.unsqueeze_dim::<4>(2);
770
771 let cumsum_bsha = raw_angles_bsha.cumsum(1);
772 let cum_angles_bsha = cache.cum_angle_bhr.clone().unsqueeze_dim::<4>(1) + cumsum_bsha;
773 assert_eq!(
774 [batch, sequence, nheads, num_rope_angles],
775 cum_angles_bsha.dims()
776 );
777 san(&cum_angles_bsha);
778
779 let angles_exp_bsrha = cum_angles_bsha.clone().unsqueeze_dim::<5>(2).expand([
782 batch,
783 sequence,
784 mimo_rank,
785 nheads,
786 num_rope_angles,
787 ]);
788 let rotate_pairwise = mimo_rank == 1;
791 let rope_dim = self.rope_dim;
792 let b_bsrhn = apply_rope_partial::<B, 5>(
793 b_bsrhr,
794 angles_exp_bsrha.clone(),
795 rope_dim,
796 rotate_pairwise,
797 );
798 let c_bsrhn =
799 apply_rope_partial::<B, 5>(c_bsrhr, angles_exp_bsrha, rope_dim, rotate_pairwise);
800 san(&b_bsrhn);
801 san(&c_bsrhn);
802
803 let x_prev_first_b1hp = cache.v_state_bhp.clone().unsqueeze_dim::<4>(1);
810 let x_prev_bshp = if sequence == 1 {
811 x_prev_first_b1hp
812 } else {
813 Tensor::cat(
814 vec![x_prev_first_b1hp, x_bshp.clone().narrow(1, 0, sequence - 1)],
815 1,
816 )
817 };
818 let b_prev_first_b1rhn = cache.k_state_brhn.clone().unsqueeze_dim::<5>(1);
820 let b_prev_bsrhn = if sequence == 1 {
821 b_prev_first_b1rhn
822 } else {
823 Tensor::cat(
824 vec![
825 b_prev_first_b1rhn,
826 b_bsrhn.clone().narrow(1, 0, sequence - 1),
827 ],
828 1,
829 )
830 };
831
832 let gamma_bsh1 = gamma_bsh.unsqueeze_dim::<4>(3);
835 let beta_bsh1 = beta_bsh.unsqueeze_dim::<4>(3);
836 let x_gamma_bshp = x_bshp.clone() * gamma_bsh1; let x_beta_bshp = x_prev_bshp * beta_bsh1; let b_last_brhn = b_bsrhn
841 .clone()
842 .narrow(1, sequence - 1, 1)
843 .reshape([batch, mimo_rank, nheads, state_rank]);
844
845 let chunk_len = ssd_path.chunk_len_or_optimal(state_rank, per_head_dim);
847 let sequence_padded = sequence.next_multiple_of(chunk_len);
848 let pad = sequence_padded - sequence;
849
850 let (x_gamma_bShp, x_beta_bShp, da_bSh, b_bSrhn, b_prev_bSrhn, c_bSrhn) = if pad == 0 {
851 (
852 x_gamma_bshp,
853 x_beta_bshp,
854 da_bsh,
855 b_bsrhn,
856 b_prev_bsrhn,
857 c_bsrhn,
858 )
859 } else {
860 let pad_hp = Tensor::zeros([batch, pad, nheads, per_head_dim], &device);
861 let pad_h = Tensor::zeros([batch, pad, nheads], &device);
862 let pad_rhn = Tensor::zeros([batch, pad, mimo_rank, nheads, state_rank], &device);
863 (
864 Tensor::cat(vec![x_gamma_bshp, pad_hp.clone()], 1),
865 Tensor::cat(vec![x_beta_bshp, pad_hp], 1),
866 Tensor::cat(vec![da_bsh, pad_h], 1),
867 Tensor::cat(vec![b_bsrhn, pad_rhn.clone()], 1),
868 Tensor::cat(vec![b_prev_bsrhn, pad_rhn.clone()], 1),
869 Tensor::cat(vec![c_bsrhn, pad_rhn], 1),
870 )
871 };
872
873 let nchunks = sequence_padded / chunk_len;
875 let x_gamma_bnlhp = x_gamma_bShp.reshape([batch, nchunks, chunk_len, nheads, per_head_dim]);
876 let x_beta_bnlhp = x_beta_bShp.reshape([batch, nchunks, chunk_len, nheads, per_head_dim]);
877 let da_bnlh = da_bSh.reshape([batch, nchunks, chunk_len, nheads]);
878 let b_bnlrhn = b_bSrhn.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
880 let b_prev_bnlrhn =
881 b_prev_bSrhn.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
882 let c_bnlrhn = c_bSrhn.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
883
884 let mimo_x_val = self.mimo_x.as_ref().map(|p| p.val());
887 let v_gamma_bnlrhp = build_v_mimo_chunked(
888 x_gamma_bnlhp.clone(),
889 mimo_x_val.as_ref(),
890 batch,
891 nchunks,
892 chunk_len,
893 mimo_rank,
894 nheads,
895 per_head_dim,
896 );
897 let v_beta_bnlrhp = build_v_mimo_chunked(
898 x_beta_bnlhp,
899 mimo_x_val.as_ref(),
900 batch,
901 nchunks,
902 chunk_len,
903 mimo_rank,
904 nheads,
905 per_head_dim,
906 );
907
908 let input_gamma = Mamba3SsdInput {
909 v_bnlrhp: v_gamma_bnlrhp,
910 da_bnlh: da_bnlh.clone(),
911 b_bnlrhn: b_bnlrhn.clone(),
912 c_bnlrhn: c_bnlrhn.clone(),
913 initial_state_bhpr: cache.ssm_bhpr,
914 init_state_hpr: self.init_state_hpr.as_ref().map(|s| s.val()),
915 };
916 let (y_gamma_bnlrhp, final_state_gamma) = ssd_path.clone().run(input_gamma);
917
918 let input_beta = Mamba3SsdInput {
919 v_bnlrhp: v_beta_bnlrhp,
920 da_bnlh,
921 b_bnlrhn: b_prev_bnlrhn,
922 c_bnlrhn,
923 initial_state_bhpr: Tensor::zeros([batch, nheads, per_head_dim, state_rank], &device),
924 init_state_hpr: None,
925 };
926 let (y_beta_bnlrhp, final_state_beta) = ssd_path.run(input_beta);
927
928 let y_bnlrhp = y_gamma_bnlrhp + y_beta_bnlrhp;
930 let final_state_bhpr = final_state_gamma + final_state_beta;
931
932 san(&y_bnlrhp);
933 san(&final_state_bhpr);
934
935 cache.ssm_bhpr = final_state_bhpr;
936
937 let y_bSrhp = y_bnlrhp.reshape([batch, sequence_padded, mimo_rank, nheads, per_head_dim]);
939 let y_bsrhp = if pad == 0 {
940 y_bSrhp
941 } else {
942 y_bSrhp.narrow(1, 0, sequence)
943 };
944
945 let v_raw_bsrhp = build_v_mimo::<B>(x_bshp.clone(), mimo_x_val.as_ref());
948
949 let d_11_h1 = self.d_h.val().unsqueeze_dims::<5>(&[0, 1, 2, 4]); let y_bsrhp = y_bsrhp + d_11_h1 * v_raw_bsrhp.clone();
951
952 let y_bsi = if mimo_rank > 1 {
956 let mimo_z_val = self.mimo_z.as_ref().map(|p| p.val()).unwrap();
957 let mimo_o_val = self.mimo_o.as_ref().map(|p| p.val()).unwrap();
958
959 let z_bshp = z_bsi
961 .clone()
962 .reshape([batch, sequence, nheads, per_head_dim]);
963 let z_bsrhp = {
964 let z_exp = z_bshp.unsqueeze_dim::<5>(2).expand([
967 batch,
968 sequence,
969 mimo_rank,
970 nheads,
971 per_head_dim,
972 ]);
973 let mz = mimo_z_val
974 .permute([1, 0, 2]) .unsqueeze_dim::<4>(0)
976 .unsqueeze_dim::<5>(0)
977 .expand([batch, sequence, mimo_rank, nheads, per_head_dim]);
978 z_exp * mz
979 };
980
981 let y_combined_bsrhp = match &self.out_norm {
985 Some(norm) => norm.forward(y_bsrhp, z_bsrhp),
986 None => y_bsrhp * Silu::new().forward(z_bsrhp),
987 };
988
989 let mo = mimo_o_val
992 .permute([1, 0, 2]) .unsqueeze_dim::<4>(0)
994 .unsqueeze_dim::<5>(0)
995 .expand([batch, sequence, mimo_rank, nheads, per_head_dim]);
996 let y_bhp: Tensor<B, 4> = (y_combined_bsrhp * mo).sum_dim(2).squeeze_dim(2);
998 y_bhp.reshape([batch, sequence, d_inner])
999 } else {
1000 let y_bshp: Tensor<B, 4> = y_bsrhp.squeeze_dim(2); let z_bshp = z_bsi.reshape([batch, sequence, nheads, per_head_dim]);
1003 let y_combined_bshp = match &self.out_norm {
1004 Some(norm) => norm.forward(y_bshp, z_bshp),
1005 None => y_bshp * Silu::new().forward(z_bshp),
1006 };
1007 y_combined_bshp.reshape([batch, sequence, d_inner])
1008 };
1009 san(&y_bsi);
1010
1011 let out_bsm = self.out_proj.forward(y_bsi);
1013 san(&out_bsm);
1014
1015 cache.k_state_brhn = b_last_brhn;
1018
1019 cache.v_state_bhp =
1021 x_bshp
1022 .narrow(1, sequence - 1, 1)
1023 .reshape([batch, nheads, per_head_dim]);
1024
1025 cache.cum_angle_bhr =
1027 cum_angles_bsha
1028 .narrow(1, sequence - 1, 1)
1029 .reshape([batch, nheads, num_rope_angles]);
1030
1031 (out_bsm, cache)
1032 }
1033}
1034
1035fn build_v_mimo_chunked<B: Backend>(
1036 x_bnlhp: Tensor<B, 5>,
1037 mimo_x: Option<&Tensor<B, 3>>,
1038 batch: usize,
1039 nchunks: usize,
1040 chunk_len: usize,
1041 mimo_rank: usize,
1042 nheads: usize,
1043 per_head_dim: usize,
1044) -> Tensor<B, 6> {
1045 match mimo_x {
1046 None => x_bnlhp.unsqueeze_dim::<6>(3),
1047 Some(mx) => {
1048 let x_exp = x_bnlhp.unsqueeze_dim::<6>(3).expand([
1049 batch,
1050 nchunks,
1051 chunk_len,
1052 mimo_rank,
1053 nheads,
1054 per_head_dim,
1055 ]);
1056 let mx_exp = mx
1057 .clone()
1058 .permute([1, 0, 2])
1059 .unsqueeze_dim::<4>(0)
1060 .unsqueeze_dim::<5>(0)
1061 .unsqueeze_dim::<6>(0)
1062 .expand([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
1063 x_exp * mx_exp
1064 }
1065 }
1066}
1067
1068mod step {
1073 use super::*;
1074
1075 impl<B: Backend> Mamba3<B> {
1076 #[allow(non_snake_case)]
1096 pub fn step(
1097 &self,
1098 input_bm: Tensor<B, 2>,
1099 cache: Option<Mamba3Cache<B>>,
1100 ) -> (Tensor<B, 2>, Mamba3Cache<B>) {
1101 let [batch, d_model] = input_bm.dims();
1102 let d_inner = self.d_inner();
1103 let nheads = self.nheads();
1104 let ngroups = self.ngroups;
1105 let per_head_dim = self.per_head_dim();
1106 let state_rank = self.state_rank;
1107 let num_rope_angles = self.num_rope_angles;
1108 let heads_per_group = nheads / ngroups;
1109 let mimo_rank = self.mimo_rank;
1110 let device = &input_bm.device();
1111 let ssm_shape = [batch, nheads, per_head_dim, state_rank];
1112
1113 assert_eq!(nheads % ngroups, 0);
1114
1115 let mut cache = cache.unwrap_or_else(|| {
1116 let ssm_bhpr = Tensor::zeros(ssm_shape, device);
1117 let k_state_brhn = Tensor::zeros([batch, mimo_rank, nheads, state_rank], device);
1118 let v_state_bhp = Tensor::zeros([batch, nheads, per_head_dim], device);
1119 let cum_angle_bhr = Tensor::zeros([batch, nheads, num_rope_angles], device);
1120 Mamba3Cache {
1121 ssm_bhpr,
1122 k_state_brhn,
1123 v_state_bhp,
1124 cum_angle_bhr,
1125 }
1126 });
1127
1128 let proj_bd = self.in_proj.forward(input_bm);
1130 let bc_size = ngroups * state_rank * mimo_rank;
1131 let mut parts = proj_bd
1132 .split_with_sizes(
1133 vec![
1134 d_inner,
1135 d_inner,
1136 bc_size,
1137 bc_size,
1138 nheads,
1139 nheads,
1140 nheads,
1141 num_rope_angles,
1142 ],
1143 1,
1144 )
1145 .into_iter();
1146 let z_bi = parts.next().unwrap(); let x_bi = parts.next().unwrap(); let b_raw_bd = parts.next().unwrap(); let c_raw_bd = parts.next().unwrap();
1150 let dd_dt_bh = parts.next().unwrap(); let dd_A_raw_bh = parts.next().unwrap();
1152 let lam_raw_bh = parts.next().unwrap();
1153 let theta_ba = parts.next().unwrap(); let x_bhp = x_bi.reshape([batch, nheads, per_head_dim]);
1157
1158 let dt_bias_1h = self.dt_bias_h.val().unsqueeze_dim(0);
1160 let dt_bh = softplus(dd_dt_bh + dt_bias_1h).clamp(self.dt_limit.0, self.dt_limit.1);
1161 let a_bh = -softplus(dd_A_raw_bh).clamp(f64::NEG_INFINITY, -self.a_floor);
1162 let da_bh = dt_bh.clone() * a_bh;
1163 let alpha_bh = da_bh.exp();
1164 let lam_bh = sigmoid(lam_raw_bh);
1165 let gamma_bh = lam_bh.clone() * dt_bh.clone();
1166 let beta_bh = (-lam_bh.clone() + 1.0) * dt_bh.clone() * alpha_bh.clone();
1167
1168 let b_brhn = {
1171 let b_brgn = b_raw_bd.reshape([batch, mimo_rank, ngroups, state_rank]);
1172 let b_norm = self
1173 .b_norm
1174 .forward(b_brgn.reshape([batch * mimo_rank, ngroups, state_rank]))
1175 .reshape([batch, mimo_rank, ngroups, state_rank]);
1176 let b_exp = b_norm
1177 .unsqueeze_dim::<5>(3) .expand([batch, mimo_rank, ngroups, heads_per_group, state_rank])
1179 .reshape([batch, mimo_rank, nheads, state_rank]);
1180 let bias = self
1182 .b_bias_hrn
1183 .val()
1184 .permute([1, 0, 2])
1185 .unsqueeze_dim::<4>(0);
1186 b_exp + bias
1187 };
1188 let c_brhn = {
1189 let c_brgn = c_raw_bd.reshape([batch, mimo_rank, ngroups, state_rank]);
1190 let c_norm = self
1191 .c_norm
1192 .forward(c_brgn.reshape([batch * mimo_rank, ngroups, state_rank]))
1193 .reshape([batch, mimo_rank, ngroups, state_rank]);
1194 let c_exp = c_norm
1195 .unsqueeze_dim::<5>(3)
1196 .expand([batch, mimo_rank, ngroups, heads_per_group, state_rank])
1197 .reshape([batch, mimo_rank, nheads, state_rank]);
1198 let bias = self
1199 .c_bias_hrn
1200 .val()
1201 .permute([1, 0, 2])
1202 .unsqueeze_dim::<4>(0);
1203 c_exp + bias
1204 };
1205 assert_eq!([batch, mimo_rank, nheads, state_rank], b_brhn.dims());
1206
1207 let theta_scaled_ba = theta_ba.tanh() * std::f32::consts::PI;
1209 let raw_angle_bha = dt_bh.unsqueeze_dim::<3>(2) * theta_scaled_ba.unsqueeze_dim::<3>(1);
1210 let new_cum_angle_bha = cache.cum_angle_bhr.clone() + raw_angle_bha;
1211
1212 let angles_brha = new_cum_angle_bha.clone().unsqueeze_dim::<4>(1).expand([
1214 batch,
1215 mimo_rank,
1216 nheads,
1217 num_rope_angles,
1218 ]);
1219 let rotate_pairwise = mimo_rank == 1;
1222 let rope_dim = self.rope_dim;
1223 let b_brhn =
1224 apply_rope_partial::<B, 4>(b_brhn, angles_brha.clone(), rope_dim, rotate_pairwise);
1225 let c_brhn = apply_rope_partial::<B, 4>(c_brhn, angles_brha, rope_dim, rotate_pairwise);
1226
1227 let mimo_x_val = self.mimo_x.as_ref().map(|p| p.val());
1231 let (x_vals_brhp, xs_vals_brhp) = build_mimo_vals(
1232 x_bhp.clone(),
1233 cache.v_state_bhp.clone(),
1234 mimo_x_val.as_ref(),
1235 batch,
1236 mimo_rank,
1237 nheads,
1238 per_head_dim,
1239 device,
1240 );
1241
1242 let gamma_b1h1 = gamma_bh.clone().unsqueeze_dim::<3>(1).unsqueeze_dim::<4>(3); let beta_b1h1 = beta_bh.clone().unsqueeze_dim::<3>(1).unsqueeze_dim::<4>(3);
1255
1256 let x_gamma_brhp = x_vals_brhp.clone() * gamma_b1h1; let x_beta_brhp = xs_vals_brhp * beta_b1h1; let xBt_state = {
1264 let b_bhrn = b_brhn.clone().permute([0, 2, 1, 3]); let xg_bhpr = x_gamma_brhp.permute([0, 2, 3, 1]); xg_bhpr.matmul(b_bhrn) };
1268 let xBt_prev = {
1269 let b_state_bhrn = cache.k_state_brhn.clone().permute([0, 2, 1, 3]); let xb_bhpr = x_beta_brhp.permute([0, 2, 3, 1]); xb_bhpr.matmul(b_state_bhrn) };
1273
1274 let alpha_bh11 = alpha_bh.unsqueeze_dims::<4>(&[2, 3]);
1275 let new_state_bhpn = alpha_bh11 * cache.ssm_bhpr.clone() + xBt_state + xBt_prev;
1276
1277 let out_r_brhp = {
1281 let c_bhrn = c_brhn.permute([0, 2, 1, 3]); let c_bhnr = c_bhrn.permute([0, 1, 3, 2]); let out_bhpr = new_state_bhpn.clone().matmul(c_bhnr); out_bhpr.permute([0, 3, 1, 2]) };
1291
1292 let d_skip = self
1294 .d_h
1295 .val()
1296 .unsqueeze_dims::<4>(&[0, 1, 3]) .expand([batch, mimo_rank, nheads, per_head_dim]);
1298 let out_r_brhp = out_r_brhp + d_skip * x_vals_brhp;
1299
1300 let z_bhp = z_bi.reshape([batch, nheads, per_head_dim]);
1304 let y_bi = if mimo_rank > 1 {
1305 let mimo_z_val = self.mimo_z.as_ref().map(|p| p.val()).unwrap();
1306 let mimo_o_val = self.mimo_o.as_ref().map(|p| p.val()).unwrap();
1307
1308 let z_exp =
1310 z_bhp
1311 .unsqueeze_dim::<4>(1)
1312 .expand([batch, mimo_rank, nheads, per_head_dim]);
1313 let mz = mimo_z_val.permute([1, 0, 2]).unsqueeze_dim::<4>(0).expand([
1315 batch,
1316 mimo_rank,
1317 nheads,
1318 per_head_dim,
1319 ]);
1320 let z_r = z_exp * mz;
1321
1322 let combined = match &self.out_norm {
1324 Some(norm) => norm.forward(out_r_brhp, z_r),
1325 None => out_r_brhp * Silu::new().forward(z_r),
1326 };
1327
1328 let mo = mimo_o_val.permute([1, 0, 2]).unsqueeze_dim::<4>(0).expand([
1331 batch,
1332 mimo_rank,
1333 nheads,
1334 per_head_dim,
1335 ]);
1336 let out_bhp: Tensor<B, 3> = (combined * mo).sum_dim(1).squeeze_dim(1);
1337 out_bhp.reshape([batch, d_inner])
1338 } else {
1339 let y_bhp: Tensor<B, 3> = out_r_brhp.squeeze_dim(1); let combined = match &self.out_norm {
1342 Some(norm) => norm.forward(y_bhp, z_bhp),
1343 None => y_bhp * Silu::new().forward(z_bhp),
1344 };
1345 combined.reshape([batch, d_inner])
1346 };
1347
1348 let out_bm = self.out_proj.forward(y_bi);
1350 assert_eq!([batch, d_model], out_bm.dims());
1351
1352 cache.ssm_bhpr = new_state_bhpn;
1354 cache.k_state_brhn = b_brhn; cache.v_state_bhp = x_bhp;
1357 cache.cum_angle_bhr = new_cum_angle_bha;
1358
1359 (out_bm, cache)
1360 }
1361 }
1362
1363 fn build_mimo_vals<B: Backend>(
1367 x_bhp: Tensor<B, 3>,
1368 x_state_bhp: Tensor<B, 3>,
1369 mimo_x: Option<&Tensor<B, 3>>,
1370 batch: usize,
1371 mimo_rank: usize,
1372 nheads: usize,
1373 per_head_dim: usize,
1374 _device: &B::Device,
1375 ) -> (Tensor<B, 4>, Tensor<B, 4>) {
1376 match mimo_x {
1377 None => {
1378 (
1380 x_bhp.unsqueeze_dim::<4>(1),
1381 x_state_bhp.unsqueeze_dim::<4>(1),
1382 )
1383 }
1384 Some(mx) => {
1385 let x_exp =
1387 x_bhp
1388 .unsqueeze_dim::<4>(1)
1389 .expand([batch, mimo_rank, nheads, per_head_dim]);
1390 let xs_exp = x_state_bhp.unsqueeze_dim::<4>(1).expand([
1391 batch,
1392 mimo_rank,
1393 nheads,
1394 per_head_dim,
1395 ]);
1396 let mx_exp = mx.clone().permute([1, 0, 2]).unsqueeze_dim::<4>(0).expand([
1398 batch,
1399 mimo_rank,
1400 nheads,
1401 per_head_dim,
1402 ]);
1403 (x_exp * mx_exp.clone(), xs_exp * mx_exp)
1404 }
1405 }
1406 }
1407}
1408
1409#[cfg(all(test, feature = "backend-flex"))]
1414mod tests {
1415 use super::*;
1416 use burn::backend::{Autodiff, Flex};
1417 use burn::tensor::Distribution;
1418
1419 type InnerB = Flex;
1422 type B = Autodiff<InnerB>;
1424
1425 type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
1426
1427 fn small_config() -> Mamba3Config {
1428 Mamba3Config::new(32) .with_state_rank(8)
1430 .with_expand(2)
1431 .with_per_head_dim(8)
1432 }
1433
1434 fn small_config_mimo() -> Mamba3Config {
1435 Mamba3Config::new(32)
1436 .with_state_rank(8)
1437 .with_expand(2)
1438 .with_per_head_dim(8)
1439 .with_mimo_rank(2)
1440 }
1441
1442 struct RunGrads {
1446 out: Tensor<InnerB, 3>,
1447 d_input: Tensor<InnerB, 3>,
1448 d_in_proj_w: Tensor<InnerB, 2>,
1449 d_dt_bias: Tensor<InnerB, 1>,
1450 d_d: Tensor<InnerB, 1>,
1451 d_b_norm_gamma: Tensor<InnerB, 1>,
1452 d_c_norm_gamma: Tensor<InnerB, 1>,
1453 d_b_bias: Tensor<InnerB, 3>,
1454 d_c_bias: Tensor<InnerB, 3>,
1455 d_out_proj_w: Tensor<InnerB, 2>,
1456 }
1457
1458 fn run_with_grads(
1463 model: &Mamba3<B>,
1464 input: &Param<Tensor<B, 3>>,
1465 head: &Tensor<InnerB, 3>,
1466 forward: impl FnOnce(&Mamba3<B>, Tensor<B, 3>) -> Tensor<B, 3>,
1467 ) -> RunGrads {
1468 let out = forward(model, input.val());
1469 let out_inner = out.clone().inner();
1470
1471 let head = Tensor::from_inner(head.clone());
1472 let loss = (out * head).sum();
1473 let grads = loss.backward();
1474
1475 RunGrads {
1476 out: out_inner,
1477 d_input: input.val().grad(&grads).expect("grad input"),
1478 d_in_proj_w: model
1479 .in_proj
1480 .weight
1481 .val()
1482 .grad(&grads)
1483 .expect("grad in_proj.weight"),
1484 d_dt_bias: model.dt_bias_h.val().grad(&grads).expect("grad dt_bias_h"),
1485 d_d: model.d_h.val().grad(&grads).expect("grad d_h"),
1486 d_b_norm_gamma: model
1487 .b_norm
1488 .gamma
1489 .val()
1490 .grad(&grads)
1491 .expect("grad b_norm.gamma"),
1492 d_c_norm_gamma: model
1493 .c_norm
1494 .gamma
1495 .val()
1496 .grad(&grads)
1497 .expect("grad c_norm.gamma"),
1498 d_b_bias: model
1499 .b_bias_hrn
1500 .val()
1501 .grad(&grads)
1502 .expect("grad b_bias_hrn"),
1503 d_c_bias: model
1504 .c_bias_hrn
1505 .val()
1506 .grad(&grads)
1507 .expect("grad c_bias_hrn"),
1508 d_out_proj_w: model
1509 .out_proj
1510 .weight
1511 .val()
1512 .grad(&grads)
1513 .expect("grad out_proj.weight"),
1514 }
1515 }
1516
1517 fn check_grads_match(label: &str, a: &RunGrads, b: &RunGrads, grad_tol: f32) {
1521 let mut failures: Vec<String> = Vec::new();
1522 macro_rules! check {
1523 ($field:ident, $name:expr) => {{
1524 let d = (a.$field.clone() - b.$field.clone())
1525 .abs()
1526 .max()
1527 .into_scalar();
1528 eprintln!("{:>40} {:>16} | max abs diff = {:>10.6}", label, $name, d);
1529 if d >= grad_tol {
1530 failures.push(format!(
1531 "{}: grad of {} max abs diff = {:.6} (tol {})",
1532 label, $name, d, grad_tol
1533 ));
1534 }
1535 }};
1536 }
1537 check!(d_input, "input");
1538 check!(d_in_proj_w, "in_proj.weight");
1539 check!(d_dt_bias, "dt_bias_h");
1540 check!(d_d, "d_h");
1541 check!(d_b_norm_gamma, "b_norm.gamma");
1542 check!(d_c_norm_gamma, "c_norm.gamma");
1543 check!(d_b_bias, "b_bias_hrn");
1544 check!(d_c_bias, "c_bias_hrn");
1545 check!(d_out_proj_w, "out_proj.weight");
1546 assert!(
1547 failures.is_empty(),
1548 "gradient mismatches:\n {}",
1549 failures.join("\n ")
1550 );
1551 }
1552
1553 fn param_input(input: &Tensor<InnerB, 3>) -> Param<Tensor<B, 3>> {
1557 Param::from_tensor(Tensor::from_inner(input.clone()))
1558 }
1559
1560 fn run_step_matches_forward(cfg: Mamba3Config) {
1561 let device: Device = Default::default();
1562 let model = cfg.init::<B>(&device);
1563
1564 let batch = 2;
1565 let seq_len = 5;
1566 let d_model = cfg.d_model;
1567
1568 let input = Tensor::<InnerB, 3>::random(
1569 [batch, seq_len, d_model],
1570 Distribution::Normal(0.0, 1.0),
1571 &device,
1572 );
1573 let head = Tensor::<InnerB, 3>::random(
1574 [batch, seq_len, d_model],
1575 Distribution::Normal(0.0, 1.0),
1576 &device,
1577 );
1578
1579 let ssd_path = Mamba3SsdPath::Minimal(Some(4));
1580
1581 let input_fwd = param_input(&input);
1582 let r_fwd = run_with_grads(&model, &input_fwd, &head, |m, x| {
1583 let (out, _) = m.forward(x, None, ssd_path.clone());
1584 out
1585 });
1586
1587 let input_step = param_input(&input);
1588 let r_step = run_with_grads(&model, &input_step, &head, |m, x| {
1589 let mut cache: Option<Mamba3Cache<B>> = None;
1590 let mut outs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
1591 for t in 0..seq_len {
1592 let token = x.clone().narrow(1, t, 1).squeeze_dim(1);
1593 let (out_t, new_cache) = m.step(token, cache);
1594 cache = Some(new_cache);
1595 outs.push(out_t);
1596 }
1597 Tensor::stack(outs, 1)
1598 });
1599
1600 let diff = (r_fwd.out.clone() - r_step.out.clone())
1602 .abs()
1603 .max()
1604 .into_scalar();
1605 assert!(
1606 diff < 1e-4,
1607 "step() vs forward() max absolute difference = {diff:.6} (expected < 1e-4)"
1608 );
1609
1610 check_grads_match("step vs forward", &r_fwd, &r_step, 1e-3);
1612 }
1613
1614 #[test]
1615 fn step_matches_forward() {
1616 run_step_matches_forward(small_config());
1617 }
1618
1619 #[test]
1620 fn step_matches_forward_ngroups2() {
1621 let cfg = Mamba3Config::new(32)
1622 .with_state_rank(8)
1623 .with_expand(2)
1624 .with_per_head_dim(16)
1625 .with_ngroups(2);
1626 run_step_matches_forward(cfg);
1627 }
1628
1629 #[test]
1630 fn step_matches_forward_mimo() {
1631 run_step_matches_forward(small_config_mimo());
1632 }
1633
1634 #[test]
1635 fn step_matches_forward_mimo_ngroups2() {
1636 let cfg = Mamba3Config::new(32)
1637 .with_state_rank(8)
1638 .with_expand(2)
1639 .with_per_head_dim(16)
1640 .with_ngroups(2)
1641 .with_mimo_rank(2);
1642 run_step_matches_forward(cfg);
1643 }
1644
1645 fn run_split_matches_full(cfg: Mamba3Config) {
1650 let device: Device = Default::default();
1651 let model = cfg.init::<B>(&device);
1652
1653 let batch = 2;
1654 let seq_len = 6;
1655 let split = 2;
1656 let d_model = cfg.d_model;
1657
1658 let input = Tensor::<InnerB, 3>::random(
1659 [batch, seq_len, d_model],
1660 Distribution::Normal(0.0, 1.0),
1661 &device,
1662 );
1663 let head = Tensor::<InnerB, 3>::random(
1664 [batch, seq_len, d_model],
1665 Distribution::Normal(0.0, 1.0),
1666 &device,
1667 );
1668
1669 let ssd_path = Mamba3SsdPath::Minimal(Some(4));
1670
1671 let input_full = param_input(&input);
1672 let r_full = run_with_grads(&model, &input_full, &head, |m, x| {
1673 let (out, _) = m.forward(x, None, ssd_path.clone());
1674 out
1675 });
1676
1677 let input_split = param_input(&input);
1678 let r_split = run_with_grads(&model, &input_split, &head, |m, x| {
1679 let prefix = x.clone().narrow(1, 0, split);
1680 let suffix = x.narrow(1, split, seq_len - split);
1681 let (out_prefix, cache) = m.forward(prefix, None, ssd_path.clone());
1682 let (out_suffix, _) = m.forward(suffix, Some(cache), ssd_path.clone());
1683 Tensor::cat(vec![out_prefix, out_suffix], 1)
1684 });
1685
1686 let diff = (r_full.out.clone() - r_split.out.clone())
1688 .abs()
1689 .max()
1690 .into_scalar();
1691 assert!(
1692 diff < 1e-4,
1693 "split forward vs full forward max absolute difference = {diff:.6} (expected < 1e-4)"
1694 );
1695
1696 check_grads_match("split vs full", &r_full, &r_split, 1e-3);
1698 }
1699
1700 #[test]
1701 fn split_matches_full() {
1702 run_split_matches_full(small_config());
1703 }
1704
1705 #[test]
1706 fn split_matches_full_ngroups2() {
1707 let cfg = Mamba3Config::new(32)
1708 .with_state_rank(8)
1709 .with_expand(2)
1710 .with_per_head_dim(16)
1711 .with_ngroups(2);
1712 run_split_matches_full(cfg);
1713 }
1714
1715 #[test]
1716 fn split_matches_full_mimo() {
1717 run_split_matches_full(small_config_mimo());
1718 }
1719
1720 #[test]
1721 fn split_matches_full_mimo_ngroups2() {
1722 let cfg = Mamba3Config::new(32)
1723 .with_state_rank(8)
1724 .with_expand(2)
1725 .with_per_head_dim(16)
1726 .with_ngroups(2)
1727 .with_mimo_rank(2);
1728 run_split_matches_full(cfg);
1729 }
1730
1731 #[test]
1734 fn step_matches_forward_rope_half() {
1735 let cfg = Mamba3Config::new(32)
1736 .with_state_rank(8)
1737 .with_expand(2)
1738 .with_per_head_dim(8)
1739 .with_rope_fraction(0.5);
1740 run_step_matches_forward(cfg);
1741 }
1742
1743 #[test]
1744 fn step_matches_forward_rope_half_mimo() {
1745 let cfg = Mamba3Config::new(32)
1746 .with_state_rank(8)
1747 .with_expand(2)
1748 .with_per_head_dim(8)
1749 .with_mimo_rank(2)
1750 .with_rope_fraction(0.5);
1751 run_step_matches_forward(cfg);
1752 }
1753
1754 #[test]
1755 fn split_matches_full_rope_half() {
1756 let cfg = Mamba3Config::new(32)
1757 .with_state_rank(8)
1758 .with_expand(2)
1759 .with_per_head_dim(8)
1760 .with_rope_fraction(0.5);
1761 run_split_matches_full(cfg);
1762 }
1763
1764 #[test]
1767 fn step_matches_forward_outproj_norm() {
1768 let cfg = Mamba3Config::new(32)
1769 .with_state_rank(8)
1770 .with_expand(2)
1771 .with_per_head_dim(8)
1772 .with_has_outproj_norm(true);
1773 run_step_matches_forward(cfg);
1774 }
1775
1776 #[test]
1777 fn step_matches_forward_outproj_norm_mimo() {
1778 let cfg = Mamba3Config::new(32)
1779 .with_state_rank(8)
1780 .with_expand(2)
1781 .with_per_head_dim(8)
1782 .with_mimo_rank(2)
1783 .with_has_outproj_norm(true);
1784 run_step_matches_forward(cfg);
1785 }
1786
1787 #[test]
1788 fn split_matches_full_outproj_norm() {
1789 let cfg = Mamba3Config::new(32)
1790 .with_state_rank(8)
1791 .with_expand(2)
1792 .with_per_head_dim(8)
1793 .with_has_outproj_norm(true);
1794 run_split_matches_full(cfg);
1795 }
1796
1797 #[test]
1800 fn step_matches_forward_rope_half_outproj_norm_mimo() {
1801 let cfg = Mamba3Config::new(32)
1802 .with_state_rank(8)
1803 .with_expand(2)
1804 .with_per_head_dim(8)
1805 .with_mimo_rank(2)
1806 .with_rope_fraction(0.5)
1807 .with_has_outproj_norm(true);
1808 run_step_matches_forward(cfg);
1809 }
1810}