Skip to main content

burn_mamba/mamba3/single_ssd/ssd/
serial.rs

1//! # SingleSsd Serial (K1–K5) SSD
2//!
3//! Chunk-serial counterpart to [`crate::mamba3::single_ssd::ssd::minimal`].
4//! Whereas the Minimal variant uses a segsum-based quadratic state passing,
5//! this one reuses the K1–K4 helpers from [`crate::mamba3::double_ssd::ssd::serial`]
6//! (which run a sequential loop for K4) and supplies a **new K5** that bakes
7//! in the single-ssd logic:
8//!
9//! - Strict lower-triangular intra-chunk path (the same-time-step block is
10//!   excluded from the SSM sum; it is the “diagonal correction” territory).
11//! - K is scaled by `scaleₜ = γₜ + (1−λₜ₊₁) Δₜ₊₁` per source-time column.
12//! - Same-time-step block contributes via an explicit `γₜ · (C·Bᵀ at t) · Vₜ`
13//!   correction term, restoring the right diagonal weighting.
14//!
15//! K1–K4 are identical to the double-SSD because:
16//! - K1 (`da_cumsum`, `da_chunk_end`) depends only on `da = Δ·A`.
17//! - K2 (`cb = C · Bᵀ`) is computed on **unscaled** B / C; the single-ssd
18//!   algorithm wants the unscaled CB so it can apply `scaleₜ` per-column
19//!   (lower triangular) and reuse the same-step block for the γ-correction.
20//! - K3 (chunk-end state from V·decay·K) is form-invariant: passing the
21//!   scale-multiplied K (`K_scaled = scaleₜ · B`) recovers the single-ssd
22//!   chunk state, with no other changes needed.
23//! - K4 (sequential state passing across chunks) operates on a `[H, P, R]`
24//!   per-chunk state and a per-chunk decay total; both are mode-agnostic.
25//!
26//! Reference kernels (same as `single_ssd_minimal`):
27//! - `refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py`
28//! - `refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py`
29
30#![allow(non_snake_case)]
31
32pub use crate::mamba3::double_ssd::ssd::serial::{
33    k1_ssd_chunk_cumsum, k2_ssd_bmm, k3_ssd_chunk_state, k4_ssd_state_passing,
34};
35use crate::mamba3::single_ssd::prelude::*;
36use burn::prelude::*;
37
38impl Mamba3SingleSsdInput {
39    /// MIMO-first Single-SSD — chunk-serial (K1–K5) variant.
40    ///
41    /// Sequence of kernels (matches the double-ssd `ssd_serial`):
42    /// 1. **K1**: intra-chunk cumulative log-decay and per-chunk decay totals.
43    /// 2. **K2**: `cb = C · Bᵀ` block matrix (unscaled).
44    /// 3. **K3**: per-chunk hidden state assuming zero initial state, fed
45    ///    `K_scaled = scaleₜ · B`.
46    /// 4. **K4**: sequential state passing across chunks (loop over chunks).
47    /// 5. **K5** (this module's new function): single-ssd chunk scan with
48    ///    strict lower-triangular masking, scale broadcasting, and the
49    ///    `γₜ`-weighted same-step diagonal correction.
50    ///
51    /// # Returns
52    /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
53    /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]` —
54    ///   the single-ssd accumulator at the last token.
55    pub fn single_ssd_serial(self) -> (Tensor<6>, Tensor<4>) {
56        let input = self;
57        input.sanity();
58        let [batch, nchunks, chunk_len, _mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
59        let [.., state_rank] = input.b_bnlmhr.dims();
60
61        assert!(
62            input.init_state_hpr.is_none(),
63            "init_state_hpr is not yet supported in single_ssd_serial; use single_ssd_minimal instead"
64        );
65        assert!(nchunks > 0, "sequence length must be at least 1");
66        assert_eq!(
67            [batch, nchunks, chunk_len, nheads],
68            input.gamma_bnlh.dims(),
69            "gamma must align with da"
70        );
71        assert_eq!(
72            [batch, nchunks, chunk_len, nheads],
73            input.scale_bnlh.dims(),
74            "scale must align with da"
75        );
76
77        // ── K1: chunk cumulative decay ────────────────────────────────────────
78        let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(input.da_bnlh.clone());
79
80        // ── K2: CB matrix on unscaled B/C ─────────────────────────────────────
81        // SingleSsd K5 applies the `scale` and `gamma` weights post-hoc, so K2 is
82        // identical to the double-ssd K2.
83        let cb_bnhLMLM: Tensor<5> = k2_ssd_bmm(input.c_bnlmhr.clone(), input.b_bnlmhr.clone());
84
85        // ── K3: chunk state using K_scaled = scaleₜ · B ───────────────────────
86        // The existing K3 computes `state = (V * decay)^T @ B_input`, so passing
87        // `B_input = K_scaled` recovers the single-ssd per-chunk state.
88        let scale_bnlh11 = input.scale_bnlh.clone().unsqueeze_dims::<6>(&[3, 5]);
89        let k_scaled_bnlmhr = input.b_bnlmhr.clone() * scale_bnlh11;
90        let intra_chunk_state_bnhpr: Tensor<5> = k3_ssd_chunk_state(
91            input.v_bnlmhp.clone(),
92            k_scaled_bnlmhr,
93            da_cumsum_bhnl.clone(),
94        );
95
96        // ── K4: sequential state passing across chunks ────────────────────────
97        let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<5>, Tensor<4>) =
98            k4_ssd_state_passing(
99                intra_chunk_state_bnhpr,
100                da_chunk_end_bhn,
101                input.initial_state_bhpr,
102            );
103        assert_eq!(
104            [batch, nchunks, nheads, per_head_dim, state_rank],
105            chunk_input_state_bnhpr.dims()
106        );
107
108        // ── K5: single-ssd chunk scan (strict-lower + diag γ-correction + Y_off)
109        let y_bnlmhp = k5_single_ssd_chunk_scan(
110            da_cumsum_bhnl,
111            input.v_bnlmhp,
112            input.c_bnlmhr,
113            input.b_bnlmhr,
114            cb_bnhLMLM,
115            input.gamma_bnlh,
116            input.scale_bnlh,
117            chunk_input_state_bnhpr,
118        );
119
120        (y_bnlmhp, final_state_bhpr)
121    }
122}
123
124// ---------------------------------------------------------------------------
125// K5 (single-ssd) — strict-lower intra-chunk + γ-correction + state-to-output
126// ---------------------------------------------------------------------------
127
128/// SingleSsd chunk scan.
129///
130/// Computes the per-chunk output from three contributions:
131/// - **Strict lower triangular intra-chunk** (`t1 > t2`):
132///   `(cb[i,j] · scale[t2] · exp(cumA[t1] − cumA[t2])) · V[t2]`
133/// - **Same-time-step (`t1 == t2`) γ-correction**:
134///   `γ[t] · (Σₙ C[t,r_out,n] · B[t,r_in,n]) · V[t,r_in,p]`
135/// - **State-to-output (Y_off)** — same formula as the double-ssd K5:
136///   `exp(cumA[t]) · C[t] · h'[n-1]`
137///
138/// `cb_bnhLMLM` is the unscaled `C · Bᵀ` matrix from K2; `b_bnlmhr` is the
139/// unscaled K/B tensor (used for the γ-correction matmul). The strict-lower
140/// MIMO mask excludes the same-step `R × R` block, leaving only `t1 > t2`
141/// contributions in the masked CB.
142///
143/// # Shapes
144/// - `da_cumsum_bhnl`: `[B, H, N, L]` (base time grid, not fused)
145/// - `v_bnlmhp`: `[B, N, L, M, H, P]`
146/// - `c_bnlmhr`, `b_bnlmhr`: `[B, N, L, M, H, R]`
147/// - `cb_bnhLMLM`: `[B, N, H, L·M, L·M]` (output of K2)
148/// - `gamma_bnlh`, `scale_bnlh`: `[B, N, L, H]`
149/// - `chunk_input_state_bnhpr`: `[B, N, H, P, R]` (h' at chunk start)
150///
151/// # Returns
152/// - `y_bnlmhp`: `[B, N, L, M, H, P]`
153#[allow(clippy::too_many_arguments)]
154pub fn k5_single_ssd_chunk_scan(
155    da_cumsum_bhnl: Tensor<4>,
156    v_bnlmhp: Tensor<6>,
157    c_bnlmhr: Tensor<6>,
158    b_bnlmhr: Tensor<6>,
159    cb_bnhLMLM: Tensor<5>,
160    gamma_bnlh: Tensor<4>,
161    scale_bnlh: Tensor<4>,
162    chunk_input_state_bnhpr: Tensor<5>,
163) -> Tensor<6> {
164    let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
165    let [.., state_rank] = c_bnlmhr.dims();
166    let device = v_bnlmhp.device();
167    let fused = chunk_len * mimo_rank;
168
169    // Fuse mimo_rank into chunk_len for the SSM-style matmul.
170    let v_bnLMhp = v_bnlmhp
171        .clone()
172        .reshape([batch, nchunks, fused, nheads, per_head_dim]);
173    let c_bnLMhr = c_bnlmhr
174        .clone()
175        .reshape([batch, nchunks, fused, nheads, state_rank]);
176
177    // Per-fused-step cumulative decay (interleave-expand the base grid).
178    let da_cumsum_bhnLM = da_cumsum_bhnl
179        .unsqueeze_dim::<5>(4)
180        .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
181        .reshape([batch, nheads, nchunks, fused]);
182
183    // ── Y_off: exp(cumA[t]) · C[t] · h'[n-1]  (same form as double-ssd K5) ──
184    let exp_da_bnhLMp = da_cumsum_bhnLM
185        .clone()
186        .exp()
187        .permute([0, 2, 1, 3]) // bnhLM
188        .unsqueeze_dim::<5>(4) // bnhLM1
189        .expand([batch, nchunks, nheads, fused, per_head_dim]);
190
191    let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
192    let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
193    let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
194    let y_off_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
195
196    // ── Y_lower: strict lower-tri intra-chunk with scale and decay ────────
197    //
198    // Mask `cb` to keep only `t1 > t2`, multiply by `exp(cumA[t1] - cumA[t2])`
199    // and by `scale[t2]` along the source axis, then matmul with V.
200    let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]); // bnhLM
201    let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
202        .clone()
203        .unsqueeze_dim::<5>(4) // bnhLM1
204        .expand([batch, nchunks, nheads, fused, fused]);
205    let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
206        .unsqueeze_dim::<5>(3) // bnh1LM
207        .expand([batch, nchunks, nheads, fused, fused]);
208    let diff_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
209
210    // Strict-upper -inf mask on the base time grid (`t1 <= t2` → -inf),
211    // then interleave-expand to fused length so that MIMO same-time blocks
212    // are zeroed out.
213    let inf_upper_ll =
214        Tensor::<2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device).triu(0); // upper triangle INCLUDING diagonal
215    let inf_upper_bnhll = inf_upper_ll
216        .unsqueeze_dims::<5>(&[0, 1, 2])
217        .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
218    let inf_upper_bnhLMLM = inf_upper_bnhll
219        .unsqueeze_dim::<6>(4)
220        .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len])
221        .reshape([batch, nchunks, nheads, fused, chunk_len])
222        .unsqueeze_dim::<6>(5)
223        .expand([batch, nchunks, nheads, fused, chunk_len, mimo_rank])
224        .reshape([batch, nchunks, nheads, fused, fused]);
225    let decay_strict_bnhLMLM = (diff_bnhLMLM + inf_upper_bnhLMLM).exp();
226
227    // Per-column scale: `scale[t2]` lives on the source axis (column).
228    let scale_bnhLM = scale_bnlh
229        .permute([0, 1, 3, 2]) // bnhl
230        .unsqueeze_dim::<5>(4) // bnhl1
231        .expand([batch, nchunks, nheads, chunk_len, mimo_rank])
232        .reshape([batch, nchunks, nheads, fused]);
233    let scale_col_bnhLMLM = scale_bnhLM
234        .unsqueeze_dim::<5>(3) // bnh1LM
235        .expand([batch, nchunks, nheads, fused, fused]);
236
237    let kernel_bnhLMLM = decay_strict_bnhLMLM * scale_col_bnhLMLM;
238    let masked_cb_bnhLMLM = cb_bnhLMLM * kernel_bnhLMLM;
239
240    let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
241    let y_lower_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
242
243    // ── Y_diag: γ-weighted same-step correction ───────────────────────────
244    //
245    // y_diag[t, m_out, h, p] = γ[t] · Σ_{m_in} (Σ_n C[t,m_out,n] · B[t,m_in,n]) · V[t,m_in,p]
246    //
247    // Computed fresh (small same-step matmul) rather than extracting the
248    // block-diagonal from `cb_bnhLMLM` (which would require a fiddly reshape).
249    let c_bnlhmr = c_bnlmhr.permute([0, 1, 2, 4, 3, 5]);
250    let b_bnlhrm = b_bnlmhr.permute([0, 1, 2, 4, 5, 3]);
251    let qk_dot_bnlhmM = c_bnlhmr.matmul(b_bnlhrm); // bnlhm_outm_in
252    let v_bnlhmp = v_bnlmhp.permute([0, 1, 2, 4, 3, 5]);
253    let y_d_bnlhmp = qk_dot_bnlhmM.matmul(v_bnlhmp); // bnlhm_outp
254    let gamma_bnlh11 = gamma_bnlh.unsqueeze_dims::<6>(&[4, 5]);
255    let y_d_bnlhmp_scaled = y_d_bnlhmp * gamma_bnlh11;
256
257    // Back to fused layout `[B, N, H, L·M, P]` to match y_lower / y_off.
258    let y_diag_bnlmhp = y_d_bnlhmp_scaled.permute([0, 1, 2, 4, 3, 5]);
259    let y_diag_bnLMhp = y_diag_bnlmhp.reshape([batch, nchunks, fused, nheads, per_head_dim]);
260    let y_diag_bnhLMp = y_diag_bnLMhp.permute([0, 1, 3, 2, 4]);
261
262    // ── Combine and reshape ───────────────────────────────────────────────
263    let y_bnhLMp = y_off_bnhLMp + y_lower_bnhLMp + y_diag_bnhLMp;
264    let y_bnLMhp = y_bnhLMp.permute([0, 1, 3, 2, 4]);
265    y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim])
266}