1use crate::mamba2::prelude::*;
2use burn::prelude::*;
3
4#[derive(Debug, Clone)]
11pub enum Mamba2SsdPath {
12 Minimal(Option<usize>),
21 Serial(Option<usize>),
33 SerialRecalculated(Option<usize>),
42}
43
44pub struct Mamba2SsdInput<B: Backend> {
45 pub x_bnlhp: Tensor<B, 5>,
48 pub dt_bnlh: Tensor<B, 4>,
51 pub a_decay_h: Tensor<B, 1>,
54 pub b_bnlgr: Tensor<B, 5>,
57 pub c_bnlgr: Tensor<B, 5>,
60 pub d_h: Tensor<B, 1>,
63 pub initial_state_bhpr: Tensor<B, 4>,
66 pub init_state_hpr: Option<Tensor<B, 3>>,
69}
70
71impl<B: Backend> Mamba2SsdInput<B> {
72 pub fn sanity(&self) {
73 use crate::utils::sanity::sanity as san;
74 san(&self.x_bnlhp);
75 san(&self.dt_bnlh);
76 san(&self.a_decay_h);
77 san(&self.b_bnlgr);
78 san(&self.c_bnlgr);
79 san(&self.d_h);
80 san(&self.initial_state_bhpr);
81 if let Some(ref init_state_hpr) = self.init_state_hpr {
82 san(init_state_hpr);
83 }
84 }
85}
86
87impl Mamba2SsdPath {
88 pub fn optimal_default(state_rank: usize, per_head_dim: usize) -> usize {
90 (state_rank * per_head_dim)
91 .isqrt()
92 .next_multiple_of(32) .min(512) }
95
96 pub fn core_optimal(state_rank: usize, per_head_dim: usize) -> Self {
100 let optim = Self::optimal_default(state_rank, per_head_dim);
101 Self::Minimal(Some(optim))
102 }
103
104 pub fn core_optimal_from_block<B: Backend>(block: &Mamba2<B>) -> Self {
108 Self::core_optimal(block.state_rank, block.per_head_dim())
109 }
110
111 pub fn chunked_optimal(state_rank: usize, per_head_dim: usize) -> Self {
115 let optim = Self::optimal_default(state_rank, per_head_dim);
116 Self::Serial(Some(optim))
117 }
118
119 pub fn chunked_optimal_from_block<B: Backend>(block: &Mamba2<B>) -> Self {
123 Self::chunked_optimal(block.state_rank, block.per_head_dim())
124 }
125
126 pub fn chunked_recalculated_optimal(state_rank: usize, per_head_dim: usize) -> Self {
130 let optim = Self::optimal_default(state_rank, per_head_dim);
131 Self::SerialRecalculated(Some(optim))
132 }
133
134 pub fn chunked_recalculated_optimal_from_block<B: Backend>(block: &Mamba2<B>) -> Self {
138 Self::chunked_recalculated_optimal(block.state_rank, block.per_head_dim())
139 }
140
141 pub fn chunk_len(&self) -> Option<usize> {
142 match self {
143 Mamba2SsdPath::Minimal(chunk_len) => *chunk_len,
144 Mamba2SsdPath::Serial(chunk_len) => *chunk_len,
145 Mamba2SsdPath::SerialRecalculated(chunk_len) => *chunk_len,
146 }
147 }
148
149 pub fn chunk_len_or_optimal(&self, state_rank: usize, per_head_dim: usize) -> usize {
150 match self {
151 Mamba2SsdPath::Minimal(chunk_len) => {
152 chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
153 }
154 Mamba2SsdPath::Serial(chunk_len) => {
155 chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
156 }
157 Mamba2SsdPath::SerialRecalculated(chunk_len) => {
158 chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
159 }
160 }
161 }
162
163 pub fn run<B: Backend + Mamba2BackendExt>(
171 &self,
172 input: Mamba2SsdInput<B>,
173 ) -> (Tensor<B, 5>, Tensor<B, 4>) {
174 match self {
175 Mamba2SsdPath::Minimal(_) => Mamba2::<B>::ssd_minimal(input),
176 Mamba2SsdPath::Serial(_) => Mamba2::<B>::ssd_serial(input),
177 Mamba2SsdPath::SerialRecalculated(_) => Mamba2::<B>::ssd_serial_recalculated(input),
178 }
179 }
180}
181
182impl Default for Mamba2SsdPath {
183 fn default() -> Mamba2SsdPath {
184 Mamba2SsdPath::SerialRecalculated(None)
186 }
187}
188
189#[cfg(all(test, feature = "backend-flex"))]
194mod tests {
195 use super::*;
196 use burn::backend::{Autodiff, Flex};
197 use burn::module::Param;
198 use burn::tensor::Distribution;
199
200 type InnerB = Flex;
203 type B = Autodiff<InnerB>;
205
206 type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
207
208 fn random_input(
216 batch: usize,
217 nchunks: usize,
218 chunk_len: usize,
219 nheads: usize,
220 per_head_dim: usize,
221 ngroups: usize,
222 state_rank: usize,
223 device: &Device,
224 ) -> (
225 Tensor<InnerB, 5>,
226 Tensor<InnerB, 4>,
227 Tensor<InnerB, 1>,
228 Tensor<InnerB, 5>,
229 Tensor<InnerB, 5>,
230 Tensor<InnerB, 1>,
231 Tensor<InnerB, 4>,
232 ) {
233 let x = Tensor::<InnerB, 5>::random(
234 [batch, nchunks, chunk_len, nheads, per_head_dim],
235 Distribution::Normal(0.0, 1.0),
236 device,
237 );
238 let dt = Tensor::<InnerB, 4>::random(
239 [batch, nchunks, chunk_len, nheads],
240 Distribution::Uniform(0.05, 0.3),
241 device,
242 );
243 let a_decay =
244 Tensor::<InnerB, 1>::random([nheads], Distribution::Uniform(-1.0, -0.5), device);
245 let b = Tensor::<InnerB, 5>::random(
246 [batch, nchunks, chunk_len, ngroups, state_rank],
247 Distribution::Normal(0.0, 1.0),
248 device,
249 );
250 let c = Tensor::<InnerB, 5>::random(
251 [batch, nchunks, chunk_len, ngroups, state_rank],
252 Distribution::Normal(0.0, 1.0),
253 device,
254 );
255 let d = Tensor::<InnerB, 1>::random([nheads], Distribution::Normal(0.0, 0.1), device);
256 let initial_state = Tensor::<InnerB, 4>::random(
257 [batch, nheads, per_head_dim, state_rank],
258 Distribution::Normal(0.0, 0.1),
259 device,
260 );
261 (x, dt, a_decay, b, c, d, initial_state)
262 }
263
264 struct Inputs {
268 x: Param<Tensor<B, 5>>,
269 dt: Param<Tensor<B, 4>>,
270 a_decay: Param<Tensor<B, 1>>,
271 b: Param<Tensor<B, 5>>,
272 c: Param<Tensor<B, 5>>,
273 d: Param<Tensor<B, 1>>,
274 initial_state: Param<Tensor<B, 4>>,
275 }
276
277 impl Inputs {
278 #[allow(clippy::too_many_arguments)]
279 fn from_inner(
280 x: Tensor<InnerB, 5>,
281 dt: Tensor<InnerB, 4>,
282 a_decay: Tensor<InnerB, 1>,
283 b: Tensor<InnerB, 5>,
284 c: Tensor<InnerB, 5>,
285 d: Tensor<InnerB, 1>,
286 initial_state: Tensor<InnerB, 4>,
287 ) -> Self {
288 Self {
289 x: Param::from_tensor(Tensor::from_inner(x)),
290 dt: Param::from_tensor(Tensor::from_inner(dt)),
291 a_decay: Param::from_tensor(Tensor::from_inner(a_decay)),
292 b: Param::from_tensor(Tensor::from_inner(b)),
293 c: Param::from_tensor(Tensor::from_inner(c)),
294 d: Param::from_tensor(Tensor::from_inner(d)),
295 initial_state: Param::from_tensor(Tensor::from_inner(initial_state)),
296 }
297 }
298
299 fn ssd_input(&self) -> Mamba2SsdInput<B> {
300 Mamba2SsdInput {
301 x_bnlhp: self.x.val(),
302 dt_bnlh: self.dt.val(),
303 a_decay_h: self.a_decay.val(),
304 b_bnlgr: self.b.val(),
305 c_bnlgr: self.c.val(),
306 d_h: self.d.val(),
307 initial_state_bhpr: self.initial_state.val(),
308 init_state_hpr: None,
310 }
311 }
312 }
313
314 struct PathRun {
316 y: Tensor<InnerB, 5>,
317 state: Tensor<InnerB, 4>,
318 d_x: Tensor<InnerB, 5>,
319 d_dt: Tensor<InnerB, 4>,
320 d_a_decay: Tensor<InnerB, 1>,
321 d_b: Tensor<InnerB, 5>,
322 d_c: Tensor<InnerB, 5>,
323 d_d: Tensor<InnerB, 1>,
324 d_init_state: Tensor<InnerB, 4>,
325 }
326
327 fn loss_from_outputs(
332 y_bnlhp: Tensor<B, 5>,
333 final_state_bhpr: Tensor<B, 4>,
334 y_head: Tensor<InnerB, 5>,
335 s_head: Tensor<InnerB, 4>,
336 ) -> Tensor<B, 1> {
337 let y_head = Tensor::from_inner(y_head);
338 let s_head = Tensor::from_inner(s_head);
339 (y_bnlhp * y_head).sum() + (final_state_bhpr * s_head).sum()
340 }
341
342 fn run_path(
344 path: Mamba2SsdPath,
345 inputs: &Inputs,
346 y_head: Tensor<InnerB, 5>,
347 s_head: Tensor<InnerB, 4>,
348 ) -> PathRun {
349 let (y, state) = path.run(inputs.ssd_input());
350 let y_inner = y.clone().inner();
351 let state_inner = state.clone().inner();
352
353 let loss = loss_from_outputs(y, state, y_head, s_head);
354 let grads = loss.backward();
355
356 PathRun {
359 y: y_inner,
360 state: state_inner,
361 d_x: inputs.x.val().grad(&grads).expect("grad x"),
362 d_dt: inputs.dt.val().grad(&grads).expect("grad dt"),
363 d_a_decay: inputs.a_decay.val().grad(&grads).expect("grad a_decay"),
364 d_b: inputs.b.val().grad(&grads).expect("grad b"),
365 d_c: inputs.c.val().grad(&grads).expect("grad c"),
366 d_d: inputs.d.val().grad(&grads).expect("grad d"),
367 d_init_state: inputs
368 .initial_state
369 .val()
370 .grad(&grads)
371 .expect("grad initial_state"),
372 }
373 }
374
375 fn run_minimal_matches_serial(
383 batch: usize,
384 nchunks: usize,
385 chunk_len: usize,
386 nheads: usize,
387 per_head_dim: usize,
388 ngroups: usize,
389 state_rank: usize,
390 ) {
391 let device: Device = Default::default();
392 let (x, dt, a_decay, b, c, d, init) = random_input(
393 batch,
394 nchunks,
395 chunk_len,
396 nheads,
397 per_head_dim,
398 ngroups,
399 state_rank,
400 &device,
401 );
402
403 let y_head = Tensor::<InnerB, 5>::random(
407 [batch, nchunks, chunk_len, nheads, per_head_dim],
408 Distribution::Normal(0.0, 1.0),
409 &device,
410 );
411 let s_head = Tensor::<InnerB, 4>::random(
412 [batch, nheads, per_head_dim, state_rank],
413 Distribution::Normal(0.0, 1.0),
414 &device,
415 );
416
417 let inputs_min = Inputs::from_inner(
419 x.clone(),
420 dt.clone(),
421 a_decay.clone(),
422 b.clone(),
423 c.clone(),
424 d.clone(),
425 init.clone(),
426 );
427 let inputs_ser = Inputs::from_inner(
428 x.clone(),
429 dt.clone(),
430 a_decay.clone(),
431 b.clone(),
432 c.clone(),
433 d.clone(),
434 init.clone(),
435 );
436 let inputs_rec = Inputs::from_inner(x, dt, a_decay, b, c, d, init);
437
438 let r_min = run_path(
439 Mamba2SsdPath::Minimal(Some(chunk_len)),
440 &inputs_min,
441 y_head.clone(),
442 s_head.clone(),
443 );
444 let r_ser = run_path(
445 Mamba2SsdPath::Serial(Some(chunk_len)),
446 &inputs_ser,
447 y_head.clone(),
448 s_head.clone(),
449 );
450 let r_rec = run_path(
451 Mamba2SsdPath::SerialRecalculated(Some(chunk_len)),
452 &inputs_rec,
453 y_head,
454 s_head,
455 );
456
457 let tol = 1e-4;
459 let dy_ser = (r_min.y.clone() - r_ser.y.clone())
460 .abs()
461 .max()
462 .into_scalar();
463 let ds_ser = (r_min.state.clone() - r_ser.state.clone())
464 .abs()
465 .max()
466 .into_scalar();
467 let dy_rec = (r_min.y.clone() - r_rec.y.clone())
468 .abs()
469 .max()
470 .into_scalar();
471 let ds_rec = (r_min.state.clone() - r_rec.state.clone())
472 .abs()
473 .max()
474 .into_scalar();
475 assert!(
476 dy_ser < tol,
477 "Minimal vs Serial: y max abs diff = {dy_ser:.6} (tol {tol})"
478 );
479 assert!(
480 ds_ser < tol,
481 "Minimal vs Serial: final_state max abs diff = {ds_ser:.6} (tol {tol})"
482 );
483 assert!(
484 dy_rec < tol,
485 "Minimal vs SerialRecalculated: y max abs diff = {dy_rec:.6} (tol {tol})"
486 );
487 assert!(
488 ds_rec < tol,
489 "Minimal vs SerialRecalculated: final_state max abs diff = {ds_rec:.6} (tol {tol})"
490 );
491
492 let grad_tol = 1e-3;
497
498 let mut failures: Vec<String> = Vec::new();
499 macro_rules! diff {
500 ($a:expr, $b:expr) => {
501 ($a.clone() - $b.clone()).abs().max().into_scalar()
502 };
503 }
504 macro_rules! check_grad {
505 ($field:ident, $name:expr) => {{
506 let d_ser = diff!(r_min.$field, r_ser.$field);
507 let d_rec = diff!(r_min.$field, r_rec.$field);
508 eprintln!(
509 "grad {:>14} | min↔ser = {:>10.6} | min↔rec = {:>10.6}",
510 $name, d_ser, d_rec
511 );
512 if d_ser >= grad_tol {
513 failures.push(format!(
514 "Minimal vs Serial: grad of {} max abs diff = {:.6} (tol {})",
515 $name, d_ser, grad_tol
516 ));
517 }
518 if d_rec >= grad_tol {
519 failures.push(format!(
520 "Minimal vs SerialRecalculated: grad of {} max abs diff = {:.6} (tol {})",
521 $name, d_rec, grad_tol
522 ));
523 }
524 }};
525 }
526 check_grad!(d_x, "x");
527 check_grad!(d_dt, "dt");
528 check_grad!(d_a_decay, "a_decay");
529 check_grad!(d_b, "b");
530 check_grad!(d_c, "c");
531 check_grad!(d_d, "d");
532 check_grad!(d_init_state, "initial_state");
533
534 assert!(
535 failures.is_empty(),
536 "gradient mismatches:\n {}",
537 failures.join("\n ")
538 );
539 }
540
541 #[test]
542 fn paths_agree_no_gqa() {
543 run_minimal_matches_serial(2, 3, 4, 2, 8, 2, 8);
545 }
546
547 #[test]
548 fn paths_agree_gqa() {
549 run_minimal_matches_serial(2, 3, 4, 4, 8, 1, 8);
551 }
552
553 #[test]
554 fn paths_agree_single_chunk() {
555 run_minimal_matches_serial(2, 1, 4, 2, 8, 2, 8);
558 }
559}