burn_mamba/mamba3/single_ssd/ssd/ssd_path/mod.rs
1//! # Single-Pass SSD — Input Bundle
2//!
3//! Sibling to [`crate::mamba3::double_ssd::ssd::ssd_path`]. Where the double-ssd
4//! pathway runs the *standard* SSD twice (γ-term and β-term), this module's
5//! [`Mamba3SingleSsdInput`] runs **one** merged SSD pass that absorbs both
6//! contributions by scaling `K` with `scaleₜ = γₜ + (1−λₜ₊₁) Δₜ₊₁`. The same-step
7//! diagonal contribution differs (it must use `γₜ`, not `scaleₜ`) and is patched
8//! via an explicit correction term inside each variant.
9//
10//! Reference kernels:
11//! - `refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py`
12//! - `refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py`
13//!
14//! The interface is MIMO-first (matches the other burn-mamba SSD inputs),
15//! with `mimo_rank = 1` collapsing to the SISO case. The algorithm is selected
16//! by [`Mamba3SsdPath`], shared with the double-ssd pathway.
17
18use crate::mamba3::prelude::*;
19use burn::prelude::*;
20
21/// MIMO-first input bundle for the merged-form SSD.
22///
23/// All tensors are pre-processed by the caller (`Mamba3::forward_single_ssd`): B/C are
24/// already QK-normed, RoPE-applied, bias-added, and expanded to per-head; V is
25/// the raw, *unscaled* MIMO-expanded value. The combined log-decay `da = Δ·A`
26/// is pre-computed. The two trapezoidal coefficients `gammaₜ` and `scaleₜ` are
27/// supplied separately because the SSD itself does the K-scaling and γ-weighted
28/// diagonal correction internally. D-skip and Z-gating are handled by the
29/// caller.
30pub struct Mamba3SingleSsdInput {
31 /// Value tensor, MIMO-expanded but **not** trapezoidally scaled.
32 ///
33 /// # Shape
34 /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
35 pub v_bnlmhp: Tensor<6>,
36
37 /// K/B tensor: QK-normed, RoPE-applied, bias-added, expanded to per-head.
38 /// Not pre-scaled — the SSD multiplies by `scaleₜ` internally for the
39 /// lower-triangular and state-recurrence paths, while the diagonal
40 /// correction reuses the unscaled tensor.
41 ///
42 /// # Shape
43 /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
44 pub b_bnlmhr: Tensor<6>,
45
46 /// Q/C tensor: same processing as `b_bnlmhr`.
47 ///
48 /// # Shape
49 /// - `[batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]`
50 pub c_bnlmhr: Tensor<6>,
51
52 /// Pre-combined log-decay `Δ·A` (negative).
53 ///
54 /// # Shape
55 /// - `[batch, nchunks, chunk_len, nheads]`
56 pub da_bnlh: Tensor<4>,
57
58 /// `γₜ = λₜ · Δₜ` — used as the per-token diagonal multiplier.
59 ///
60 /// # Shape
61 /// - `[batch, nchunks, chunk_len, nheads]`
62 pub gamma_bnlh: Tensor<4>,
63
64 /// `scaleₜ = γₜ + (1 − λₜ₊₁) · Δₜ₊₁` — K is multiplied by this for the
65 /// lower-triangular and state recurrence paths. The shifted term is zero
66 /// at the very last sequence position (no future token exists).
67 ///
68 /// # Shape
69 /// - `[batch, nchunks, chunk_len, nheads]`
70 pub scale_bnlh: Tensor<4>,
71
72 /// Initial SSM hidden state (merged-form accumulator).
73 ///
74 /// When continuing from a prior call, this should already include the
75 /// boundary β contribution `(1 − λ₀) · Δ₀ · Σₘ Kₜ₋₁[m] ⊗ (xₜ₋₁ ⊙ mimo_xₘ)`
76 /// (which the previous call could not yet add because it did not know
77 /// `λ₀, Δ₀`).
78 ///
79 /// # Shape
80 /// - `[batch, nheads, per_head_dim, state_rank]`
81 pub initial_state_bhpr: Tensor<4>,
82
83 /// Optional learnable initial state (broadcast over batch).
84 ///
85 /// # Shape
86 /// - `[nheads, per_head_dim, state_rank]`
87 pub init_state_hpr: Option<Tensor<3>>,
88}
89
90impl Mamba3SingleSsdInput {
91 /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every input tensor.
92 pub fn sanity(&self) {
93 use crate::modules::sanity as san;
94 san(&self.v_bnlmhp);
95 san(&self.b_bnlmhr);
96 san(&self.c_bnlmhr);
97 san(&self.da_bnlh);
98 san(&self.gamma_bnlh);
99 san(&self.scale_bnlh);
100 san(&self.initial_state_bhpr);
101 if let Some(ref init_state_hpr) = self.init_state_hpr {
102 san(init_state_hpr);
103 }
104 }
105}
106
107impl Mamba3SingleSsdInput {
108 /// Run the selected merged-form (single-ssd) algorithm on this MIMO-first input.
109 ///
110 /// Dispatches by [`Mamba3SsdPath`] variant to `single_ssd_minimal`,
111 /// `single_ssd_serial`, or `single_ssd_serial_recalculated`.
112 ///
113 /// # Returns
114 /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
115 /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]` —
116 /// the merged-form accumulator at the last token (to be stored in the
117 /// cache for streaming).
118 pub fn run(self, path: &Mamba3SsdPath) -> (Tensor<6>, Tensor<4>) {
119 match path {
120 Mamba3SsdPath::Minimal(_) => self.single_ssd_minimal(),
121 Mamba3SsdPath::Serial(_) => self.single_ssd_serial(),
122 Mamba3SsdPath::SerialRecalculated(_) => self.single_ssd_serial_recalculated(),
123 }
124 }
125}
126
127// ---------------------------------------------------------------------------
128// Tests — Minimal ≡ Serial (forward outputs + input gradients)
129// ---------------------------------------------------------------------------
130
131#[cfg(all(test, feature = "_dev-test"))]
132mod tests;