1use crate::mamba3::prelude::*;
2use burn::prelude::*;
3
4#[derive(Debug, Clone)]
11pub enum Mamba3SsdPath {
12 Minimal(Option<usize>),
21 Serial(Option<usize>),
33 SerialRecalculated(Option<usize>),
42}
43
44pub struct Mamba3SsdInput<B: Backend> {
50 pub v_bnlrhp: Tensor<B, 6>,
55
56 pub da_bnlh: Tensor<B, 4>,
61
62 pub b_bnlrhn: Tensor<B, 6>,
67
68 pub c_bnlrhn: Tensor<B, 6>,
73
74 pub initial_state_bhpr: Tensor<B, 4>,
79
80 pub init_state_hpr: Option<Tensor<B, 3>>,
85}
86
87impl<B: Backend> Mamba3SsdInput<B> {
88 pub fn sanity(&self) {
89 use crate::utils::sanity::sanity as san;
90 san(&self.v_bnlrhp);
91 san(&self.da_bnlh);
92 san(&self.b_bnlrhn);
93 san(&self.c_bnlrhn);
94 san(&self.initial_state_bhpr);
95 if let Some(ref init_state_hpr) = self.init_state_hpr {
96 san(init_state_hpr);
97 }
98 }
99}
100
101impl Mamba3SsdPath {
102 pub fn optimal_default(state_rank: usize, per_head_dim: usize) -> usize {
104 (state_rank * per_head_dim)
105 .isqrt()
106 .next_multiple_of(32) .min(512) }
109
110 pub fn core_optimal(state_rank: usize, per_head_dim: usize) -> Self {
114 let optim = Self::optimal_default(state_rank, per_head_dim);
115 Self::Minimal(Some(optim))
116 }
117
118 pub fn core_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self {
122 Self::core_optimal(block.state_rank, block.per_head_dim())
123 }
124
125 pub fn chunked_optimal(state_rank: usize, per_head_dim: usize) -> Self {
129 let optim = Self::optimal_default(state_rank, per_head_dim);
130 Self::Serial(Some(optim))
131 }
132
133 pub fn chunked_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self {
137 Self::chunked_optimal(block.state_rank, block.per_head_dim())
138 }
139
140 pub fn chunked_recalculated_optimal(state_rank: usize, per_head_dim: usize) -> Self {
144 let optim = Self::optimal_default(state_rank, per_head_dim);
145 Self::SerialRecalculated(Some(optim))
146 }
147
148 pub fn chunked_recalculated_optimal_from_block<B: Backend>(block: &Mamba3<B>) -> Self {
152 Self::chunked_recalculated_optimal(block.state_rank, block.per_head_dim())
153 }
154
155 pub fn chunk_len(&self) -> Option<usize> {
156 match self {
157 Mamba3SsdPath::Minimal(chunk_len) => *chunk_len,
158 Mamba3SsdPath::Serial(chunk_len) => *chunk_len,
159 Mamba3SsdPath::SerialRecalculated(chunk_len) => *chunk_len,
160 }
161 }
162
163 pub fn chunk_len_or_optimal(&self, state_rank: usize, per_head_dim: usize) -> usize {
164 match self {
165 Mamba3SsdPath::Minimal(chunk_len) => {
166 chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
167 }
168 Mamba3SsdPath::Serial(chunk_len) => {
169 chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
170 }
171 Mamba3SsdPath::SerialRecalculated(chunk_len) => {
172 chunk_len.unwrap_or_else(|| Self::optimal_default(state_rank, per_head_dim))
173 }
174 }
175 }
176
177 pub fn run<B: Backend + Mamba3BackendExt>(
185 &self,
186 input: Mamba3SsdInput<B>,
187 ) -> (Tensor<B, 6>, Tensor<B, 4>) {
188 match self {
189 Mamba3SsdPath::Minimal(_) => Mamba3::<B>::ssd_minimal(input),
190 Mamba3SsdPath::Serial(_) => Mamba3::<B>::ssd_serial(input),
191 Mamba3SsdPath::SerialRecalculated(_) => Mamba3::<B>::ssd_serial_recalculated(input),
192 }
193 }
194}
195
196impl Default for Mamba3SsdPath {
197 fn default() -> Mamba3SsdPath {
198 Mamba3SsdPath::SerialRecalculated(None)
200 }
201}
202
203#[cfg(all(test, feature = "backend-flex"))]
208mod tests {
209 use super::*;
210 use burn::backend::{Autodiff, Flex};
211 use burn::module::Param;
212 use burn::tensor::Distribution;
213
214 type InnerB = Flex;
217 type B = Autodiff<InnerB>;
219
220 type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
221
222 fn random_input(
230 batch: usize,
231 nchunks: usize,
232 chunk_len: usize,
233 mimo_rank: usize,
234 nheads: usize,
235 per_head_dim: usize,
236 state_rank: usize,
237 device: &Device,
238 ) -> (
239 Tensor<InnerB, 6>,
240 Tensor<InnerB, 4>,
241 Tensor<InnerB, 6>,
242 Tensor<InnerB, 6>,
243 Tensor<InnerB, 4>,
244 ) {
245 let v = Tensor::<InnerB, 6>::random(
246 [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim],
247 Distribution::Normal(0.0, 1.0),
248 device,
249 );
250 let da = Tensor::<InnerB, 4>::random(
251 [batch, nchunks, chunk_len, nheads],
252 Distribution::Normal(-0.5, 0.1),
253 device,
254 );
255 let b = Tensor::<InnerB, 6>::random(
256 [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank],
257 Distribution::Normal(0.0, 1.0),
258 device,
259 );
260 let c = Tensor::<InnerB, 6>::random(
261 [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank],
262 Distribution::Normal(0.0, 1.0),
263 device,
264 );
265 let initial_state = Tensor::<InnerB, 4>::random(
266 [batch, nheads, per_head_dim, state_rank],
267 Distribution::Normal(0.0, 0.1),
268 device,
269 );
270 (v, da, b, c, initial_state)
271 }
272
273 struct Inputs {
277 v: Param<Tensor<B, 6>>,
278 da: Param<Tensor<B, 4>>,
279 b: Param<Tensor<B, 6>>,
280 c: Param<Tensor<B, 6>>,
281 initial_state: Param<Tensor<B, 4>>,
282 }
283
284 impl Inputs {
285 fn from_inner(
286 v: Tensor<InnerB, 6>,
287 da: Tensor<InnerB, 4>,
288 b: Tensor<InnerB, 6>,
289 c: Tensor<InnerB, 6>,
290 initial_state: Tensor<InnerB, 4>,
291 ) -> Self {
292 Self {
293 v: Param::from_tensor(Tensor::from_inner(v)),
294 da: Param::from_tensor(Tensor::from_inner(da)),
295 b: Param::from_tensor(Tensor::from_inner(b)),
296 c: Param::from_tensor(Tensor::from_inner(c)),
297 initial_state: Param::from_tensor(Tensor::from_inner(initial_state)),
298 }
299 }
300
301 fn ssd_input(&self) -> Mamba3SsdInput<B> {
302 Mamba3SsdInput {
303 v_bnlrhp: self.v.val(),
304 da_bnlh: self.da.val(),
305 b_bnlrhn: self.b.val(),
306 c_bnlrhn: self.c.val(),
307 initial_state_bhpr: self.initial_state.val(),
308 init_state_hpr: None,
310 }
311 }
312 }
313
314 struct PathRun {
316 y: Tensor<InnerB, 6>,
317 state: Tensor<InnerB, 4>,
318 d_v: Tensor<InnerB, 6>,
319 d_da: Tensor<InnerB, 4>,
320 d_b: Tensor<InnerB, 6>,
321 d_c: Tensor<InnerB, 6>,
322 d_init_state: Tensor<InnerB, 4>,
323 }
324
325 fn loss_from_outputs(
329 y_bnlrhp: Tensor<B, 6>,
330 final_state_bhpr: Tensor<B, 4>,
331 y_head: Tensor<InnerB, 6>,
332 s_head: Tensor<InnerB, 4>,
333 ) -> Tensor<B, 1> {
334 let y_head = Tensor::from_inner(y_head);
335 let s_head = Tensor::from_inner(s_head);
336 (y_bnlrhp * y_head).sum() + (final_state_bhpr * s_head).sum()
337 }
338
339 fn run_path(
341 path: Mamba3SsdPath,
342 inputs: &Inputs,
343 y_head: Tensor<InnerB, 6>,
344 s_head: Tensor<InnerB, 4>,
345 ) -> PathRun {
346 let (y, state) = path.run(inputs.ssd_input());
347 let y_inner = y.clone().inner();
348 let state_inner = state.clone().inner();
349
350 let loss = loss_from_outputs(y, state, y_head, s_head);
351 let grads = loss.backward();
352
353 PathRun {
354 y: y_inner,
355 state: state_inner,
356 d_v: inputs.v.val().grad(&grads).expect("grad v"),
357 d_da: inputs.da.val().grad(&grads).expect("grad da"),
358 d_b: inputs.b.val().grad(&grads).expect("grad b"),
359 d_c: inputs.c.val().grad(&grads).expect("grad c"),
360 d_init_state: inputs
361 .initial_state
362 .val()
363 .grad(&grads)
364 .expect("grad initial_state"),
365 }
366 }
367
368 fn run_minimal_matches_serial(
377 batch: usize,
378 nchunks: usize,
379 chunk_len: usize,
380 mimo_rank: usize,
381 nheads: usize,
382 per_head_dim: usize,
383 state_rank: usize,
384 ) {
385 let device: Device = Default::default();
386 let (v, da, b, c, init) = random_input(
387 batch,
388 nchunks,
389 chunk_len,
390 mimo_rank,
391 nheads,
392 per_head_dim,
393 state_rank,
394 &device,
395 );
396
397 let y_head = Tensor::<InnerB, 6>::random(
399 [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim],
400 Distribution::Normal(0.0, 1.0),
401 &device,
402 );
403 let s_head = Tensor::<InnerB, 4>::random(
404 [batch, nheads, per_head_dim, state_rank],
405 Distribution::Normal(0.0, 1.0),
406 &device,
407 );
408
409 let inputs_min =
411 Inputs::from_inner(v.clone(), da.clone(), b.clone(), c.clone(), init.clone());
412 let inputs_ser =
413 Inputs::from_inner(v.clone(), da.clone(), b.clone(), c.clone(), init.clone());
414 let inputs_rec = Inputs::from_inner(v, da, b, c, init);
415
416 let r_min = run_path(
417 Mamba3SsdPath::Minimal(Some(chunk_len)),
418 &inputs_min,
419 y_head.clone(),
420 s_head.clone(),
421 );
422 let r_ser = run_path(
423 Mamba3SsdPath::Serial(Some(chunk_len)),
424 &inputs_ser,
425 y_head.clone(),
426 s_head.clone(),
427 );
428 let r_rec = run_path(
429 Mamba3SsdPath::SerialRecalculated(Some(chunk_len)),
430 &inputs_rec,
431 y_head,
432 s_head,
433 );
434
435 let tol = 1e-4;
437 let dy_ser = (r_min.y.clone() - r_ser.y.clone())
438 .abs()
439 .max()
440 .into_scalar();
441 let ds_ser = (r_min.state.clone() - r_ser.state.clone())
442 .abs()
443 .max()
444 .into_scalar();
445 let dy_rec = (r_min.y.clone() - r_rec.y.clone())
446 .abs()
447 .max()
448 .into_scalar();
449 let ds_rec = (r_min.state.clone() - r_rec.state.clone())
450 .abs()
451 .max()
452 .into_scalar();
453 assert!(
454 dy_ser < tol,
455 "Minimal vs Serial: y max abs diff = {dy_ser:.6} (tol {tol})"
456 );
457 assert!(
458 ds_ser < tol,
459 "Minimal vs Serial: final_state max abs diff = {ds_ser:.6} (tol {tol})"
460 );
461 assert!(
462 dy_rec < tol,
463 "Minimal vs SerialRecalculated: y max abs diff = {dy_rec:.6} (tol {tol})"
464 );
465 assert!(
466 ds_rec < tol,
467 "Minimal vs SerialRecalculated: final_state max abs diff = {ds_rec:.6} (tol {tol})"
468 );
469
470 let grad_tol = 1e-3;
475
476 let mut failures: Vec<String> = Vec::new();
477 macro_rules! diff {
478 ($a:expr, $b:expr) => {
479 ($a.clone() - $b.clone()).abs().max().into_scalar()
480 };
481 }
482 macro_rules! check_grad {
483 ($field:ident, $name:expr) => {{
484 let d_ser = diff!(r_min.$field, r_ser.$field);
485 let d_rec = diff!(r_min.$field, r_rec.$field);
486 eprintln!(
487 "grad {:>14} | min↔ser = {:>10.6} | min↔rec = {:>10.6}",
488 $name, d_ser, d_rec
489 );
490 if d_ser >= grad_tol {
491 failures.push(format!(
492 "Minimal vs Serial: grad of {} max abs diff = {:.6} (tol {})",
493 $name, d_ser, grad_tol
494 ));
495 }
496 if d_rec >= grad_tol {
497 failures.push(format!(
498 "Minimal vs SerialRecalculated: grad of {} max abs diff = {:.6} (tol {})",
499 $name, d_rec, grad_tol
500 ));
501 }
502 }};
503 }
504 check_grad!(d_v, "v");
505 check_grad!(d_da, "da");
506 check_grad!(d_b, "b");
507 check_grad!(d_c, "c");
508 check_grad!(d_init_state, "initial_state");
509
510 assert!(
511 failures.is_empty(),
512 "gradient mismatches:\n {}",
513 failures.join("\n ")
514 );
515 }
516
517 #[test]
518 fn paths_agree_siso() {
519 run_minimal_matches_serial(2, 3, 4, 1, 2, 8, 8);
521 }
522
523 #[test]
524 fn paths_agree_mimo() {
525 run_minimal_matches_serial(2, 3, 4, 2, 2, 8, 8);
527 }
528
529 #[test]
530 fn paths_agree_single_chunk() {
531 run_minimal_matches_serial(2, 1, 4, 1, 2, 8, 8);
534 }
535}