burn_mamba/mamba3/single_ssd/ssd/serial.rs
1//! # SingleSsd Serial (K1–K5) SSD
2//!
3//! Chunk-serial counterpart to [`crate::mamba3::single_ssd::ssd::minimal`].
4//! Whereas the Minimal variant uses a segsum-based quadratic state passing,
5//! this one reuses the K1–K4 helpers from [`crate::mamba3::double_ssd::ssd::serial`]
6//! (which run a sequential loop for K4) and supplies a **new K5** that bakes
7//! in the single-ssd logic:
8//!
9//! - Strict lower-triangular intra-chunk path (the same-time-step block is
10//! excluded from the SSM sum; it is the “diagonal correction” territory).
11//! - K is scaled by `scaleₜ = γₜ + (1−λₜ₊₁) Δₜ₊₁` per source-time column.
12//! - Same-time-step block contributes via an explicit `γₜ · (C·Bᵀ at t) · Vₜ`
13//! correction term, restoring the right diagonal weighting.
14//!
15//! K1–K4 are identical to the double-SSD because:
16//! - K1 (`da_cumsum`, `da_chunk_end`) depends only on `da = Δ·A`.
17//! - K2 (`cb = C · Bᵀ`) is computed on **unscaled** B / C; the single-ssd
18//! algorithm wants the unscaled CB so it can apply `scaleₜ` per-column
19//! (lower triangular) and reuse the same-step block for the γ-correction.
20//! - K3 (chunk-end state from V·decay·K) is form-invariant: passing the
21//! scale-multiplied K (`K_scaled = scaleₜ · B`) recovers the single-ssd
22//! chunk state, with no other changes needed.
23//! - K4 (sequential state passing across chunks) operates on a `[H, P, R]`
24//! per-chunk state and a per-chunk decay total; both are mode-agnostic.
25//!
26//! Reference kernels (same as `single_ssd_minimal`):
27//! - `refs/state-spaces/mamba/mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py`
28//! - `refs/state-spaces/mamba/mamba_ssm/ops/tilelang/mamba3/mamba3_mimo_fwd.py`
29
30#![allow(non_snake_case)]
31
32pub use crate::mamba3::double_ssd::ssd::serial::{
33 k1_ssd_chunk_cumsum, k2_ssd_bmm, k3_ssd_chunk_state, k4_ssd_state_passing,
34};
35use crate::mamba3::single_ssd::prelude::*;
36use burn::prelude::*;
37
38impl Mamba3SingleSsdInput {
39 /// MIMO-first Single-SSD — chunk-serial (K1–K5) variant.
40 ///
41 /// Sequence of kernels (matches the double-ssd `ssd_serial`):
42 /// 1. **K1**: intra-chunk cumulative log-decay and per-chunk decay totals.
43 /// 2. **K2**: `cb = C · Bᵀ` block matrix (unscaled).
44 /// 3. **K3**: per-chunk hidden state assuming zero initial state, fed
45 /// `K_scaled = scaleₜ · B`.
46 /// 4. **K4**: sequential state passing across chunks (loop over chunks).
47 /// 5. **K5** (this module's new function): single-ssd chunk scan with
48 /// strict lower-triangular masking, scale broadcasting, and the
49 /// `γₜ`-weighted same-step diagonal correction.
50 ///
51 /// # Returns
52 /// - `y_bnlmhp`: `[batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]`
53 /// - `final_state_bhpr`: `[batch, nheads, per_head_dim, state_rank]` —
54 /// the single-ssd accumulator at the last token.
55 pub fn single_ssd_serial(self) -> (Tensor<6>, Tensor<4>) {
56 let input = self;
57 input.sanity();
58 let [batch, nchunks, chunk_len, _mimo_rank, nheads, per_head_dim] = input.v_bnlmhp.dims();
59 let [.., state_rank] = input.b_bnlmhr.dims();
60
61 assert!(
62 input.init_state_hpr.is_none(),
63 "init_state_hpr is not yet supported in single_ssd_serial; use single_ssd_minimal instead"
64 );
65 assert!(nchunks > 0, "sequence length must be at least 1");
66 assert_eq!(
67 [batch, nchunks, chunk_len, nheads],
68 input.gamma_bnlh.dims(),
69 "gamma must align with da"
70 );
71 assert_eq!(
72 [batch, nchunks, chunk_len, nheads],
73 input.scale_bnlh.dims(),
74 "scale must align with da"
75 );
76
77 // ── K1: chunk cumulative decay ────────────────────────────────────────
78 let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(input.da_bnlh.clone());
79
80 // ── K2: CB matrix on unscaled B/C ─────────────────────────────────────
81 // SingleSsd K5 applies the `scale` and `gamma` weights post-hoc, so K2 is
82 // identical to the double-ssd K2.
83 let cb_bnhLMLM: Tensor<5> = k2_ssd_bmm(input.c_bnlmhr.clone(), input.b_bnlmhr.clone());
84
85 // ── K3: chunk state using K_scaled = scaleₜ · B ───────────────────────
86 // The existing K3 computes `state = (V * decay)^T @ B_input`, so passing
87 // `B_input = K_scaled` recovers the single-ssd per-chunk state.
88 let scale_bnlh11 = input.scale_bnlh.clone().unsqueeze_dims::<6>(&[3, 5]);
89 let k_scaled_bnlmhr = input.b_bnlmhr.clone() * scale_bnlh11;
90 let intra_chunk_state_bnhpr: Tensor<5> = k3_ssd_chunk_state(
91 input.v_bnlmhp.clone(),
92 k_scaled_bnlmhr,
93 da_cumsum_bhnl.clone(),
94 );
95
96 // ── K4: sequential state passing across chunks ────────────────────────
97 let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<5>, Tensor<4>) =
98 k4_ssd_state_passing(
99 intra_chunk_state_bnhpr,
100 da_chunk_end_bhn,
101 input.initial_state_bhpr,
102 );
103 assert_eq!(
104 [batch, nchunks, nheads, per_head_dim, state_rank],
105 chunk_input_state_bnhpr.dims()
106 );
107
108 // ── K5: single-ssd chunk scan (strict-lower + diag γ-correction + Y_off)
109 let y_bnlmhp = k5_single_ssd_chunk_scan(
110 da_cumsum_bhnl,
111 input.v_bnlmhp,
112 input.c_bnlmhr,
113 input.b_bnlmhr,
114 cb_bnhLMLM,
115 input.gamma_bnlh,
116 input.scale_bnlh,
117 chunk_input_state_bnhpr,
118 );
119
120 (y_bnlmhp, final_state_bhpr)
121 }
122}
123
124// ---------------------------------------------------------------------------
125// K5 (single-ssd) — strict-lower intra-chunk + γ-correction + state-to-output
126// ---------------------------------------------------------------------------
127
128/// SingleSsd chunk scan.
129///
130/// Computes the per-chunk output from three contributions:
131/// - **Strict lower triangular intra-chunk** (`t1 > t2`):
132/// `(cb[i,j] · scale[t2] · exp(cumA[t1] − cumA[t2])) · V[t2]`
133/// - **Same-time-step (`t1 == t2`) γ-correction**:
134/// `γ[t] · (Σₙ C[t,r_out,n] · B[t,r_in,n]) · V[t,r_in,p]`
135/// - **State-to-output (Y_off)** — same formula as the double-ssd K5:
136/// `exp(cumA[t]) · C[t] · h'[n-1]`
137///
138/// `cb_bnhLMLM` is the unscaled `C · Bᵀ` matrix from K2; `b_bnlmhr` is the
139/// unscaled K/B tensor (used for the γ-correction matmul). The strict-lower
140/// MIMO mask excludes the same-step `R × R` block, leaving only `t1 > t2`
141/// contributions in the masked CB.
142///
143/// # Shapes
144/// - `da_cumsum_bhnl`: `[B, H, N, L]` (base time grid, not fused)
145/// - `v_bnlmhp`: `[B, N, L, M, H, P]`
146/// - `c_bnlmhr`, `b_bnlmhr`: `[B, N, L, M, H, R]`
147/// - `cb_bnhLMLM`: `[B, N, H, L·M, L·M]` (output of K2)
148/// - `gamma_bnlh`, `scale_bnlh`: `[B, N, L, H]`
149/// - `chunk_input_state_bnhpr`: `[B, N, H, P, R]` (h' at chunk start)
150///
151/// # Returns
152/// - `y_bnlmhp`: `[B, N, L, M, H, P]`
153#[allow(clippy::too_many_arguments)]
154pub fn k5_single_ssd_chunk_scan(
155 da_cumsum_bhnl: Tensor<4>,
156 v_bnlmhp: Tensor<6>,
157 c_bnlmhr: Tensor<6>,
158 b_bnlmhr: Tensor<6>,
159 cb_bnhLMLM: Tensor<5>,
160 gamma_bnlh: Tensor<4>,
161 scale_bnlh: Tensor<4>,
162 chunk_input_state_bnhpr: Tensor<5>,
163) -> Tensor<6> {
164 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
165 let [.., state_rank] = c_bnlmhr.dims();
166 let device = v_bnlmhp.device();
167 let fused = chunk_len * mimo_rank;
168
169 // Fuse mimo_rank into chunk_len for the SSM-style matmul.
170 let v_bnLMhp = v_bnlmhp
171 .clone()
172 .reshape([batch, nchunks, fused, nheads, per_head_dim]);
173 let c_bnLMhr = c_bnlmhr
174 .clone()
175 .reshape([batch, nchunks, fused, nheads, state_rank]);
176
177 // Per-fused-step cumulative decay (interleave-expand the base grid).
178 let da_cumsum_bhnLM = da_cumsum_bhnl
179 .unsqueeze_dim::<5>(4)
180 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
181 .reshape([batch, nheads, nchunks, fused]);
182
183 // ── Y_off: exp(cumA[t]) · C[t] · h'[n-1] (same form as double-ssd K5) ──
184 let exp_da_bnhLMp = da_cumsum_bhnLM
185 .clone()
186 .exp()
187 .permute([0, 2, 1, 3]) // bnhLM
188 .unsqueeze_dim::<5>(4) // bnhLM1
189 .expand([batch, nchunks, nheads, fused, per_head_dim]);
190
191 let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
192 let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
193 let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
194 let y_off_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
195
196 // ── Y_lower: strict lower-tri intra-chunk with scale and decay ────────
197 //
198 // Mask `cb` to keep only `t1 > t2`, multiply by `exp(cumA[t1] - cumA[t2])`
199 // and by `scale[t2]` along the source axis, then matmul with V.
200 let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]); // bnhLM
201 let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
202 .clone()
203 .unsqueeze_dim::<5>(4) // bnhLM1
204 .expand([batch, nchunks, nheads, fused, fused]);
205 let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
206 .unsqueeze_dim::<5>(3) // bnh1LM
207 .expand([batch, nchunks, nheads, fused, fused]);
208 let diff_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
209
210 // Strict-upper -inf mask on the base time grid (`t1 <= t2` → -inf),
211 // then interleave-expand to fused length so that MIMO same-time blocks
212 // are zeroed out.
213 let inf_upper_ll =
214 Tensor::<2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device).triu(0); // upper triangle INCLUDING diagonal
215 let inf_upper_bnhll = inf_upper_ll
216 .unsqueeze_dims::<5>(&[0, 1, 2])
217 .expand([batch, nchunks, nheads, chunk_len, chunk_len]);
218 let inf_upper_bnhLMLM = inf_upper_bnhll
219 .unsqueeze_dim::<6>(4)
220 .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len])
221 .reshape([batch, nchunks, nheads, fused, chunk_len])
222 .unsqueeze_dim::<6>(5)
223 .expand([batch, nchunks, nheads, fused, chunk_len, mimo_rank])
224 .reshape([batch, nchunks, nheads, fused, fused]);
225 let decay_strict_bnhLMLM = (diff_bnhLMLM + inf_upper_bnhLMLM).exp();
226
227 // Per-column scale: `scale[t2]` lives on the source axis (column).
228 let scale_bnhLM = scale_bnlh
229 .permute([0, 1, 3, 2]) // bnhl
230 .unsqueeze_dim::<5>(4) // bnhl1
231 .expand([batch, nchunks, nheads, chunk_len, mimo_rank])
232 .reshape([batch, nchunks, nheads, fused]);
233 let scale_col_bnhLMLM = scale_bnhLM
234 .unsqueeze_dim::<5>(3) // bnh1LM
235 .expand([batch, nchunks, nheads, fused, fused]);
236
237 let kernel_bnhLMLM = decay_strict_bnhLMLM * scale_col_bnhLMLM;
238 let masked_cb_bnhLMLM = cb_bnhLMLM * kernel_bnhLMLM;
239
240 let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
241 let y_lower_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
242
243 // ── Y_diag: γ-weighted same-step correction ───────────────────────────
244 //
245 // y_diag[t, m_out, h, p] = γ[t] · Σ_{m_in} (Σ_n C[t,m_out,n] · B[t,m_in,n]) · V[t,m_in,p]
246 //
247 // Computed fresh (small same-step matmul) rather than extracting the
248 // block-diagonal from `cb_bnhLMLM` (which would require a fiddly reshape).
249 let c_bnlhmr = c_bnlmhr.permute([0, 1, 2, 4, 3, 5]);
250 let b_bnlhrm = b_bnlmhr.permute([0, 1, 2, 4, 5, 3]);
251 let qk_dot_bnlhmM = c_bnlhmr.matmul(b_bnlhrm); // bnlhm_outm_in
252 let v_bnlhmp = v_bnlmhp.permute([0, 1, 2, 4, 3, 5]);
253 let y_d_bnlhmp = qk_dot_bnlhmM.matmul(v_bnlhmp); // bnlhm_outp
254 let gamma_bnlh11 = gamma_bnlh.unsqueeze_dims::<6>(&[4, 5]);
255 let y_d_bnlhmp_scaled = y_d_bnlhmp * gamma_bnlh11;
256
257 // Back to fused layout `[B, N, H, L·M, P]` to match y_lower / y_off.
258 let y_diag_bnlmhp = y_d_bnlhmp_scaled.permute([0, 1, 2, 4, 3, 5]);
259 let y_diag_bnLMhp = y_diag_bnlmhp.reshape([batch, nchunks, fused, nheads, per_head_dim]);
260 let y_diag_bnhLMp = y_diag_bnLMhp.permute([0, 1, 3, 2, 4]);
261
262 // ── Combine and reshape ───────────────────────────────────────────────
263 let y_bnhLMp = y_off_bnhLMp + y_lower_bnhLMp + y_diag_bnhLMp;
264 let y_bnLMhp = y_bnhLMp.permute([0, 1, 3, 2, 4]);
265 y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim])
266}