Skip to main content

burn_mamba/mamba3/double_ssd/ssd/ssd_path/
mod.rs

1//! # SSD input bundle for the Mamba-3 double-SSD pathway
2//!
3//! [`Mamba3DoubleSsdInput`] gathers the pre-processed tensors a single standard
4//! SSD pass consumes (B/C already QK-normed, RoPE-applied, bias-added, and
5//! GQA-expanded to per-head; `v` already scaled by the trapezoid coefficient γ
6//! or β; `da = Δ·A` pre-combined).  [`Mamba3DoubleSsdInput::run`] dispatches to
7//! the algorithm chosen by the shared [`Mamba3SsdPath`].
8//!
9//! [`Mamba3SsdPath`]: crate::mamba3::ssd_path::Mamba3SsdPath
10
11use crate::mamba3::prelude::*;
12use burn::prelude::*;
13
14/// MIMO-first SSD input.
15///
16/// All tensors are pre-processed: B/C are already QK-normed, RoPE-applied, bias-added, and
17/// expanded to per-head (not per-group). V is already scaled by the (double-ssd) trapezoidal
18/// coefficient (γ or β). The combined log-decay `da = Δ·A` is pre-computed. D skip is handled
19/// by the caller.
20pub struct Mamba3DoubleSsdInput {
21    /// Value tensor, already scaled by (double-ssd) trapezoidal coefficient (γ or β).
22    ///
23    /// # Shape
24    /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
25    pub v_bnlmhp: Tensor<6>,
26
27    /// Pre-combined log-decay `Δ·A` (negative).
28    ///
29    /// # Shape
30    /// - `[batch, nchunks, chunk_len, nheads]`
31    pub da_bnlh: Tensor<4>,
32
33    /// Key/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head, per-rank.
34    ///
35    /// # Shape
36    /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
37    pub b_bnlmhr: Tensor<6>,
38
39    /// Query/C tensor: same processing as B.
40    ///
41    /// # Shape
42    /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
43    pub c_bnlmhr: Tensor<6>,
44
45    /// Initial SSM hidden state.
46    ///
47    /// # Shape
48    /// - `[batch, nheads, per_head_dim, state_rank]`
49    pub initial_state_bhpr: Tensor<4>,
50
51    /// Optional learnable initial state (broadcast over batch).
52    ///
53    /// # Shape
54    /// - `[nheads, per_head_dim, state_rank]`
55    pub init_state_hpr: Option<Tensor<3>>,
56}
57
58impl Mamba3DoubleSsdInput {
59    /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every input tensor.
60    pub fn sanity(&self) {
61        use crate::modules::sanity as san;
62        san(&self.v_bnlmhp);
63        san(&self.da_bnlh);
64        san(&self.b_bnlmhr);
65        san(&self.c_bnlmhr);
66        san(&self.initial_state_bhpr);
67        if let Some(ref init_state_hpr) = self.init_state_hpr {
68            san(init_state_hpr);
69        }
70    }
71}
72
73impl Mamba3DoubleSsdInput {
74    /// Run the selected double-ssd algorithm on this MIMO-first input.
75    ///
76    /// Dispatches by [`Mamba3SsdPath`] variant to `double_ssd_minimal`,
77    /// `double_ssd_serial`, or `double_ssd_serial_recalculated`.
78    ///
79    /// # Returns
80    /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
81    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]`
82    pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>) {
83        match path {
84            Mamba3SsdPath::Minimal(_) => self.double_ssd_minimal(),
85            Mamba3SsdPath::Serial(_) => self.double_ssd_serial(),
86            Mamba3SsdPath::SerialRecalculated(_) => self.double_ssd_serial_recalculated(),
87        }
88    }
89}
90
91// ---------------------------------------------------------------------------
92// Tests
93// ---------------------------------------------------------------------------
94
95#[cfg(all(test, feature = "_dev-test"))]
96mod tests;