Skip to main content

burn_mamba/mamba3/single_ssd/ssd/ssd_path/
mod.rs

1//! # Single-Pass SSD — Input Bundle
2//!
3//! Sibling to [`crate::mamba3::double_ssd::ssd::ssd_path`]. Where the double-ssd
4//! pathway runs the *standard* SSD twice (γ-term and β-term), this module's
5//! [`Mamba3SingleSsdInput`] runs **one** merged SSD pass that absorbs both
6//! contributions by scaling `K` with `scaleₜ = γₜ + (1−λₜ₊₁) Δₜ₊₁`. The same-step
7//! diagonal contribution differs (it must use `γₜ`, not `scaleₜ`) and is patched
8//! via an explicit correction term inside each variant.
9//
10//! Reference kernels:
11//! - `refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py`
12//! - `refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py`
13//!
14//! The interface is MIMO-first (matches the other burn-mamba SSD inputs),
15//! with `mimo_rank = 1` collapsing to the SISO case. The algorithm is selected
16//! by [`Mamba3SsdPath`], shared with the double-ssd pathway.
17
18use crate::mamba3::prelude::*;
19use burn::prelude::*;
20
21/// MIMO-first input bundle for the merged-form SSD.
22///
23/// All tensors are pre-processed by the caller (`Mamba3::forward_single_ssd`): B/C are
24/// already QK-normed, RoPE-applied, bias-added, and expanded to per-head; V is
25/// the raw, *unscaled* MIMO-expanded value. The combined log-decay `da = Δ·A`
26/// is pre-computed. The two trapezoidal coefficients `gammaₜ` and `scaleₜ` are
27/// supplied separately because the SSD itself does the K-scaling and γ-weighted
28/// diagonal correction internally. D-skip and Z-gating are handled by the
29/// caller.
30pub struct Mamba3SingleSsdInput {
31    /// Value tensor, MIMO-expanded but **not** trapezoidally scaled.
32    ///
33    /// # Shape
34    /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
35    pub v_bnlmhp: Tensor<6>,
36
37    /// K/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head.
38    /// Not pre-scaled — the SSD multiplies by `scaleₜ` internally for the
39    /// lower-triangular and state-recurrence paths, while the diagonal
40    /// correction reuses the unscaled tensor.
41    ///
42    /// # Shape
43    /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
44    pub b_bnlmhr: Tensor<6>,
45
46    /// Q/C tensor: same processing as `b_bnlmhr`.
47    ///
48    /// # Shape
49    /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
50    pub c_bnlmhr: Tensor<6>,
51
52    /// Pre-combined log-decay `Δ·A` (negative).
53    ///
54    /// # Shape
55    /// - `[batch, nchunks, chunk_len, nheads]`
56    pub da_bnlh: Tensor<4>,
57
58    /// `γₜ = λₜ · Δₜ` — used as the per-token diagonal multiplier.
59    ///
60    /// # Shape
61    /// - `[batch, nchunks, chunk_len, nheads]`
62    pub gamma_bnlh: Tensor<4>,
63
64    /// `scaleₜ = γₜ + (1 − λₜ₊₁) · Δₜ₊₁` — K is multiplied by this for the
65    /// lower-triangular and state recurrence paths. The shifted term is zero
66    /// at the very last sequence position (no future token exists).
67    ///
68    /// # Shape
69    /// - `[batch, nchunks, chunk_len, nheads]`
70    pub scale_bnlh: Tensor<4>,
71
72    /// Initial SSM hidden state (merged-form accumulator).
73    ///
74    /// When continuing from a prior call, this should already include the
75    /// boundary β contribution `(1 − λ₀) · Δ₀ · Σₘ Kₜ₋₁[m] ⊗ (xₜ₋₁ ⊙ mimo_xₘ)`
76    /// (which the previous call could not yet add because it did not know
77    /// `λ₀, Δ₀`).
78    ///
79    /// # Shape
80    /// - `[batch, nheads, per_head_dim, state_rank]`
81    pub initial_state_bhpr: Tensor<4>,
82
83    /// Optional learnable initial state (broadcast over batch).
84    ///
85    /// # Shape
86    /// - `[nheads, per_head_dim, state_rank]`
87    pub init_state_hpr: Option<Tensor<3>>,
88}
89
90impl Mamba3SingleSsdInput {
91    /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every input tensor.
92    pub fn sanity(&self) {
93        use crate::modules::sanity as san;
94        san(&self.v_bnlmhp);
95        san(&self.b_bnlmhr);
96        san(&self.c_bnlmhr);
97        san(&self.da_bnlh);
98        san(&self.gamma_bnlh);
99        san(&self.scale_bnlh);
100        san(&self.initial_state_bhpr);
101        if let Some(ref init_state_hpr) = self.init_state_hpr {
102            san(init_state_hpr);
103        }
104    }
105}
106
107impl Mamba3SingleSsdInput {
108    /// Run the selected merged-form (single-ssd) algorithm on this MIMO-first input.
109    ///
110    /// Dispatches by [`Mamba3SsdPath`] variant to `single_ssd_minimal`,
111    /// `single_ssd_serial`, or `single_ssd_serial_recalculated`.
112    ///
113    /// # Returns
114    /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
115    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]` —
116    ///   the merged-form accumulator at the last token (to be stored in the
117    ///   cache for streaming).
118    pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>) {
119        match path {
120            Mamba3SsdPath::Minimal(_) => self.single_ssd_minimal(),
121            Mamba3SsdPath::Serial(_) => self.single_ssd_serial(),
122            Mamba3SsdPath::SerialRecalculated(_) => self.single_ssd_serial_recalculated(),
123        }
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Tests — Minimal ≡ Serial (forward outputs + input gradients)
129// ---------------------------------------------------------------------------
130
131#[cfg(all(test, feature = "_dev-test"))]
132mod tests;