burn_mamba/mamba3/double_ssd/ssd/ssd_path/mod.rs
1//! # SSD input bundle for the Mamba-3 double-SSD pathway
2//!
3//! [`Mamba3DoubleSsdInput`] gathers the pre-processed tensors a single standard
4//! SSD pass consumes (B/C already QK-normed, RoPE-applied, bias-added, and
5//! GQA-expanded to per-head; `v` already scaled by the trapezoid coefficient γ
6//! or β; `da = Δ·A` pre-combined). [`Mamba3DoubleSsdInput::run`] dispatches to
7//! the algorithm chosen by the shared [`Mamba3SsdPath`].
8//!
9//! [`Mamba3SsdPath`]: crate::mamba3::ssd_path::Mamba3SsdPath
10
11use crate::mamba3::prelude::*;
12use burn::prelude::*;
13
14/// MIMO-first SSD input.
15///
16/// All tensors are pre-processed: B/C are already QK-normed, RoPE-applied, bias-added, and
17/// expanded to per-head (not per-group). V is already scaled by the (double-ssd) trapezoidal
18/// coefficient (γ or β). The combined log-decay `da = Δ·A` is pre-computed. D skip is handled
19/// by the caller.
20pub struct Mamba3DoubleSsdInput {
21 /// Value tensor, already scaled by (double-ssd) trapezoidal coefficient (γ or β).
22 ///
23 /// # Shape
24 /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
25 pub v_bnlmhp: Tensor<6>,
26
27 /// Pre-combined log-decay `Δ·A` (negative).
28 ///
29 /// # Shape
30 /// - `[batch, nchunks, chunk_len, nheads]`
31 pub da_bnlh: Tensor<4>,
32
33 /// Key/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head, per-rank.
34 ///
35 /// # Shape
36 /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
37 pub b_bnlmhr: Tensor<6>,
38
39 /// Query/C tensor: same processing as B.
40 ///
41 /// # Shape
42 /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
43 pub c_bnlmhr: Tensor<6>,
44
45 /// Initial SSM hidden state.
46 ///
47 /// # Shape
48 /// - `[batch, nheads, per_head_dim, state_rank]`
49 pub initial_state_bhpr: Tensor<4>,
50
51 /// Optional learnable initial state (broadcast over batch).
52 ///
53 /// # Shape
54 /// - `[nheads, per_head_dim, state_rank]`
55 pub init_state_hpr: Option<Tensor<3>>,
56}
57
58impl Mamba3DoubleSsdInput {
59 /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every input tensor.
60 pub fn sanity(&self) {
61 use crate::modules::sanity as san;
62 san(&self.v_bnlmhp);
63 san(&self.da_bnlh);
64 san(&self.b_bnlmhr);
65 san(&self.c_bnlmhr);
66 san(&self.initial_state_bhpr);
67 if let Some(ref init_state_hpr) = self.init_state_hpr {
68 san(init_state_hpr);
69 }
70 }
71}
72
73impl Mamba3DoubleSsdInput {
74 /// Run the selected double-ssd algorithm on this MIMO-first input.
75 ///
76 /// Dispatches by [`Mamba3SsdPath`] variant to `double_ssd_minimal`,
77 /// `double_ssd_serial`, or `double_ssd_serial_recalculated`.
78 ///
79 /// # Returns
80 /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
81 /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
82 pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>) {
83 match path {
84 Mamba3SsdPath::Minimal(_) => self.double_ssd_minimal(),
85 Mamba3SsdPath::Serial(_) => self.double_ssd_serial(),
86 Mamba3SsdPath::SerialRecalculated(_) => self.double_ssd_serial_recalculated(),
87 }
88 }
89}
90
91// ---------------------------------------------------------------------------
92// Tests
93// ---------------------------------------------------------------------------
94
95#[cfg(all(test, feature = "_dev-test"))]
96mod tests;