burn_mamba/mamba2/mamba2.rs
1//! # Mamba-2 SSM Block — Structured State Space Duality (SSD)
2//!
3//! This module implements the core **SSD layer** from the paper
4//! *"Transformers are SSMs: Generalized Models and Efficient Algorithms
5//! through Structured State Space Duality"* (Dao & Gu, 2024).
6//!
7//! ## The SSD Model
8//!
9//! The SSD layer is a multi-head selective SSM. Each head processes a
10//! sequence of `P`-dimensional inputs `X ∈ ℝ^{T×P}` through the recurrence
11//! (Eq. 1–2 of the paper):
12//!
13//! ```text
14//! hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ (state update)
15//! yₜ = Cₜᵀ hₜ (output readout)
16//! ```
17//!
18//! where:
19//! - `hₜ ∈ ℝ^{N×P}` is the hidden state (N = `state_rank`, P = `per_head_dim`)
20//! - `Āₜ = exp(Δₜ A) ∈ ℝ` is a scalar decay (the key SSD constraint:
21//! Āₜ = αₜ · I, i.e. scalar times identity, rather than a diagonal matrix)
22//! - `B̄ₜ = Δₜ Bₜ ∈ ℝᴺ` is the (discretised) input projection
23//! - `Cₜ ∈ ℝᴺ` is the output projection
24//! - `Δₜ > 0` is the (input-dependent) discretisation step size
25//! - `A < 0` is a learnable scalar decay rate per head
26//!
27//! ## State Space Duality
28//!
29//! Unrolling the recurrence yields an equivalent **attention-like** form
30//! (Eq. 6–7):
31//!
32//! ```text
33//! M = L ∘ (C Bᵀ) ∈ ℝ^{T×T}
34//! Y = M · X
35//! ```
36//!
37//! where `L` is the **1-semiseparable mask** (Eq. 4–5):
38//!
39//! ```text
40//! Lᵢⱼ = āᵢ · āᵢ₋₁ · ... · āⱼ₊₁ (i ≥ j)
41//! Lᵢⱼ = 0 (i < j)
42//! ```
43//!
44//! This makes the layer equivalent to causal linear attention
45//! `Y = (L ∘ QKᵀ) V` under the renaming `(C, B, X) ↦ (Q, K, V)`.
46//!
47//! ## The Chunkwise SSD Algorithm
48//!
49//! See [`minimal`] for more information.
50//!
51//! ## Notation / Dimension Keys
52//!
53//! Throughout all file, tensor names carry a suffix encoding their shape.
54//! The letters used are:
55//!
56//! | Letter | Dimension | Typical value |
57//! |--------|-----------|---------------|
58//! | `b` | batch | varies |
59//! | `s` | sequence length T | varies |
60//! | `m` | d_model | 768, 1024 … |
61//! | `i` | d_inner = expand·d_model | 2·d_model |
62//! | `h` | nheads H | d_inner / P |
63//! | `p` | per_head_dim P | 64, 128 |
64//! | `r` | state_rank N | 64–256 |
65//! | `v` | conv_dim | d_inner + 2·G·N |
66//! | `k` | conv_kernel | 4 |
67//! | `g` | ngroups G | 1–H |
68//! | `n` | nchunks = T/Q | varies |
69//! | `l` | chunk_len Q | 64–256 |
70//! | `N` | 1+nchunks (padded for state scan) | — |
71//! | `f` | P·N (flattened state for matmul) | — |
72
73use crate::mamba2::prelude::*;
74use crate::utils::sanity::sanity as san;
75use crate::utils::{
76 rms_norm_gated::{RmsNormGated, RmsNormGatedConfig},
77 silu::Silu,
78 softplus::softplus,
79};
80use burn::prelude::*;
81use burn::{
82 module::{Module, Param},
83 nn::conv::{Conv1d, Conv1dConfig},
84 nn::{Initializer, Linear, LinearConfig},
85};
86
87// ---------------------------------------------------------------------------
88// Mamba2 (the SSM block)
89// ---------------------------------------------------------------------------
90
91/// The Mamba-2 SSM block.
92///
93/// Implements the full SSD layer as described in §5 of the paper. Supports
94/// two execution modes:
95///
96/// - [`Self::forward`] — chunkwise SSD for training / prefill
97/// (exploits tensor cores; linear in sequence length T)
98/// - [`Self::step`] — pure recurrent form for token-by-token decoding
99/// (O(H·P·N) per step; no KV-cache)
100///
101/// ## Architecture (one forward pass through the block)
102///
103/// ```text
104/// u [B, T, D]
105/// ├─ in_proj ──────────────────────────────────┐
106/// │ │
107/// │ z [B,T,I] xbc [B,T,V] dt_raw [B,T,H] │
108/// │ │ │
109/// │ causal Conv1d │
110/// │ │ SiLU │
111/// │ split into │
112/// │ x [B,T,H,P] B [B,T,G,N] C [B,T,G,N]
113/// │ │
114/// │ Δ = softplus(dt_raw + dt_bias) │
115/// │ Ā = exp(Δ · A) [scalar per head] │
116/// │ B̄ = Δ · B │
117/// │ │
118/// │ ┌──── chunked_selective_scan ─────────┐ │
119/// │ │ (Steps 1–4, see below) │ │
120/// │ └────────────────────────────────────-┘ │
121/// │ y [B,T,H,P] │
122/// │ + D skip │
123/// │ RmsNormGated(·, z) │
124/// └─ out_proj ─────────────────────────────────┘
125/// output [B, T, D]
126/// ```
127#[derive(Module, Debug)]
128pub struct Mamba2<B: Backend> {
129 /// Input projection: maps `d_model → d_inner + conv_dim + nheads`.
130 ///
131 /// The output is split into three parts:
132 /// - `z [B, T, d_inner]` — multiplicative gate for the output norm
133 /// - `xbc [B, T, conv_dim]` — input to the causal convolution, which
134 /// is then split into (x, B, C) after activation
135 /// - `dt_raw [B, T, nheads]` — raw (pre-softplus) discretisation step Δ
136 pub in_proj: Linear<B>,
137
138 /// Causal depthwise Conv1d applied to the `xbc` projection.
139 ///
140 /// - Input/output channels: `conv_dim`
141 /// - Kernel size: `conv_kernel` (typically 4)
142 /// - Groups: `conv_dim` (fully depthwise — each channel is independent)
143 /// - Padding: **none** (left-padding is applied manually so the convolution
144 /// is strictly causal)
145 ///
146 /// The convolution provides a local `conv_kernel`-token context window
147 /// before the SSM, which helps the model capture short-range dependencies
148 /// that the SSM's recurrent form handles less efficiently.
149 pub conv1d: Conv1d<B>,
150
151 /// Per-head bias for the discretisation step size Δ.
152 ///
153 /// Shape: `[nheads]`
154 ///
155 /// At inference time, `Δₜ = softplus(dt_raw_t + dt_bias)`.
156 /// Initialised such that the corresponding initial `Δ` values are
157 /// log-uniformly distributed in `[dt_min, dt_max]`.
158 pub dt_bias_h: Param<Tensor<B, 1>>,
159
160 /// Hard clamp applied to Δ after softplus: `Δ ∈ [dt_limit.0, dt_limit.1]`.
161 ///
162 /// Prevents degenerate discretisations (e.g. Δ → 0 causes Ā → 1, meaning
163 /// the state never decays; Δ → ∞ causes Ā → 0, meaning the state is
164 /// immediately wiped each step).
165 pub dt_limit: (f64, f64),
166
167 /// Per-head log-magnitude of the continuous-time decay parameter A.
168 ///
169 /// Shape: `[nheads]`
170 ///
171 /// The actual (negative) decay rate is `A = -exp(a_log)`. The discrete
172 /// decay is `Āₜ = exp(Δₜ · A) = exp(-Δₜ · exp(a_log)) ∈ (0, 1)`.
173 ///
174 /// Storing the *log* of the magnitude and negating ensures A < 0
175 /// (decaying system) unconditionally and avoids any sign-constraint
176 /// during gradient descent.
177 pub a_log_h: Param<Tensor<B, 1>>,
178
179 /// Per-head skip (D) coefficient.
180 ///
181 /// Shape: `[nheads]`
182 ///
183 /// Adds a direct path from the (post-convolution, pre-SSM) input to the
184 /// output: `yₜ += D · xₜ`. Initialised to ones.
185 pub d_h: Param<Tensor<B, 1>>,
186
187 /// Gated RMSNorm applied to the SSM output, conditioned on the gate `z`.
188 ///
189 /// Input channel dimension: `d_inner`.
190 ///
191 /// This combines the multiplicative gate (from `z`) and a normalisation
192 /// step into a single fused operation, matching the architecture in §5.2
193 /// of the paper.
194 pub norm: RmsNormGated<B>,
195
196 /// Output projection: maps `d_inner → d_model`.
197 pub out_proj: Linear<B>,
198
199 /// Optional learnable initial hidden state `h₀`.
200 ///
201 /// Shape: `[nheads, per_head_dim, state_rank]` (i.e. `[H, P, N]`)
202 ///
203 /// When `None`, the initial state is zero (the standard default).
204 /// When `Some`, the stored tensor is used as the initial condition for
205 /// *every* forward call (not per-batch; it is broadcast over the batch
206 /// dimension).
207 pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
208
209 /// State rank `N` — the number of latent dimensions in the SSM hidden
210 /// state `h ∈ ℝ^{N×P}` per head. Corresponds to the paper's `N`.
211 pub state_rank: usize,
212
213 /// Number of B/C groups `G` for grouped SSM heads (analogous to
214 /// grouped-query attention). `G` divides `nheads`; all `nheads/G` heads
215 /// within a group share the same B and C projections while having
216 /// independent X, A, and Z projections.
217 pub ngroups: usize,
218}
219
220impl<B: Backend> Mamba2<B> {
221 /// `d_inner = expand · d_model`. Inferred from the norm's weight shape.
222 pub fn d_inner(&self) -> usize {
223 let [d_inner] = self.norm.gamma.dims();
224 d_inner
225 }
226
227 /// `nheads = d_inner / per_head_dim`. Inferred from `a_log_h`.
228 pub fn nheads(&self) -> usize {
229 let [nheads] = self.a_log_h.dims();
230 nheads
231 }
232
233 /// `per_head_dim P = d_inner / nheads`.
234 pub fn per_head_dim(&self) -> usize {
235 self.d_inner() / self.nheads()
236 }
237
238 /// `conv_dim = d_inner + 2 · ngroups · state_rank`.
239 pub fn conv_dim(&self) -> usize {
240 self.d_inner() + 2 * self.ngroups * self.state_rank
241 }
242}
243
244// ---------------------------------------------------------------------------
245// Mamba2Config (hyperparameters and factory)
246// ---------------------------------------------------------------------------
247
248/// Hyperparameters for the Mamba-2 SSM block.
249///
250/// All computed quantities (e.g. `nheads`, `d_inner`, `conv_dim`) are derived
251/// from the stored fields; see the helper methods on [`Mamba2Config`].
252#[derive(Config, Debug)]
253pub struct Mamba2Config {
254 /// Model (hidden) dimension D. Every token is represented as a
255 /// D-dimensional vector at the input and output of the block.
256 pub d_model: usize,
257
258 /// State rank N — the latent dimension of the SSM hidden state.
259 ///
260 /// Larger N gives a more expressive state but increases memory and compute
261 /// per step. The paper uses N ∈ {64, 128, 256} for most experiments.
262 #[config(default = 128)]
263 pub state_rank: usize,
264
265 /// Causal convolution window length. Typically 4.
266 #[config(default = 4)]
267 pub conv_kernel: usize,
268
269 /// Expansion factor for `d_inner = expand · d_model`.
270 ///
271 /// An expansion of 2 doubles the internal width, keeping the parameter
272 /// count of the SSM block comparable to a standard attention layer.
273 #[config(default = 2)]
274 pub expand: usize,
275
276 /// Head dimension P. The total `d_inner` is split into
277 /// `nheads = d_inner / P` independent SSM heads.
278 ///
279 /// Typical values: 64 or 128. Smaller P means more heads and a larger
280 /// hidden state per model dimension.
281 #[config(default = 64)]
282 pub per_head_dim: usize,
283
284 /// Number of B/C groups G. Must divide `nheads`.
285 ///
286 /// Setting G < nheads reduces the B and C projection sizes (analogous to
287 /// GQA in attention), saving memory without a large accuracy cost.
288 #[config(default = 1)]
289 pub ngroups: usize,
290
291 /// Range `[lo, hi]` for the uniform initialisation of the magnitude of A.
292 ///
293 /// `A = -Uniform(lo, hi)`, stored as `a_log = log(Uniform(lo, hi))`.
294 /// The paper uses `[1, 16]` by default.
295 #[config(default = "(1., 16.)")]
296 pub a_init_range: (f64, f64),
297
298 /// Gated RMSNorm mode: when `true` the norm is applied *before* the gate;
299 /// when `false` (default) the gate is applied first (SiLU-gated norm).
300 #[config(default = false)]
301 pub is_norm_before_gate: bool,
302
303 /// Minimum value of the initial Δ distribution. Used to set `dt_bias`.
304 #[config(default = 1e-3)]
305 pub dt_min: f64,
306
307 /// Maximum value of the initial Δ distribution. Used to set `dt_bias`.
308 #[config(default = 0.1)]
309 pub dt_max: f64,
310
311 /// Floor clamped onto the sampled initial Δ values before inverting to
312 /// obtain `dt_bias`. Prevents numerical issues with very small Δ.
313 #[config(default = 1e-4)]
314 pub dt_init_floor: f64,
315
316 /// Hard clamp limits for Δ at runtime: `Δ ∈ [dt_limit.0, dt_limit.1]`.
317 ///
318 /// Defaults to `(0, f16::MAX ≈ 65504)`, effectively only clamping at 0.
319 #[config(default = "(0., 6.5504e+4)")]
320 pub dt_limit: (f64, f64),
321
322 /// Whether to add a bias term to the `in_proj` and `out_proj` projections.
323 #[config(default = false)]
324 pub has_proj_bias: bool,
325
326 /// Whether to add a bias term to the depthwise convolution.
327 #[config(default = true)]
328 pub has_conv_bias: bool,
329
330 /// Whether to allocate a learnable initial SSM state `h₀`.
331 ///
332 /// When `false` (default), the hidden state starts at zero for every
333 /// sequence. When `true`, `init_state_hpr` is allocated as a trainable
334 /// parameter of shape `[nheads, per_head_dim, state_rank]`.
335 #[config(default = false)]
336 pub has_learnable_init_state: bool,
337}
338
339impl Mamba2Config {
340 // -----------------------------------------------------------------------
341 // Computed dimensions
342 // -----------------------------------------------------------------------
343
344 /// `d_inner = expand · d_model`.
345 pub fn d_inner(&self) -> usize {
346 self.expand * self.d_model
347 }
348
349 /// `nheads H = d_inner / per_head_dim`.
350 pub fn nheads(&self) -> usize {
351 self.d_inner() / self.per_head_dim
352 }
353
354 /// `conv_dim = d_inner + 2 · ngroups · state_rank`.
355 ///
356 /// The depthwise convolution processes `x`, `B`, and `C` concatenated:
357 /// x contributes `d_inner` channels, B and C each contribute
358 /// `ngroups · state_rank` channels.
359 pub fn conv_dim(&self) -> usize {
360 self.d_inner() + 2 * self.ngroups * self.state_rank
361 }
362
363 // -----------------------------------------------------------------------
364 // Initialisation
365 // -----------------------------------------------------------------------
366
367 /// Allocate and initialise all Mamba-2 block parameters on `device`.
368 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2<B> {
369 let d_inner = self.d_inner();
370 let nheads = self.nheads();
371 let conv_dim = self.conv_dim();
372
373 assert!(self.per_head_dim > 0, "per_head_dim must be positive");
374 assert_eq!(
375 nheads * self.per_head_dim,
376 d_inner,
377 "d_inner must be divisible by per_head_dim"
378 );
379 assert_eq!(
380 nheads % self.ngroups,
381 0,
382 "nheads must be divisible by ngroups"
383 );
384
385 // Uniform initialiser matching PyTorch's default: U(-1/√fan_in, 1/√fan_in).
386 let uniform_init = |fan_in: usize| {
387 let bound = 1.0 / (fan_in as f64).sqrt();
388 Initializer::Uniform {
389 min: -bound,
390 max: bound,
391 }
392 };
393
394 // ── in_proj ──────────────────────────────────────────────────────────
395 // Projects d_model → (z, xbc, dt_raw).
396 // Size: d_inner + conv_dim + nheads
397 let d_in_proj_out = d_inner + conv_dim + nheads;
398 let in_proj = LinearConfig::new(self.d_model, d_in_proj_out)
399 .with_bias(self.has_proj_bias)
400 .with_initializer(uniform_init(self.d_model))
401 .init::<B>(device);
402
403 // ── conv1d ───────────────────────────────────────────────────────────
404 // Causal depthwise convolution. Left-padding is applied manually in
405 // `forward` and `step`, so we request "Valid" (no automatic padding).
406 // The initialiser fan_in is `in_channels / groups * kernel_size = 1 * conv_kernel`.
407 let conv1d = Conv1dConfig::new(conv_dim, conv_dim, self.conv_kernel)
408 .with_padding(burn::nn::PaddingConfig1d::Valid)
409 .with_groups(conv_dim)
410 .with_bias(self.has_conv_bias)
411 .with_initializer(uniform_init(self.conv_kernel))
412 .init::<B>(device);
413
414 // ── dt_bias ──────────────────────────────────────────────────────────
415 // We want the initial Δ values (after softplus) to be log-uniformly
416 // distributed in [dt_min, dt_max]. The inverse softplus (inverse of
417 // softplus(x) = ln(1 + exp(x))) is used to back-solve for the bias:
418 // dt_bias = softplus⁻¹(dt) = dt + ln(1 - exp(-dt)) ≈ dt + ln(dt)
419 // which simplifies to `dt + log(exp(dt) - 1)` in the formula below.
420 let expm1 = |t: Tensor<B, 1>| t.exp() - 1.;
421 let dt_h = Tensor::random(
422 [nheads],
423 burn::tensor::Distribution::Uniform(self.dt_min.ln(), self.dt_max.ln()),
424 device,
425 )
426 .exp();
427 let dt_h = dt_h.clamp(self.dt_init_floor, f64::INFINITY);
428 // Inverse softplus: softplus⁻¹(y) = y + log(1 - e^{-y}) = y + log(e^y - 1) - y = log(e^y - 1)
429 let inv_dt_h = dt_h.clone() + (-expm1(-dt_h)).log();
430 let dt_bias_h = Param::from_tensor(inv_dt_h);
431
432 // ── a_log ─────────────────────────────────────────────────────────────
433 // A is constrained to be negative (decaying system).
434 // We store a_log = log(|A|) and recover A = -exp(a_log) at runtime.
435 // This parameterisation ensures A < 0 unconditionally.
436 assert!(
437 self.a_init_range.0 > 0.0,
438 "a_init_range lower bound must be > 0"
439 );
440 assert!(
441 self.a_init_range.0 < self.a_init_range.1,
442 "a_init_range must satisfy lo < hi"
443 );
444 let a_h = Tensor::random(
445 [nheads],
446 burn::tensor::Distribution::Uniform(self.a_init_range.0, self.a_init_range.1),
447 device,
448 );
449 let a_log_h = Param::from_tensor(a_h.log());
450
451 // ── D (skip connection) ───────────────────────────────────────────────
452 // Initialised to ones, adding a direct residual path from input to output.
453 let d_h = Initializer::Ones.init::<B, 1, _>([nheads], device);
454
455 // ── norm (gated RMSNorm) and out_proj ─────────────────────────────────
456 let norm = RmsNormGatedConfig::new(d_inner)
457 .with_norm_before_gate(self.is_norm_before_gate)
458 .init(device);
459 let out_proj = LinearConfig::new(d_inner, self.d_model)
460 .with_bias(self.has_proj_bias)
461 .with_initializer(uniform_init(d_inner))
462 .init(device);
463
464 // ── learnable initial state (optional) ────────────────────────────────
465 let init_state_hpr = self.has_learnable_init_state.then(|| {
466 Initializer::Zeros.init::<B, 3, _>([nheads, self.per_head_dim, self.state_rank], device)
467 });
468
469 Mamba2 {
470 in_proj,
471 conv1d,
472 dt_bias_h,
473 dt_limit: self.dt_limit,
474 a_log_h,
475 d_h,
476 norm,
477 out_proj,
478 init_state_hpr,
479 state_rank: self.state_rank,
480 ngroups: self.ngroups,
481 }
482 }
483}
484
485// ---------------------------------------------------------------------------
486// Mamba2::forward (chunkwise SSD — training / prefill)
487// ---------------------------------------------------------------------------
488
489impl<B: Backend + Mamba2BackendExt> Mamba2<B> {
490 /// Process a full input sequence using the chunkwise SSD algorithm.
491 ///
492 /// This is the primary training and prefill path. The computation is
493 /// **linear in T** but uses batched matrix multiplications (GEMMs) that
494 /// can exploit GPU tensor cores — unlike the naive sequential recurrence,
495 /// which requires O(T) serial steps.
496 ///
497 /// ## Full dataflow
498 ///
499 /// 1. **In-projection**: `u → (z, xbc, dt_raw)` via a single linear layer.
500 /// 2. **Causal Conv1d + SiLU**: local context mixing over `xbc`.
501 /// 3. **Split**: `xbc → (x, B, C)`.
502 /// 4. **Discretise**: `Δ = softplus(dt_raw + dt_bias)`;
503 /// `Ā = exp(Δ · A)`; `B̄ = Δ · B`.
504 /// 5. **Padding**: sequence padding.
505 /// 6. **Chunked SSD**: four-step chunkwise algorithm (see
506 /// [`Self::chunked_selective_scan`]).
507 /// 7. **Gated RMSNorm**: `y = RMSNorm(y) · σ(z)`.
508 /// 8. **Out-projection**: `y → output`.
509 ///
510 /// ## Sequence padding
511 ///
512 /// If `sequence_unpadded % chunk_len ≠ 0`, the sequence is zero-padded
513 /// to the next multiple of Q. Zero-padding is equivalent to inserting
514 /// identity steps (`Δ = 0 ⇒ Ā = exp(0) = 1, B̄ = 0`), so the SSM
515 /// state is carried forward unchanged through the pad — making it safe to
516 /// read the final state of the padded last chunk as the true final state.
517 ///
518 /// # Shapes
519 /// - `input_bsm` : `[batch, sequence, d_model]`
520 /// - output : `[batch, sequence, d_model]`
521 /// - cache (out) : updated convolution window and SSM state
522 #[allow(non_snake_case)]
523 pub fn forward(
524 &self,
525 input_bsm: Tensor<B, 3>,
526 cache: Option<Mamba2Cache<B>>,
527 ssd_path: Mamba2SsdPath,
528 ) -> (Tensor<B, 3>, Mamba2Cache<B>) {
529 let [batch, sequence, _d_model] = input_bsm.dims();
530 let d_inner = self.d_inner();
531 let ngroups = self.ngroups;
532 let nheads = self.nheads();
533 let per_head_dim = self.per_head_dim();
534 let conv_dim = self.conv_dim();
535 let state_rank = self.state_rank;
536 let [_conv_dim, _, conv_kernel] = self.conv1d.weight.dims();
537 let [_d_model, d_in_proj_out] = self.in_proj.weight.dims();
538 let device = input_bsm.device();
539 assert_eq!(conv_dim, _conv_dim);
540 assert_ne!(ngroups, 0);
541 assert_eq!(conv_dim, _conv_dim);
542 assert_eq!(nheads % ngroups, 0);
543 assert!(sequence > 0, "sequence length must be at least 1");
544 san(&input_bsm);
545
546 // ── Initialise cache if not provided ──────────────────────────────────
547 let mut cache = cache.unwrap_or_else(|| {
548 let conv_bvk = Tensor::zeros(Shape::new([batch, conv_dim, conv_kernel]), &device);
549 let ssm_bhpr = Tensor::zeros(
550 Shape::new([batch, nheads, per_head_dim, state_rank]),
551 &device,
552 );
553 Mamba2Cache { conv_bvk, ssm_bhpr }
554 });
555 cache.sanity();
556
557 // ── Step 1: In-projection ─────────────────────────────────────────────
558 // One linear layer projects the input to all SSM parameters at once.
559 // This "parallel projection" structure (vs. Mamba-1's sequential
560 // projections) enables tensor parallelism with only 1 all-reduce per
561 // layer instead of 2.
562 //
563 // Projection output: [z | xbc | dt_raw]
564 // z [B, T, d_inner] — gate for the output RMSNorm
565 // xbc [B, T, conv_dim] — will become (x, B, C) after conv + split
566 // dt_raw [B, T, nheads] — raw discretisation step (pre-softplus)
567 let (z_gate_bsi, xbc_bsv, dt_raw_bsh) = {
568 let z_xbc_dt_bsd = self.in_proj.forward(input_bsm);
569 assert_eq!([batch, sequence, d_in_proj_out], z_xbc_dt_bsd.dims());
570 assert_eq!(
571 [batch, sequence, d_inner + conv_dim + nheads],
572 z_xbc_dt_bsd.dims(),
573 );
574
575 let mut parts = z_xbc_dt_bsd
576 .split_with_sizes(vec![d_inner, conv_dim, nheads], 2)
577 .into_iter();
578 (
579 parts.next().unwrap(), // z [B, T, d_inner]
580 parts.next().unwrap(), // xbc [B, T, conv_dim]
581 parts.next().unwrap(), // dt_raw [B, T, nheads]
582 )
583 };
584 assert_eq!([batch, sequence, d_inner], z_gate_bsi.dims());
585 assert_eq!([batch, sequence, conv_dim], xbc_bsv.dims());
586 assert_eq!([batch, sequence, nheads], dt_raw_bsh.dims());
587 san(&z_gate_bsi);
588 san(&xbc_bsv);
589 san(&dt_raw_bsh);
590
591 // ── Step 2: Causal depthwise Conv1d ───────────────────────────────────
592 // Apply the causal 1-D depthwise convolution to `xbc`. To maintain
593 // strict causality, the input is left-padded with the last
594 // `(conv_kernel - 1)` columns from the cache (the tail of the previous
595 // chunk), giving a padded input of length `(conv_kernel-1) + sequence`.
596 // After the convolution the output has length `sequence` (Valid padding).
597 //
598 // The right-most `conv_kernel` columns of the padded input become the
599 // new convolution cache for the next call.
600 let xbc_bvs = xbc_bsv.permute([0, 2, 1]); // [B, conv_dim, T]
601 assert_eq!([batch, conv_dim, sequence], xbc_bvs.dims());
602
603 // Build the causally-padded input: [cached tail | new input]
604 let xbc_padded_bvS = if conv_kernel >= 2 {
605 // Drop the oldest (leftmost) element of the cache, keeping the
606 // last (conv_kernel - 1) columns.
607 let tail_bvK = cache.conv_bvk.slice(s![.., .., 1..]);
608 assert_eq!([batch, conv_dim, conv_kernel - 1], tail_bvK.dims());
609 Tensor::cat(vec![tail_bvK, xbc_bvs], 2)
610 } else {
611 // conv_kernel == 1: no causal padding needed.
612 xbc_bvs
613 };
614 assert_eq!(
615 [batch, conv_dim, (conv_kernel - 1) + sequence],
616 xbc_padded_bvS.dims()
617 );
618 san(&xbc_padded_bvS);
619
620 // Update the cache: save the last `conv_kernel` columns of the padded
621 // input (i.e. starting at position `sequence - 1` from the new input).
622 cache.conv_bvk = xbc_padded_bvS.clone().slice(s![.., .., (sequence - 1)..]);
623 assert_eq!([batch, conv_dim, conv_kernel], cache.conv_bvk.dims());
624
625 // Apply the depthwise convolution and transpose back to [B, T, conv_dim].
626 let xbc_bvs = self.conv1d.forward(xbc_padded_bvS);
627 assert_eq!([batch, conv_dim, sequence], xbc_bvs.dims());
628 san(&xbc_bvs);
629
630 let xbc_bsv = xbc_bvs.permute([0, 2, 1]); // [B, T, conv_dim]
631 assert_eq!([batch, sequence, conv_dim], xbc_bsv.dims());
632
633 // SiLU activation (element-wise).
634 let xbc_bsv = Silu::new().forward(xbc_bsv);
635 assert_eq!([batch, sequence, conv_dim], xbc_bsv.dims());
636 san(&xbc_bsv);
637
638 // ── Step 3: Split xbc into (x, B, C) ──────────────────────────────────
639 // After activation, xbc is partitioned along the channel dimension:
640 // x [B, T, d_inner] → reshaped to [B, T, H, P] (input)
641 // B [B, T, ngroups·N] → reshaped to [B, T, G, N] (state input proj)
642 // C [B, T, ngroups·N] → reshaped to [B, T, G, N] (state output proj)
643 //
644 // Note: in the SSM/attention duality, C ↔ Q, B ↔ K, x ↔ V.
645 let (x_bshp, b_bsgr, c_bsgr) = {
646 let mut parts = xbc_bsv
647 .split_with_sizes(vec![d_inner, ngroups * state_rank, ngroups * state_rank], 2)
648 .into_iter();
649 (
650 parts
651 .next()
652 .unwrap() // [B, T, d_inner]
653 .reshape([batch, sequence, nheads, per_head_dim]),
654 parts
655 .next()
656 .unwrap() // [B, T, ngroups·N]
657 .reshape([batch, sequence, ngroups, state_rank]),
658 parts
659 .next()
660 .unwrap() // [B, T, ngroups·N]
661 .reshape([batch, sequence, ngroups, state_rank]),
662 )
663 };
664 // No shape assertions on reshapes (shapes are algebraically guaranteed).
665
666 // ── Step 4: Discretisation ────────────────────────────────────────────
667 // Compute the discrete-time parameters from the continuous-time A
668 // and input-dependent step size Δ (ZOH discretisation, §4.5):
669 //
670 // Δₜ = softplus(dt_raw_t + dt_bias) ∈ (0, ∞)
671 // Āₜ = exp(Δₜ · A) ∈ (0, 1) [scalar per head]
672 // B̄ₜ = Δₜ · Bₜ ∈ ℝᴺ [Euler approx]
673 //
674 // The Euler approximation B̄ ≈ ΔB (instead of the exact ZOH formula)
675 // is standard in Mamba-1 and Mamba-2 (see §4.5 of the reference).
676 //
677 // `a_head_decay_h` = A = -exp(a_log) < 0 (negative, one scalar per head).
678 // Note: we pass this negative value to `chunked_selective_scan`; inside
679 // that function it is multiplied by Δ > 0, giving a negative exponent
680 // which produces Āₜ = exp(Δₜ·A) ∈ (0,1) as required.
681 let dt_bias_11h = self.dt_bias_h.val().unsqueeze_dims(&[0, 1]);
682 assert_eq!([1, 1, nheads], dt_bias_11h.dims());
683
684 let dt_bsh = softplus(dt_raw_bsh + dt_bias_11h).clamp(self.dt_limit.0, self.dt_limit.1);
685 assert_eq!([batch, sequence, nheads], dt_bsh.dims());
686 san(&dt_bsh);
687
688 let a_head_decay_h = -self.a_log_h.val().exp(); // A = -exp(a_log) < 0
689 assert_eq!([nheads], a_head_decay_h.dims());
690 san(&a_head_decay_h);
691
692 // ── Step 5: Pad sequence to a multiple of chunk_len ───────────────────
693 // Zeros are the correct pad: Δ=0 ⇒ Ā=exp(0·A)=1, B̄=0·B=0.
694 // The state is thus carried through unchanged, so the final state of
695 // the padded last chunk equals the state after the last real token.
696 let chunk_len = ssd_path.chunk_len_or_optimal(state_rank, per_head_dim);
697 assert!(chunk_len > 0);
698 let sequence_padded = sequence.next_multiple_of(chunk_len);
699 let pad = sequence_padded - sequence;
700 let (x_bShp, dt_bSh, b_bSgr, c_bSgr) = if pad == 0 {
701 (x_bshp.clone(), dt_bsh, b_bsgr, c_bsgr)
702 } else {
703 let x_bshp = Tensor::cat(
704 vec![
705 x_bshp.clone(),
706 Tensor::zeros(Shape::new([batch, pad, nheads, per_head_dim]), &device),
707 ],
708 1,
709 );
710 let dt_bsh = Tensor::cat(
711 vec![
712 dt_bsh,
713 Tensor::zeros(Shape::new([batch, pad, nheads]), &device),
714 ],
715 1,
716 );
717 let b_bsgr = Tensor::cat(
718 vec![
719 b_bsgr,
720 Tensor::zeros(Shape::new([batch, pad, ngroups, state_rank]), &device),
721 ],
722 1,
723 );
724 let c_bsgr = Tensor::cat(
725 vec![
726 c_bsgr,
727 Tensor::zeros(Shape::new([batch, pad, ngroups, state_rank]), &device),
728 ],
729 1,
730 );
731 (x_bshp.clone(), dt_bsh, b_bsgr, c_bsgr)
732 };
733
734 // ── Reshapes into chunks ───────────────────────────────────────────────
735 let nchunks = sequence_padded / chunk_len;
736 let x_bnlhp = x_bShp.reshape([batch, nchunks, chunk_len, nheads, per_head_dim]);
737 let dt_bnlh = dt_bSh.reshape([batch, nchunks, chunk_len, nheads]);
738 let b_bnlgr = b_bSgr.reshape([batch, nchunks, chunk_len, ngroups, state_rank]);
739 let c_bnlgr = c_bSgr.reshape([batch, nchunks, chunk_len, ngroups, state_rank]);
740
741 // ── Step 6: Selective Scan ────────────────────────────────────────────
742 let ssd_input = crate::mamba2::ssd::Mamba2SsdInput {
743 x_bnlhp,
744 dt_bnlh,
745 a_decay_h: a_head_decay_h,
746 b_bnlgr,
747 c_bnlgr,
748 d_h: self.d_h.val(),
749 initial_state_bhpr: cache.ssm_bhpr,
750 init_state_hpr: self.init_state_hpr.as_ref().map(|s| s.val()),
751 };
752 ssd_input.sanity();
753 let (y_bnlhp, final_state_bhpr) = match ssd_path {
754 Mamba2SsdPath::Minimal(_chunk_len) => Self::ssd_minimal(ssd_input),
755 Mamba2SsdPath::Serial(_chunk_len) => Self::ssd_serial(ssd_input),
756 Mamba2SsdPath::SerialRecalculated(_chunk_len) => {
757 Self::ssd_serial_recalculated(ssd_input)
758 }
759 };
760 assert_eq!(
761 [batch, nchunks, chunk_len, nheads, per_head_dim],
762 y_bnlhp.dims()
763 );
764 san(&y_bnlhp);
765 san(&final_state_bhpr);
766
767 // Update the SSM state in the cache for the next call.
768 cache.ssm_bhpr = final_state_bhpr;
769 assert_eq!(
770 [batch, nheads, per_head_dim, state_rank],
771 cache.ssm_bhpr.dims()
772 );
773
774 // Remove zero-pad columns that were added at Step 5.
775 let y_bShp = y_bnlhp.reshape([batch, sequence_padded, nheads, per_head_dim]);
776 let y_bshp = if pad == 0 {
777 y_bShp
778 } else {
779 y_bShp.slice(s![.., 0..sequence, .., ..])
780 };
781
782 // Reshape into sequence.
783 let y_bsi = y_bshp.reshape([batch, sequence, d_inner]);
784 assert_eq!([batch, sequence, d_inner], y_bsi.dims());
785
786 // ── Step 7: Gated RMSNorm ─────────────────────────────────────────────
787 let y_bsi = self.norm.forward(y_bsi, z_gate_bsi);
788 assert_eq!([batch, sequence, d_inner], y_bsi.dims());
789 san(&y_bsi);
790
791 // ── Step 8: Out-projection ────────────────────────────────────────────
792 let out_bsm = self.out_proj.forward(y_bsi);
793 assert_eq!([batch, sequence, _d_model], out_bsm.dims());
794 san(&out_bsm);
795
796 (out_bsm, cache)
797 }
798}
799
800// ---------------------------------------------------------------------------
801// Mamba2::step (recurrent SSM — token-by-token decoding)
802// ---------------------------------------------------------------------------
803
804mod step {
805 use super::*;
806
807 impl<B: Backend> Mamba2<B> {
808 /// Process a **single token** using the pure recurrent SSM form.
809 ///
810 /// This is the O(H·P·N)-per-token decoding path. It runs one tick of
811 /// the discretised Mamba-2 recurrence:
812 ///
813 /// ```text
814 /// Āₜ = exp(Δₜ · A) scalar per head, ∈ (0, 1)
815 /// B̄ₜ = Δₜ · Bₜ ∈ ℝᴺ (Euler discretisation)
816 /// hₜ = Āₜ · hₜ₋₁ + B̄ₜ · xₜᵀ ∈ ℝ^{P×N} (outer product update)
817 /// yₜ = Cₜᵀ · hₜ + D · xₜ ∈ ℝᴾ (output)
818 /// ```
819 ///
820 /// The convolution is handled by manually sliding the cache window:
821 /// the oldest input column is dropped and the new token's projection
822 /// is appended.
823 ///
824 /// The SSM hidden state `cache.ssm_bhpr` is updated in-place via
825 /// the recurrence above.
826 ///
827 /// # Shapes
828 /// - `input_bm` : `[batch, d_model]`
829 /// - output : `[batch, d_model]`
830 #[allow(non_snake_case)]
831 pub fn step(
832 &self,
833 input_bm: Tensor<B, 2>,
834 cache: Option<Mamba2Cache<B>>,
835 ) -> (Tensor<B, 2>, Mamba2Cache<B>) {
836 let [batch, d_model] = input_bm.dims();
837 let d_inner = self.d_inner();
838 let ngroups = self.ngroups;
839 let nheads = self.nheads();
840 let per_head_dim = self.per_head_dim();
841 let conv_dim = self.conv_dim();
842 let state_rank = self.state_rank;
843 let [_conv_dim, _, conv_kernel] = self.conv1d.weight.dims();
844 let [_d_model, d_in_proj_out] = self.in_proj.weight.dims();
845
846 assert_eq!(conv_dim, _conv_dim);
847 assert_eq!(nheads % ngroups, 0);
848
849 let mut cache = cache.unwrap_or_else(|| {
850 let device = &input_bm.device();
851 let conv_bvk = Tensor::zeros(Shape::new([batch, conv_dim, conv_kernel]), device);
852 let ssm_bhpr = Tensor::zeros(
853 Shape::new([batch, nheads, per_head_dim, state_rank]),
854 device,
855 );
856 Mamba2Cache { conv_bvk, ssm_bhpr }
857 });
858
859 // ── In-projection ─────────────────────────────────────────────────
860 let (z_gate_bi, xbc_bv, dt_raw_bh) = {
861 let z_xbc_dt_bd = self.in_proj.forward(input_bm);
862 assert_eq!([batch, d_in_proj_out], z_xbc_dt_bd.dims());
863 assert_eq!([batch, d_inner + conv_dim + nheads], z_xbc_dt_bd.dims());
864
865 let mut parts = z_xbc_dt_bd
866 .split_with_sizes(vec![d_inner, conv_dim, nheads], 1)
867 .into_iter();
868 (
869 parts.next().unwrap(), // z [B, d_inner]
870 parts.next().unwrap(), // xbc[B, conv_dim]
871 parts.next().unwrap(), // dt [B, nheads]
872 )
873 };
874 assert_eq!([batch, d_inner], z_gate_bi.dims());
875 assert_eq!([batch, conv_dim], xbc_bv.dims());
876 assert_eq!([batch, nheads], dt_raw_bh.dims());
877
878 // ── Causal convolution (single step) ──────────────────────────────
879 // Slide the cache window left by one position, then insert the new
880 // token's projection `xbc_bv` as the rightmost column.
881 cache.conv_bvk = {
882 let conv_bvk = cache.conv_bvk;
883 assert_eq!([batch, conv_dim, conv_kernel], conv_bvk.dims());
884
885 // Drop the oldest (leftmost) column.
886 let tail_bvK = conv_bvk.slice(s![.., .., 1..]);
887 assert_eq!([batch, conv_dim, conv_kernel - 1], tail_bvK.dims());
888
889 // Append the new token as the rightmost column.
890 let updated_bvk = Tensor::cat([tail_bvK, xbc_bv.unsqueeze_dim(2)].to_vec(), 2);
891 assert_eq!([batch, conv_dim, conv_kernel], updated_bvk.dims());
892 updated_bvk
893 };
894
895 // Apply the depthwise convolution manually (one step = dot product
896 // of the cached window with the conv weight along the kernel axis).
897 let xbc_bv = {
898 let conv1d_v1k = self.conv1d.weight.val(); // [conv_dim, 1, conv_kernel]
899 assert_eq!([conv_dim, 1, conv_kernel], conv1d_v1k.dims());
900
901 let conv1d_bvk = conv1d_v1k
902 .permute([1, 0, 2]) // [1, conv_dim, conv_kernel]
903 .expand([batch, conv_dim, conv_kernel]);
904 assert_eq!([batch, conv_dim, conv_kernel], conv1d_bvk.dims());
905
906 // Element-wise multiply and sum over the kernel axis.
907 let product_bvk = cache.conv_bvk.clone() * conv1d_bvk;
908 let mut xbc_bv = product_bvk.sum_dim(2).squeeze_dim(2);
909 assert_eq!([batch, conv_dim], xbc_bv.dims());
910
911 // Add the (optional) bias.
912 if let Some(bias_v) = &self.conv1d.bias {
913 assert_eq!([conv_dim], bias_v.dims());
914 xbc_bv = xbc_bv + bias_v.val().unsqueeze();
915 }
916
917 Silu::new().forward(xbc_bv)
918 };
919 assert_eq!([batch, conv_dim], xbc_bv.dims());
920
921 // ── Split (x, B, C) ───────────────────────────────────────────────
922 assert_eq!(d_inner, nheads * per_head_dim);
923 let (x_bhp, b_bgr, c_bgr) = {
924 let mut parts = xbc_bv
925 .split_with_sizes(vec![d_inner, ngroups * state_rank, ngroups * state_rank], 1)
926 .into_iter();
927 (
928 parts
929 .next()
930 .unwrap() // [B, d_inner]
931 .reshape([batch, nheads, per_head_dim]),
932 parts
933 .next()
934 .unwrap() // [B, ngroups·N]
935 .reshape([batch, ngroups, state_rank]),
936 parts
937 .next()
938 .unwrap() // [B, ngroups·N]
939 .reshape([batch, ngroups, state_rank]),
940 )
941 };
942
943 // ── Discretisation ────────────────────────────────────────────────
944 // Δₜ = softplus(dt_raw + dt_bias)
945 let dt_bias_1h = self.dt_bias_h.val().unsqueeze_dim(0);
946 assert_eq!([1, nheads], dt_bias_1h.dims());
947 let dt_bh = softplus(dt_raw_bh + dt_bias_1h).clamp(self.dt_limit.0, self.dt_limit.1);
948 assert_eq!([batch, nheads], dt_bh.dims());
949
950 // A = -exp(a_log) < 0 (negative, decaying)
951 let a_head_decay_h = -self.a_log_h.val().exp();
952 assert_eq!([nheads], a_head_decay_h.dims());
953
954 // Āₜ = exp(Δₜ · A) ∈ (0, 1) scalar per [B, H]
955 let dta_bh = (dt_bh.clone() * a_head_decay_h.unsqueeze()).exp();
956 assert_eq!([batch, nheads], dta_bh.dims());
957
958 // ── SSM state update: hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜᵀ ───────────────────
959 // The cache holds h_{t-1} with shape [B, H, P, N].
960 // Āₜ is a scalar per head, so we broadcast it over P and N.
961 // B̄ₜ xₜᵀ is an outer product producing a [P, N] matrix per [B, H].
962
963 let ssm_shape_bhpr = [batch, nheads, per_head_dim, state_rank];
964
965 let dta_bhpr = dta_bh.unsqueeze_dims::<4>(&[2, 3]).expand(ssm_shape_bhpr); // [B, H, P, N]
966
967 // B̄ₜ xₜᵀ = (Δₜ Bₜ) xₜᵀ:
968 // x: [B, H, P] → broadcast to [B, H, P, N]
969 // B: [B, G, N] → expand to [B, H, N] → broadcast to [B, H, P, N]
970 // Δ: [B, H] → broadcast to [B, H, P, N]
971 let heads_per_group = nheads / ngroups;
972 let dtbx_bhpr = {
973 let x_bhpr = x_bhp.clone().unsqueeze_dim::<4>(3).expand(ssm_shape_bhpr);
974
975 // Expand B from [B, G, N] → [B, H, P, N], matching the SSD forward path:
976 // each group's projection is replicated across the heads_per_group heads of
977 // that group so that heads 0..(H/G) belong to group 0, etc.
978 let b_bhpr = b_bgr
979 .unsqueeze_dim::<4>(2) // [B, G, 1, N]
980 .expand([batch, ngroups, heads_per_group, state_rank]) // [B, G, H/G, N]
981 .reshape([batch, nheads, state_rank]) // [B, H, N]
982 .unsqueeze_dim::<4>(2) // [B, H, 1, N]
983 .expand(ssm_shape_bhpr); // [B, H, P, N]
984
985 let dt_bhpr = dt_bh.unsqueeze_dims::<4>(&[2, 3]).expand(ssm_shape_bhpr);
986
987 dt_bhpr * b_bhpr * x_bhpr // B̄ₜ xₜᵀ [B, H, P, N]
988 };
989 assert_eq!(ssm_shape_bhpr, dtbx_bhpr.dims());
990
991 // hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜᵀ
992 cache.ssm_bhpr = cache.ssm_bhpr * dta_bhpr + dtbx_bhpr;
993 assert_eq!(ssm_shape_bhpr, cache.ssm_bhpr.dims());
994
995 // ── Output: yₜ = Cₜ hₜ + D xₜ ──────────────────────────────────
996 let y_bi = {
997 // Cₜ hₜ: element-wise multiply C (broadcast to [B, H, P, N])
998 // with h_t, then sum over N.
999 let c_bhpr = c_bgr
1000 .unsqueeze_dim::<4>(2) // [B, G, 1, N]
1001 .expand([batch, ngroups, heads_per_group, state_rank]) // [B, G, H/G, N]
1002 .reshape([batch, nheads, state_rank]) // [B, H, N]
1003 .unsqueeze_dim::<4>(2) // [B, H, 1, N]
1004 .expand(ssm_shape_bhpr); // [B, H, P, N]
1005 assert_eq!(ssm_shape_bhpr, c_bhpr.dims());
1006
1007 let ch_bhp = (cache.ssm_bhpr.clone() * c_bhpr).sum_dim(3).squeeze_dim(3); // sum over N → [B, H, P]
1008 assert_eq!([batch, nheads, per_head_dim], ch_bhp.dims());
1009
1010 // D xₜ: per-head scalar skip.
1011 let d_1h1 = self.d_h.val().unsqueeze_dims(&[0, 2]);
1012 assert_eq!([1, nheads, 1], d_1h1.dims());
1013 let skip_bhp = d_1h1.expand([batch, nheads, per_head_dim]) * x_bhp;
1014 assert_eq!([batch, nheads, per_head_dim], skip_bhp.dims());
1015
1016 let y_bhp = ch_bhp + skip_bhp;
1017 assert_eq!([batch, nheads, per_head_dim], y_bhp.dims());
1018
1019 // Flatten heads → [B, d_inner], then apply gated RMSNorm.
1020 let y_bi = y_bhp.reshape([batch, d_inner]);
1021 self.norm.forward(y_bi, z_gate_bi)
1022 };
1023 assert_eq!([batch, d_inner], y_bi.dims());
1024
1025 // ── Out-projection ────────────────────────────────────────────────
1026 let out_bm = self.out_proj.forward(y_bi);
1027 assert_eq!([batch, d_model], out_bm.dims());
1028
1029 (out_bm, cache)
1030 }
1031 }
1032}
1033
1034// ---------------------------------------------------------------------------
1035// Tests
1036// ---------------------------------------------------------------------------
1037
1038#[cfg(all(test, feature = "backend-flex"))]
1039mod tests {
1040 use super::*;
1041 use burn::backend::{Autodiff, Flex};
1042 use burn::tensor::Distribution;
1043
1044 /// Inner (non-autodiff) backend used for materialising values and
1045 /// extracted gradients.
1046 type InnerB = Flex;
1047 /// Autodiff-wrapped backend used to drive `.backward()`.
1048 type B = Autodiff<InnerB>;
1049
1050 type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
1051
1052 fn small_config() -> Mamba2Config {
1053 Mamba2Config::new(32)
1054 .with_state_rank(8)
1055 .with_expand(2)
1056 .with_per_head_dim(8)
1057 }
1058
1059 /// A bundle of input + model-parameter gradients extracted from one
1060 /// forward+backward run. Each `check_grads_match` call compares these
1061 /// across two runs that should be mathematically equivalent.
1062 struct RunGrads {
1063 out: Tensor<InnerB, 3>,
1064 d_input: Tensor<InnerB, 3>,
1065 d_in_proj_w: Tensor<InnerB, 2>,
1066 d_conv1d_w: Tensor<InnerB, 3>,
1067 d_dt_bias: Tensor<InnerB, 1>,
1068 d_a_log: Tensor<InnerB, 1>,
1069 d_d: Tensor<InnerB, 1>,
1070 d_norm_gamma: Tensor<InnerB, 1>,
1071 d_out_proj_w: Tensor<InnerB, 2>,
1072 }
1073
1074 /// Run a closure that produces an output tensor from a model and an input
1075 /// (wrapped as a `Param` so it has its own autodiff leaf), then derive a
1076 /// scalar loss with a fixed (non-tracked) random "head" and return the
1077 /// gradients of the input and a representative set of model parameters.
1078 fn run_with_grads(
1079 model: &Mamba2<B>,
1080 input: &Param<Tensor<B, 3>>,
1081 head: &Tensor<InnerB, 3>,
1082 forward: impl FnOnce(&Mamba2<B>, Tensor<B, 3>) -> Tensor<B, 3>,
1083 ) -> RunGrads {
1084 let out = forward(model, input.val());
1085 let out_inner = out.clone().inner();
1086
1087 let head = Tensor::from_inner(head.clone());
1088 let loss = (out * head).sum();
1089 let grads = loss.backward();
1090
1091 RunGrads {
1092 out: out_inner,
1093 d_input: input.val().grad(&grads).expect("grad input"),
1094 d_in_proj_w: model
1095 .in_proj
1096 .weight
1097 .val()
1098 .grad(&grads)
1099 .expect("grad in_proj.weight"),
1100 d_conv1d_w: model
1101 .conv1d
1102 .weight
1103 .val()
1104 .grad(&grads)
1105 .expect("grad conv1d.weight"),
1106 d_dt_bias: model.dt_bias_h.val().grad(&grads).expect("grad dt_bias_h"),
1107 d_a_log: model.a_log_h.val().grad(&grads).expect("grad a_log_h"),
1108 d_d: model.d_h.val().grad(&grads).expect("grad d_h"),
1109 d_norm_gamma: model
1110 .norm
1111 .gamma
1112 .val()
1113 .grad(&grads)
1114 .expect("grad norm.gamma"),
1115 d_out_proj_w: model
1116 .out_proj
1117 .weight
1118 .val()
1119 .grad(&grads)
1120 .expect("grad out_proj.weight"),
1121 }
1122 }
1123
1124 /// Assert that every entry in `a` and `b` agrees to within `grad_tol`,
1125 /// printing every comparison so a failure dump shows the full picture
1126 /// (instead of stopping at the first mismatch).
1127 fn check_grads_match(label: &str, a: &RunGrads, b: &RunGrads, grad_tol: f32) {
1128 let mut failures: Vec<String> = Vec::new();
1129 macro_rules! check {
1130 ($field:ident, $name:expr) => {{
1131 let d = (a.$field.clone() - b.$field.clone())
1132 .abs()
1133 .max()
1134 .into_scalar();
1135 eprintln!("{:>40} {:>16} | max abs diff = {:>10.6}", label, $name, d);
1136 if d >= grad_tol {
1137 failures.push(format!(
1138 "{}: grad of {} max abs diff = {:.6} (tol {})",
1139 label, $name, d, grad_tol
1140 ));
1141 }
1142 }};
1143 }
1144 check!(d_input, "input");
1145 check!(d_in_proj_w, "in_proj.weight");
1146 check!(d_conv1d_w, "conv1d.weight");
1147 check!(d_dt_bias, "dt_bias_h");
1148 check!(d_a_log, "a_log_h");
1149 check!(d_d, "d_h");
1150 check!(d_norm_gamma, "norm.gamma");
1151 check!(d_out_proj_w, "out_proj.weight");
1152 assert!(
1153 failures.is_empty(),
1154 "gradient mismatches:\n {}",
1155 failures.join("\n ")
1156 );
1157 }
1158
1159 /// Helper that builds a fresh `Param<Tensor>` from a stable inner tensor.
1160 /// A new Param is needed per run so that the autodiff leaf has a fresh
1161 /// node, isolating each backward pass to its own forward graph.
1162 fn param_input(input: &Tensor<InnerB, 3>) -> Param<Tensor<B, 3>> {
1163 Param::from_tensor(Tensor::from_inner(input.clone()))
1164 }
1165
1166 fn run_step_matches_forward(cfg: Mamba2Config, ssd_path: Mamba2SsdPath) {
1167 let device: Device = Default::default();
1168 let model = cfg.init::<B>(&device);
1169
1170 let batch = 2;
1171 let seq_len = 5;
1172 let d_model = cfg.d_model;
1173
1174 let input = Tensor::<InnerB, 3>::random(
1175 [batch, seq_len, d_model],
1176 Distribution::Normal(0.0, 1.0),
1177 &device,
1178 );
1179 let head = Tensor::<InnerB, 3>::random(
1180 [batch, seq_len, d_model],
1181 Distribution::Normal(0.0, 1.0),
1182 &device,
1183 );
1184
1185 let input_fwd = param_input(&input);
1186 let r_fwd = run_with_grads(&model, &input_fwd, &head, |m, x| {
1187 let (out, _) = m.forward(x, None, ssd_path.clone());
1188 out
1189 });
1190
1191 let input_step = param_input(&input);
1192 let r_step = run_with_grads(&model, &input_step, &head, |m, x| {
1193 let mut cache: Option<Mamba2Cache<B>> = None;
1194 let mut outs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
1195 for t in 0..seq_len {
1196 let token = x.clone().narrow(1, t, 1).squeeze_dim(1);
1197 let (out_t, new_cache) = m.step(token, cache);
1198 cache = Some(new_cache);
1199 outs.push(out_t);
1200 }
1201 Tensor::stack(outs, 1)
1202 });
1203
1204 // ── Forward agreement (existing check) ───────────────────────────
1205 let diff = (r_fwd.out.clone() - r_step.out.clone())
1206 .abs()
1207 .max()
1208 .into_scalar();
1209 assert!(
1210 diff < 1e-4,
1211 "step() vs forward() max absolute difference = {diff:.6} (expected < 1e-4)"
1212 );
1213
1214 // ── Gradient agreement ───────────────────────────────────────────
1215 // step() and forward() are different reductions of the same SSM,
1216 // so their per-parameter gradients should also agree, modulo
1217 // float-summation order noise.
1218 check_grads_match("step vs forward", &r_fwd, &r_step, 1e-3);
1219 }
1220
1221 #[test]
1222 fn step_matches_forward() {
1223 run_step_matches_forward(small_config(), Mamba2SsdPath::Minimal(Some(4)));
1224 }
1225
1226 #[test]
1227 fn step_matches_forward_ngroups2() {
1228 let cfg = Mamba2Config::new(32)
1229 .with_state_rank(8)
1230 .with_expand(2)
1231 .with_per_head_dim(16)
1232 .with_ngroups(2);
1233 run_step_matches_forward(cfg, Mamba2SsdPath::Minimal(Some(4)));
1234 }
1235
1236 /// forward(full) ≡ forward(prefix) then forward(suffix, cache_from_prefix).
1237 ///
1238 /// Verifies stateful chunked-prefill: the convolution window carried in the
1239 /// cache must replay correctly at the start of the second segment.
1240 fn run_split_matches_full(cfg: Mamba2Config, ssd_path: Mamba2SsdPath) {
1241 let device: Device = Default::default();
1242 let model = cfg.init::<B>(&device);
1243
1244 let batch = 2;
1245 let seq_len = 6;
1246 let split = 2;
1247 let d_model = cfg.d_model;
1248
1249 let input = Tensor::<InnerB, 3>::random(
1250 [batch, seq_len, d_model],
1251 Distribution::Normal(0.0, 1.0),
1252 &device,
1253 );
1254 let head = Tensor::<InnerB, 3>::random(
1255 [batch, seq_len, d_model],
1256 Distribution::Normal(0.0, 1.0),
1257 &device,
1258 );
1259
1260 let input_full = param_input(&input);
1261 let r_full = run_with_grads(&model, &input_full, &head, |m, x| {
1262 let (out, _) = m.forward(x, None, ssd_path.clone());
1263 out
1264 });
1265
1266 let input_split = param_input(&input);
1267 let r_split = run_with_grads(&model, &input_split, &head, |m, x| {
1268 let prefix = x.clone().narrow(1, 0, split);
1269 let suffix = x.narrow(1, split, seq_len - split);
1270 let (out_prefix, cache) = m.forward(prefix, None, ssd_path.clone());
1271 let (out_suffix, _) = m.forward(suffix, Some(cache), ssd_path.clone());
1272 Tensor::cat(vec![out_prefix, out_suffix], 1)
1273 });
1274
1275 // ── Forward agreement (existing check) ───────────────────────────
1276 let diff = (r_full.out.clone() - r_split.out.clone())
1277 .abs()
1278 .max()
1279 .into_scalar();
1280 assert!(
1281 diff < 1e-4,
1282 "split forward vs full forward max absolute difference = {diff:.6} (expected < 1e-4)"
1283 );
1284
1285 // ── Gradient agreement ───────────────────────────────────────────
1286 check_grads_match("split vs full", &r_full, &r_split, 1e-3);
1287 }
1288
1289 #[test]
1290 fn split_matches_full() {
1291 run_split_matches_full(small_config(), Mamba2SsdPath::Minimal(Some(4)));
1292 }
1293
1294 #[test]
1295 fn split_matches_full_ngroups2() {
1296 let cfg = Mamba2Config::new(32)
1297 .with_state_rank(8)
1298 .with_expand(2)
1299 .with_per_head_dim(16)
1300 .with_ngroups(2);
1301 run_split_matches_full(cfg, Mamba2SsdPath::Minimal(Some(4)));
1302 }
1303
1304 // ── is_norm_before_gate = true ───────────────────────────────────────────
1305
1306 #[test]
1307 fn step_matches_forward_norm_before_gate() {
1308 let cfg = Mamba2Config::new(32)
1309 .with_state_rank(8)
1310 .with_expand(2)
1311 .with_per_head_dim(8)
1312 .with_is_norm_before_gate(true);
1313 run_step_matches_forward(cfg, Mamba2SsdPath::Minimal(Some(4)));
1314 }
1315
1316 #[test]
1317 fn split_matches_full_norm_before_gate() {
1318 let cfg = Mamba2Config::new(32)
1319 .with_state_rank(8)
1320 .with_expand(2)
1321 .with_per_head_dim(8)
1322 .with_is_norm_before_gate(true);
1323 run_split_matches_full(cfg, Mamba2SsdPath::Minimal(Some(4)));
1324 }
1325
1326 // ── SSD path agreement ───────────────────────────────────────────────────
1327
1328 fn run_ssd_paths_agree(cfg: Mamba2Config) {
1329 let device: Device = Default::default();
1330 let model = cfg.init::<B>(&device);
1331
1332 let batch = 2;
1333 let seq_len = 8;
1334 let d_model = cfg.d_model;
1335
1336 let input = Tensor::<InnerB, 3>::random(
1337 [batch, seq_len, d_model],
1338 Distribution::Normal(0.0, 1.0),
1339 &device,
1340 );
1341 let head = Tensor::<InnerB, 3>::random(
1342 [batch, seq_len, d_model],
1343 Distribution::Normal(0.0, 1.0),
1344 &device,
1345 );
1346
1347 // Each path gets its own input Param so the autodiff leaves are
1348 // independent across the three backward passes.
1349 let input_min = param_input(&input);
1350 let input_ser = param_input(&input);
1351 let input_rec = param_input(&input);
1352
1353 let r_min = run_with_grads(&model, &input_min, &head, |m, x| {
1354 let (out, _) = m.forward(x, None, Mamba2SsdPath::Minimal(Some(4)));
1355 out
1356 });
1357 let r_ser = run_with_grads(&model, &input_ser, &head, |m, x| {
1358 let (out, _) = m.forward(x, None, Mamba2SsdPath::Serial(Some(4)));
1359 out
1360 });
1361 let r_rec = run_with_grads(&model, &input_rec, &head, |m, x| {
1362 let (out, _) = m.forward(x, None, Mamba2SsdPath::SerialRecalculated(Some(4)));
1363 out
1364 });
1365
1366 // ── Forward agreement ────────────────────────────────────────────
1367 let tol = 1e-4;
1368 let d_ser = (r_min.out.clone() - r_ser.out.clone())
1369 .abs()
1370 .max()
1371 .into_scalar();
1372 let d_rec = (r_min.out.clone() - r_rec.out.clone())
1373 .abs()
1374 .max()
1375 .into_scalar();
1376 assert!(
1377 d_ser < tol,
1378 "Minimal vs Serial: forward max abs diff = {d_ser:.6} (tol {tol})"
1379 );
1380 assert!(
1381 d_rec < tol,
1382 "Minimal vs SerialRecalculated: forward max abs diff = {d_rec:.6} (tol {tol})"
1383 );
1384
1385 // ── Gradient agreement ───────────────────────────────────────────
1386 check_grads_match("Minimal vs Serial", &r_min, &r_ser, 1e-3);
1387 check_grads_match("Minimal vs SerialRecalculated", &r_min, &r_rec, 1e-3);
1388 }
1389
1390 #[test]
1391 fn ssd_paths_agree() {
1392 run_ssd_paths_agree(small_config());
1393 }
1394
1395 #[test]
1396 fn ssd_paths_agree_ngroups2() {
1397 let cfg = Mamba2Config::new(32)
1398 .with_state_rank(8)
1399 .with_expand(2)
1400 .with_per_head_dim(16)
1401 .with_ngroups(2);
1402 run_ssd_paths_agree(cfg);
1403 }
1404}