Skip to main content

burn_mamba/mamba2/
mamba2.rs

1//! # Mamba-2 SSM Block — Structured State Space Duality (SSD)
2//!
3//! This module implements the core **SSD layer** from the paper
4//! *"Transformers are SSMs: Generalized Models and Efficient Algorithms
5//! through Structured State Space Duality"* (Dao & Gu, 2024).
6//!
7//! ## The SSD Model
8//!
9//! The SSD layer is a multi-head selective SSM.  Each head processes a
10//! sequence of `P`-dimensional inputs `X ∈ ℝ^{T×P}` through the recurrence
11//! (Eq. 1–2 of the paper):
12//!
13//! ```text
14//!   hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ          (state update)
15//!   yₜ = Cₜᵀ hₜ                     (output readout)
16//! ```
17//!
18//! where:
19//! - `hₜ ∈ ℝ^{N×P}` is the hidden state (N = `state_rank`, P = `per_head_dim`)
20//! - `Āₜ = exp(Δₜ A) ∈ ℝ` is a scalar decay (the key SSD constraint:
21//!   Āₜ = αₜ · I, i.e. scalar times identity, rather than a diagonal matrix)
22//! - `B̄ₜ = Δₜ Bₜ ∈ ℝᴺ`  is the (discretised) input projection
23//! - `Cₜ ∈ ℝᴺ`           is the output projection
24//! - `Δₜ > 0`             is the (input-dependent) discretisation step size
25//! - `A < 0`              is a learnable scalar decay rate per head
26//!
27//! ## State Space Duality
28//!
29//! Unrolling the recurrence yields an equivalent **attention-like** form
30//! (Eq. 6–7):
31//!
32//! ```text
33//!   M = L ∘ (C Bᵀ)      ∈ ℝ^{T×T}
34//!   Y = M · X
35//! ```
36//!
37//! where `L` is the **1-semiseparable mask** (Eq. 4–5):
38//!
39//! ```text
40//!   Lᵢⱼ = āᵢ · āᵢ₋₁ · ... · āⱼ₊₁    (i ≥ j)
41//!   Lᵢⱼ = 0                           (i < j)
42//! ```
43//!
44//! This makes the layer equivalent to causal linear attention
45//! `Y = (L ∘ QKᵀ) V` under the renaming `(C, B, X) ↦ (Q, K, V)`.
46//!
47//! ## The Chunkwise SSD Algorithm
48//!
49//! See [`minimal`] for more information.
50//!
51//! ## Notation / Dimension Keys
52//!
53//! Throughout all file, tensor names carry a suffix encoding their shape.
54//! The letters used are:
55//!
56//! | Letter | Dimension | Typical value |
57//! |--------|-----------|---------------|
58//! | `b`    | batch     | varies        |
59//! | `s`    | sequence length T | varies |
60//! | `m`    | d_model   | 768, 1024 … |
61//! | `i`    | d_inner = expand·d_model | 2·d_model |
62//! | `h`    | nheads H  | d_inner / P  |
63//! | `p`    | per_head_dim P | 64, 128 |
64//! | `r`    | state_rank N   | 64–256  |
65//! | `v`    | conv_dim  | d_inner + 2·G·N |
66//! | `k`    | conv_kernel    | 4       |
67//! | `g`    | ngroups G      | 1–H     |
68//! | `n`    | nchunks = T/Q  | varies  |
69//! | `l`    | chunk_len Q    | 64–256  |
70//! | `N`    | 1+nchunks (padded for state scan) | — |
71//! | `f`    | P·N (flattened state for matmul)   | — |
72
73use crate::mamba2::prelude::*;
74use crate::utils::sanity::sanity as san;
75use crate::utils::{
76    rms_norm_gated::{RmsNormGated, RmsNormGatedConfig},
77    silu::Silu,
78    softplus::softplus,
79};
80use burn::prelude::*;
81use burn::{
82    module::{Module, Param},
83    nn::conv::{Conv1d, Conv1dConfig},
84    nn::{Initializer, Linear, LinearConfig},
85};
86
87// ---------------------------------------------------------------------------
88// Mamba2  (the SSM block)
89// ---------------------------------------------------------------------------
90
91/// The Mamba-2 SSM block.
92///
93/// Implements the full SSD layer as described in §5 of the paper.  Supports
94/// two execution modes:
95///
96/// - [`Self::forward`] — chunkwise SSD for training / prefill
97///   (exploits tensor cores; linear in sequence length T)
98/// - [`Self::step`]    — pure recurrent form for token-by-token decoding
99///   (O(H·P·N) per step; no KV-cache)
100///
101/// ## Architecture (one forward pass through the block)
102///
103/// ```text
104///   u  [B, T, D]
105///   ├─ in_proj ──────────────────────────────────┐
106///   │                                            │
107///   │  z [B,T,I]   xbc [B,T,V]   dt_raw [B,T,H] │
108///   │                │                           │
109///   │            causal Conv1d                   │
110///   │                │ SiLU                      │
111///   │           split into                       │
112///   │       x [B,T,H,P]  B [B,T,G,N]  C [B,T,G,N]
113///   │                                            │
114///   │     Δ = softplus(dt_raw + dt_bias)         │
115///   │     Ā = exp(Δ · A)   [scalar per head]     │
116///   │     B̄ = Δ · B                              │
117///   │                                            │
118///   │  ┌──── chunked_selective_scan ─────────┐   │
119///   │  │  (Steps 1–4, see below)             │   │
120///   │  └────────────────────────────────────-┘   │
121///   │     y [B,T,H,P]                            │
122///   │     + D skip                               │
123///   │     RmsNormGated(·, z)                     │
124///   └─ out_proj ─────────────────────────────────┘
125///   output  [B, T, D]
126/// ```
127#[derive(Module, Debug)]
128pub struct Mamba2<B: Backend> {
129    /// Input projection: maps `d_model → d_inner + conv_dim + nheads`.
130    ///
131    /// The output is split into three parts:
132    /// - `z      [B, T, d_inner]`  — multiplicative gate for the output norm
133    /// - `xbc    [B, T, conv_dim]` — input to the causal convolution, which
134    ///   is then split into (x, B, C) after activation
135    /// - `dt_raw [B, T, nheads]`   — raw (pre-softplus) discretisation step Δ
136    pub in_proj: Linear<B>,
137
138    /// Causal depthwise Conv1d applied to the `xbc` projection.
139    ///
140    /// - Input/output channels: `conv_dim`
141    /// - Kernel size: `conv_kernel` (typically 4)
142    /// - Groups: `conv_dim` (fully depthwise — each channel is independent)
143    /// - Padding: **none** (left-padding is applied manually so the convolution
144    ///   is strictly causal)
145    ///
146    /// The convolution provides a local `conv_kernel`-token context window
147    /// before the SSM, which helps the model capture short-range dependencies
148    /// that the SSM's recurrent form handles less efficiently.
149    pub conv1d: Conv1d<B>,
150
151    /// Per-head bias for the discretisation step size Δ.
152    ///
153    /// Shape: `[nheads]`
154    ///
155    /// At inference time, `Δₜ = softplus(dt_raw_t + dt_bias)`.
156    /// Initialised such that the corresponding initial `Δ` values are
157    /// log-uniformly distributed in `[dt_min, dt_max]`.
158    pub dt_bias_h: Param<Tensor<B, 1>>,
159
160    /// Hard clamp applied to Δ after softplus:  `Δ ∈ [dt_limit.0, dt_limit.1]`.
161    ///
162    /// Prevents degenerate discretisations (e.g. Δ → 0 causes Ā → 1, meaning
163    /// the state never decays; Δ → ∞ causes Ā → 0, meaning the state is
164    /// immediately wiped each step).
165    pub dt_limit: (f64, f64),
166
167    /// Per-head log-magnitude of the continuous-time decay parameter A.
168    ///
169    /// Shape: `[nheads]`
170    ///
171    /// The actual (negative) decay rate is `A = -exp(a_log)`.  The discrete
172    /// decay is `Āₜ = exp(Δₜ · A) = exp(-Δₜ · exp(a_log)) ∈ (0, 1)`.
173    ///
174    /// Storing the *log* of the magnitude and negating ensures A < 0
175    /// (decaying system) unconditionally and avoids any sign-constraint
176    /// during gradient descent.
177    pub a_log_h: Param<Tensor<B, 1>>,
178
179    /// Per-head skip (D) coefficient.
180    ///
181    /// Shape: `[nheads]`
182    ///
183    /// Adds a direct path from the (post-convolution, pre-SSM) input to the
184    /// output:  `yₜ += D · xₜ`.  Initialised to ones.
185    pub d_h: Param<Tensor<B, 1>>,
186
187    /// Gated RMSNorm applied to the SSM output, conditioned on the gate `z`.
188    ///
189    /// Input channel dimension: `d_inner`.
190    ///
191    /// This combines the multiplicative gate (from `z`) and a normalisation
192    /// step into a single fused operation, matching the architecture in §5.2
193    /// of the paper.
194    pub norm: RmsNormGated<B>,
195
196    /// Output projection: maps `d_inner → d_model`.
197    pub out_proj: Linear<B>,
198
199    /// Optional learnable initial hidden state `h₀`.
200    ///
201    /// Shape: `[nheads, per_head_dim, state_rank]` (i.e. `[H, P, N]`)
202    ///
203    /// When `None`, the initial state is zero (the standard default).
204    /// When `Some`, the stored tensor is used as the initial condition for
205    /// *every* forward call (not per-batch; it is broadcast over the batch
206    /// dimension).
207    pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
208
209    /// State rank `N` — the number of latent dimensions in the SSM hidden
210    /// state `h ∈ ℝ^{N×P}` per head.  Corresponds to the paper's `N`.
211    pub state_rank: usize,
212
213    /// Number of B/C groups `G` for grouped SSM heads (analogous to
214    /// grouped-query attention).  `G` divides `nheads`; all `nheads/G` heads
215    /// within a group share the same B and C projections while having
216    /// independent X, A, and Z projections.
217    pub ngroups: usize,
218}
219
220impl<B: Backend> Mamba2<B> {
221    /// `d_inner = expand · d_model`.  Inferred from the norm's weight shape.
222    pub fn d_inner(&self) -> usize {
223        let [d_inner] = self.norm.gamma.dims();
224        d_inner
225    }
226
227    /// `nheads = d_inner / per_head_dim`.  Inferred from `a_log_h`.
228    pub fn nheads(&self) -> usize {
229        let [nheads] = self.a_log_h.dims();
230        nheads
231    }
232
233    /// `per_head_dim P = d_inner / nheads`.
234    pub fn per_head_dim(&self) -> usize {
235        self.d_inner() / self.nheads()
236    }
237
238    /// `conv_dim = d_inner + 2 · ngroups · state_rank`.
239    pub fn conv_dim(&self) -> usize {
240        self.d_inner() + 2 * self.ngroups * self.state_rank
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Mamba2Config  (hyperparameters and factory)
246// ---------------------------------------------------------------------------
247
248/// Hyperparameters for the Mamba-2 SSM block.
249///
250/// All computed quantities (e.g. `nheads`, `d_inner`, `conv_dim`) are derived
251/// from the stored fields; see the helper methods on [`Mamba2Config`].
252#[derive(Config, Debug)]
253pub struct Mamba2Config {
254    /// Model (hidden) dimension D.  Every token is represented as a
255    /// D-dimensional vector at the input and output of the block.
256    pub d_model: usize,
257
258    /// State rank N — the latent dimension of the SSM hidden state.
259    ///
260    /// Larger N gives a more expressive state but increases memory and compute
261    /// per step.  The paper uses N ∈ {64, 128, 256} for most experiments.
262    #[config(default = 128)]
263    pub state_rank: usize,
264
265    /// Causal convolution window length.  Typically 4.
266    #[config(default = 4)]
267    pub conv_kernel: usize,
268
269    /// Expansion factor for `d_inner = expand · d_model`.
270    ///
271    /// An expansion of 2 doubles the internal width, keeping the parameter
272    /// count of the SSM block comparable to a standard attention layer.
273    #[config(default = 2)]
274    pub expand: usize,
275
276    /// Head dimension P.  The total `d_inner` is split into
277    /// `nheads = d_inner / P` independent SSM heads.
278    ///
279    /// Typical values: 64 or 128.  Smaller P means more heads and a larger
280    /// hidden state per model dimension.
281    #[config(default = 64)]
282    pub per_head_dim: usize,
283
284    /// Number of B/C groups G.  Must divide `nheads`.
285    ///
286    /// Setting G < nheads reduces the B and C projection sizes (analogous to
287    /// GQA in attention), saving memory without a large accuracy cost.
288    #[config(default = 1)]
289    pub ngroups: usize,
290
291    /// Range `[lo, hi]` for the uniform initialisation of the magnitude of A.
292    ///
293    /// `A = -Uniform(lo, hi)`, stored as `a_log = log(Uniform(lo, hi))`.
294    /// The paper uses `[1, 16]` by default.
295    #[config(default = "(1., 16.)")]
296    pub a_init_range: (f64, f64),
297
298    /// Gated RMSNorm mode: when `true` the norm is applied *before* the gate;
299    /// when `false` (default) the gate is applied first (SiLU-gated norm).
300    #[config(default = false)]
301    pub is_norm_before_gate: bool,
302
303    /// Minimum value of the initial Δ distribution.  Used to set `dt_bias`.
304    #[config(default = 1e-3)]
305    pub dt_min: f64,
306
307    /// Maximum value of the initial Δ distribution.  Used to set `dt_bias`.
308    #[config(default = 0.1)]
309    pub dt_max: f64,
310
311    /// Floor clamped onto the sampled initial Δ values before inverting to
312    /// obtain `dt_bias`.  Prevents numerical issues with very small Δ.
313    #[config(default = 1e-4)]
314    pub dt_init_floor: f64,
315
316    /// Hard clamp limits for Δ at runtime: `Δ ∈ [dt_limit.0, dt_limit.1]`.
317    ///
318    /// Defaults to `(0, f16::MAX ≈ 65504)`, effectively only clamping at 0.
319    #[config(default = "(0., 6.5504e+4)")]
320    pub dt_limit: (f64, f64),
321
322    /// Whether to add a bias term to the `in_proj` and `out_proj` projections.
323    #[config(default = false)]
324    pub has_proj_bias: bool,
325
326    /// Whether to add a bias term to the depthwise convolution.
327    #[config(default = true)]
328    pub has_conv_bias: bool,
329
330    /// Whether to allocate a learnable initial SSM state `h₀`.
331    ///
332    /// When `false` (default), the hidden state starts at zero for every
333    /// sequence.  When `true`, `init_state_hpr` is allocated as a trainable
334    /// parameter of shape `[nheads, per_head_dim, state_rank]`.
335    #[config(default = false)]
336    pub has_learnable_init_state: bool,
337}
338
339impl Mamba2Config {
340    // -----------------------------------------------------------------------
341    // Computed dimensions
342    // -----------------------------------------------------------------------
343
344    /// `d_inner = expand · d_model`.
345    pub fn d_inner(&self) -> usize {
346        self.expand * self.d_model
347    }
348
349    /// `nheads H = d_inner / per_head_dim`.
350    pub fn nheads(&self) -> usize {
351        self.d_inner() / self.per_head_dim
352    }
353
354    /// `conv_dim = d_inner + 2 · ngroups · state_rank`.
355    ///
356    /// The depthwise convolution processes `x`, `B`, and `C` concatenated:
357    /// x contributes `d_inner` channels, B and C each contribute
358    /// `ngroups · state_rank` channels.
359    pub fn conv_dim(&self) -> usize {
360        self.d_inner() + 2 * self.ngroups * self.state_rank
361    }
362
363    // -----------------------------------------------------------------------
364    // Initialisation
365    // -----------------------------------------------------------------------
366
367    /// Allocate and initialise all Mamba-2 block parameters on `device`.
368    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2<B> {
369        let d_inner = self.d_inner();
370        let nheads = self.nheads();
371        let conv_dim = self.conv_dim();
372
373        assert!(self.per_head_dim > 0, "per_head_dim must be positive");
374        assert_eq!(
375            nheads * self.per_head_dim,
376            d_inner,
377            "d_inner must be divisible by per_head_dim"
378        );
379        assert_eq!(
380            nheads % self.ngroups,
381            0,
382            "nheads must be divisible by ngroups"
383        );
384
385        // Uniform initialiser matching PyTorch's default: U(-1/√fan_in, 1/√fan_in).
386        let uniform_init = |fan_in: usize| {
387            let bound = 1.0 / (fan_in as f64).sqrt();
388            Initializer::Uniform {
389                min: -bound,
390                max: bound,
391            }
392        };
393
394        // ── in_proj ──────────────────────────────────────────────────────────
395        // Projects d_model → (z, xbc, dt_raw).
396        // Size:  d_inner  +  conv_dim  +  nheads
397        let d_in_proj_out = d_inner + conv_dim + nheads;
398        let in_proj = LinearConfig::new(self.d_model, d_in_proj_out)
399            .with_bias(self.has_proj_bias)
400            .with_initializer(uniform_init(self.d_model))
401            .init::<B>(device);
402
403        // ── conv1d ───────────────────────────────────────────────────────────
404        // Causal depthwise convolution.  Left-padding is applied manually in
405        // `forward` and `step`, so we request "Valid" (no automatic padding).
406        // The initialiser fan_in is `in_channels / groups * kernel_size = 1 * conv_kernel`.
407        let conv1d = Conv1dConfig::new(conv_dim, conv_dim, self.conv_kernel)
408            .with_padding(burn::nn::PaddingConfig1d::Valid)
409            .with_groups(conv_dim)
410            .with_bias(self.has_conv_bias)
411            .with_initializer(uniform_init(self.conv_kernel))
412            .init::<B>(device);
413
414        // ── dt_bias ──────────────────────────────────────────────────────────
415        // We want the initial Δ values (after softplus) to be log-uniformly
416        // distributed in [dt_min, dt_max].  The inverse softplus (inverse of
417        // softplus(x) = ln(1 + exp(x))) is used to back-solve for the bias:
418        //   dt_bias = softplus⁻¹(dt) = dt + ln(1 - exp(-dt)) ≈ dt + ln(dt)
419        // which simplifies to `dt + log(exp(dt) - 1)` in the formula below.
420        let expm1 = |t: Tensor<B, 1>| t.exp() - 1.;
421        let dt_h = Tensor::random(
422            [nheads],
423            burn::tensor::Distribution::Uniform(self.dt_min.ln(), self.dt_max.ln()),
424            device,
425        )
426        .exp();
427        let dt_h = dt_h.clamp(self.dt_init_floor, f64::INFINITY);
428        // Inverse softplus: softplus⁻¹(y) = y + log(1 - e^{-y}) = y + log(e^y - 1) - y = log(e^y - 1)
429        let inv_dt_h = dt_h.clone() + (-expm1(-dt_h)).log();
430        let dt_bias_h = Param::from_tensor(inv_dt_h);
431
432        // ── a_log ─────────────────────────────────────────────────────────────
433        // A is constrained to be negative (decaying system).
434        // We store a_log = log(|A|) and recover A = -exp(a_log) at runtime.
435        // This parameterisation ensures A < 0 unconditionally.
436        assert!(
437            self.a_init_range.0 > 0.0,
438            "a_init_range lower bound must be > 0"
439        );
440        assert!(
441            self.a_init_range.0 < self.a_init_range.1,
442            "a_init_range must satisfy lo < hi"
443        );
444        let a_h = Tensor::random(
445            [nheads],
446            burn::tensor::Distribution::Uniform(self.a_init_range.0, self.a_init_range.1),
447            device,
448        );
449        let a_log_h = Param::from_tensor(a_h.log());
450
451        // ── D (skip connection) ───────────────────────────────────────────────
452        // Initialised to ones, adding a direct residual path from input to output.
453        let d_h = Initializer::Ones.init::<B, 1, _>([nheads], device);
454
455        // ── norm (gated RMSNorm) and out_proj ─────────────────────────────────
456        let norm = RmsNormGatedConfig::new(d_inner)
457            .with_norm_before_gate(self.is_norm_before_gate)
458            .init(device);
459        let out_proj = LinearConfig::new(d_inner, self.d_model)
460            .with_bias(self.has_proj_bias)
461            .with_initializer(uniform_init(d_inner))
462            .init(device);
463
464        // ── learnable initial state (optional) ────────────────────────────────
465        let init_state_hpr = self.has_learnable_init_state.then(|| {
466            Initializer::Zeros.init::<B, 3, _>([nheads, self.per_head_dim, self.state_rank], device)
467        });
468
469        Mamba2 {
470            in_proj,
471            conv1d,
472            dt_bias_h,
473            dt_limit: self.dt_limit,
474            a_log_h,
475            d_h,
476            norm,
477            out_proj,
478            init_state_hpr,
479            state_rank: self.state_rank,
480            ngroups: self.ngroups,
481        }
482    }
483}
484
485// ---------------------------------------------------------------------------
486// Mamba2::forward  (chunkwise SSD — training / prefill)
487// ---------------------------------------------------------------------------
488
489impl<B: Backend + Mamba2BackendExt> Mamba2<B> {
490    /// Process a full input sequence using the chunkwise SSD algorithm.
491    ///
492    /// This is the primary training and prefill path.  The computation is
493    /// **linear in T** but uses batched matrix multiplications (GEMMs) that
494    /// can exploit GPU tensor cores — unlike the naive sequential recurrence,
495    /// which requires O(T) serial steps.
496    ///
497    /// ## Full dataflow
498    ///
499    /// 1. **In-projection**: `u → (z, xbc, dt_raw)` via a single linear layer.
500    /// 2. **Causal Conv1d + SiLU**: local context mixing over `xbc`.
501    /// 3. **Split**: `xbc → (x, B, C)`.
502    /// 4. **Discretise**: `Δ = softplus(dt_raw + dt_bias)`;
503    ///    `Ā = exp(Δ · A)`;  `B̄ = Δ · B`.
504    /// 5. **Padding**: sequence padding.
505    /// 6. **Chunked SSD**: four-step chunkwise algorithm (see
506    ///    [`Self::chunked_selective_scan`]).
507    /// 7. **Gated RMSNorm**: `y = RMSNorm(y) · σ(z)`.
508    /// 8. **Out-projection**: `y → output`.
509    ///
510    /// ## Sequence padding
511    ///
512    /// If `sequence_unpadded % chunk_len ≠ 0`, the sequence is zero-padded
513    /// to the next multiple of Q.  Zero-padding is equivalent to inserting
514    /// identity steps (`Δ = 0  ⇒  Ā = exp(0) = 1,  B̄ = 0`), so the SSM
515    /// state is carried forward unchanged through the pad — making it safe to
516    /// read the final state of the padded last chunk as the true final state.
517    ///
518    /// # Shapes
519    /// - `input_bsm` : `[batch, sequence, d_model]`
520    /// - output      : `[batch, sequence, d_model]`
521    /// - cache (out) : updated convolution window and SSM state
522    #[allow(non_snake_case)]
523    pub fn forward(
524        &self,
525        input_bsm: Tensor<B, 3>,
526        cache: Option<Mamba2Cache<B>>,
527        ssd_path: Mamba2SsdPath,
528    ) -> (Tensor<B, 3>, Mamba2Cache<B>) {
529        let [batch, sequence, _d_model] = input_bsm.dims();
530        let d_inner = self.d_inner();
531        let ngroups = self.ngroups;
532        let nheads = self.nheads();
533        let per_head_dim = self.per_head_dim();
534        let conv_dim = self.conv_dim();
535        let state_rank = self.state_rank;
536        let [_conv_dim, _, conv_kernel] = self.conv1d.weight.dims();
537        let [_d_model, d_in_proj_out] = self.in_proj.weight.dims();
538        let device = input_bsm.device();
539        assert_eq!(conv_dim, _conv_dim);
540        assert_ne!(ngroups, 0);
541        assert_eq!(conv_dim, _conv_dim);
542        assert_eq!(nheads % ngroups, 0);
543        assert!(sequence > 0, "sequence length must be at least 1");
544        san(&input_bsm);
545
546        // ── Initialise cache if not provided ──────────────────────────────────
547        let mut cache = cache.unwrap_or_else(|| {
548            let conv_bvk = Tensor::zeros(Shape::new([batch, conv_dim, conv_kernel]), &device);
549            let ssm_bhpr = Tensor::zeros(
550                Shape::new([batch, nheads, per_head_dim, state_rank]),
551                &device,
552            );
553            Mamba2Cache { conv_bvk, ssm_bhpr }
554        });
555        cache.sanity();
556
557        // ── Step 1: In-projection ─────────────────────────────────────────────
558        // One linear layer projects the input to all SSM parameters at once.
559        // This "parallel projection" structure (vs. Mamba-1's sequential
560        // projections) enables tensor parallelism with only 1 all-reduce per
561        // layer instead of 2.
562        //
563        // Projection output:  [z | xbc | dt_raw]
564        //   z      [B, T, d_inner]  — gate for the output RMSNorm
565        //   xbc    [B, T, conv_dim] — will become (x, B, C) after conv + split
566        //   dt_raw [B, T, nheads]   — raw discretisation step (pre-softplus)
567        let (z_gate_bsi, xbc_bsv, dt_raw_bsh) = {
568            let z_xbc_dt_bsd = self.in_proj.forward(input_bsm);
569            assert_eq!([batch, sequence, d_in_proj_out], z_xbc_dt_bsd.dims());
570            assert_eq!(
571                [batch, sequence, d_inner + conv_dim + nheads],
572                z_xbc_dt_bsd.dims(),
573            );
574
575            let mut parts = z_xbc_dt_bsd
576                .split_with_sizes(vec![d_inner, conv_dim, nheads], 2)
577                .into_iter();
578            (
579                parts.next().unwrap(), // z      [B, T, d_inner]
580                parts.next().unwrap(), // xbc    [B, T, conv_dim]
581                parts.next().unwrap(), // dt_raw [B, T, nheads]
582            )
583        };
584        assert_eq!([batch, sequence, d_inner], z_gate_bsi.dims());
585        assert_eq!([batch, sequence, conv_dim], xbc_bsv.dims());
586        assert_eq!([batch, sequence, nheads], dt_raw_bsh.dims());
587        san(&z_gate_bsi);
588        san(&xbc_bsv);
589        san(&dt_raw_bsh);
590
591        // ── Step 2: Causal depthwise Conv1d ───────────────────────────────────
592        // Apply the causal 1-D depthwise convolution to `xbc`.  To maintain
593        // strict causality, the input is left-padded with the last
594        // `(conv_kernel - 1)` columns from the cache (the tail of the previous
595        // chunk), giving a padded input of length `(conv_kernel-1) + sequence`.
596        // After the convolution the output has length `sequence` (Valid padding).
597        //
598        // The right-most `conv_kernel` columns of the padded input become the
599        // new convolution cache for the next call.
600        let xbc_bvs = xbc_bsv.permute([0, 2, 1]); // [B, conv_dim, T]
601        assert_eq!([batch, conv_dim, sequence], xbc_bvs.dims());
602
603        // Build the causally-padded input: [cached tail | new input]
604        let xbc_padded_bvS = if conv_kernel >= 2 {
605            // Drop the oldest (leftmost) element of the cache, keeping the
606            // last (conv_kernel - 1) columns.
607            let tail_bvK = cache.conv_bvk.slice(s![.., .., 1..]);
608            assert_eq!([batch, conv_dim, conv_kernel - 1], tail_bvK.dims());
609            Tensor::cat(vec![tail_bvK, xbc_bvs], 2)
610        } else {
611            // conv_kernel == 1: no causal padding needed.
612            xbc_bvs
613        };
614        assert_eq!(
615            [batch, conv_dim, (conv_kernel - 1) + sequence],
616            xbc_padded_bvS.dims()
617        );
618        san(&xbc_padded_bvS);
619
620        // Update the cache: save the last `conv_kernel` columns of the padded
621        // input (i.e. starting at position `sequence - 1` from the new input).
622        cache.conv_bvk = xbc_padded_bvS.clone().slice(s![.., .., (sequence - 1)..]);
623        assert_eq!([batch, conv_dim, conv_kernel], cache.conv_bvk.dims());
624
625        // Apply the depthwise convolution and transpose back to [B, T, conv_dim].
626        let xbc_bvs = self.conv1d.forward(xbc_padded_bvS);
627        assert_eq!([batch, conv_dim, sequence], xbc_bvs.dims());
628        san(&xbc_bvs);
629
630        let xbc_bsv = xbc_bvs.permute([0, 2, 1]); // [B, T, conv_dim]
631        assert_eq!([batch, sequence, conv_dim], xbc_bsv.dims());
632
633        // SiLU activation (element-wise).
634        let xbc_bsv = Silu::new().forward(xbc_bsv);
635        assert_eq!([batch, sequence, conv_dim], xbc_bsv.dims());
636        san(&xbc_bsv);
637
638        // ── Step 3: Split xbc into (x, B, C) ──────────────────────────────────
639        // After activation, xbc is partitioned along the channel dimension:
640        //   x [B, T, d_inner]          → reshaped to [B, T, H, P]   (input)
641        //   B [B, T, ngroups·N]        → reshaped to [B, T, G, N]   (state input proj)
642        //   C [B, T, ngroups·N]        → reshaped to [B, T, G, N]   (state output proj)
643        //
644        // Note: in the SSM/attention duality, C ↔ Q, B ↔ K, x ↔ V.
645        let (x_bshp, b_bsgr, c_bsgr) = {
646            let mut parts = xbc_bsv
647                .split_with_sizes(vec![d_inner, ngroups * state_rank, ngroups * state_rank], 2)
648                .into_iter();
649            (
650                parts
651                    .next()
652                    .unwrap() // [B, T, d_inner]
653                    .reshape([batch, sequence, nheads, per_head_dim]),
654                parts
655                    .next()
656                    .unwrap() // [B, T, ngroups·N]
657                    .reshape([batch, sequence, ngroups, state_rank]),
658                parts
659                    .next()
660                    .unwrap() // [B, T, ngroups·N]
661                    .reshape([batch, sequence, ngroups, state_rank]),
662            )
663        };
664        // No shape assertions on reshapes (shapes are algebraically guaranteed).
665
666        // ── Step 4: Discretisation ────────────────────────────────────────────
667        // Compute the discrete-time parameters from the continuous-time A
668        // and input-dependent step size Δ (ZOH discretisation, §4.5):
669        //
670        //   Δₜ = softplus(dt_raw_t + dt_bias)     ∈ (0, ∞)
671        //   Āₜ = exp(Δₜ · A)                       ∈ (0, 1)   [scalar per head]
672        //   B̄ₜ = Δₜ · Bₜ                           ∈ ℝᴺ       [Euler approx]
673        //
674        // The Euler approximation B̄ ≈ ΔB (instead of the exact ZOH formula)
675        // is standard in Mamba-1 and Mamba-2 (see §4.5 of the reference).
676        //
677        // `a_head_decay_h` = A = -exp(a_log) < 0 (negative, one scalar per head).
678        // Note: we pass this negative value to `chunked_selective_scan`; inside
679        // that function it is multiplied by Δ > 0, giving a negative exponent
680        // which produces Āₜ = exp(Δₜ·A) ∈ (0,1) as required.
681        let dt_bias_11h = self.dt_bias_h.val().unsqueeze_dims(&[0, 1]);
682        assert_eq!([1, 1, nheads], dt_bias_11h.dims());
683
684        let dt_bsh = softplus(dt_raw_bsh + dt_bias_11h).clamp(self.dt_limit.0, self.dt_limit.1);
685        assert_eq!([batch, sequence, nheads], dt_bsh.dims());
686        san(&dt_bsh);
687
688        let a_head_decay_h = -self.a_log_h.val().exp(); // A = -exp(a_log) < 0
689        assert_eq!([nheads], a_head_decay_h.dims());
690        san(&a_head_decay_h);
691
692        // ── Step 5: Pad sequence to a multiple of chunk_len ───────────────────
693        // Zeros are the correct pad: Δ=0  ⇒  Ā=exp(0·A)=1, B̄=0·B=0.
694        // The state is thus carried through unchanged, so the final state of
695        // the padded last chunk equals the state after the last real token.
696        let chunk_len = ssd_path.chunk_len_or_optimal(state_rank, per_head_dim);
697        assert!(chunk_len > 0);
698        let sequence_padded = sequence.next_multiple_of(chunk_len);
699        let pad = sequence_padded - sequence;
700        let (x_bShp, dt_bSh, b_bSgr, c_bSgr) = if pad == 0 {
701            (x_bshp.clone(), dt_bsh, b_bsgr, c_bsgr)
702        } else {
703            let x_bshp = Tensor::cat(
704                vec![
705                    x_bshp.clone(),
706                    Tensor::zeros(Shape::new([batch, pad, nheads, per_head_dim]), &device),
707                ],
708                1,
709            );
710            let dt_bsh = Tensor::cat(
711                vec![
712                    dt_bsh,
713                    Tensor::zeros(Shape::new([batch, pad, nheads]), &device),
714                ],
715                1,
716            );
717            let b_bsgr = Tensor::cat(
718                vec![
719                    b_bsgr,
720                    Tensor::zeros(Shape::new([batch, pad, ngroups, state_rank]), &device),
721                ],
722                1,
723            );
724            let c_bsgr = Tensor::cat(
725                vec![
726                    c_bsgr,
727                    Tensor::zeros(Shape::new([batch, pad, ngroups, state_rank]), &device),
728                ],
729                1,
730            );
731            (x_bshp.clone(), dt_bsh, b_bsgr, c_bsgr)
732        };
733
734        // ── Reshapes into chunks ───────────────────────────────────────────────
735        let nchunks = sequence_padded / chunk_len;
736        let x_bnlhp = x_bShp.reshape([batch, nchunks, chunk_len, nheads, per_head_dim]);
737        let dt_bnlh = dt_bSh.reshape([batch, nchunks, chunk_len, nheads]);
738        let b_bnlgr = b_bSgr.reshape([batch, nchunks, chunk_len, ngroups, state_rank]);
739        let c_bnlgr = c_bSgr.reshape([batch, nchunks, chunk_len, ngroups, state_rank]);
740
741        // ── Step 6: Selective Scan ────────────────────────────────────────────
742        let ssd_input = crate::mamba2::ssd::Mamba2SsdInput {
743            x_bnlhp,
744            dt_bnlh,
745            a_decay_h: a_head_decay_h,
746            b_bnlgr,
747            c_bnlgr,
748            d_h: self.d_h.val(),
749            initial_state_bhpr: cache.ssm_bhpr,
750            init_state_hpr: self.init_state_hpr.as_ref().map(|s| s.val()),
751        };
752        ssd_input.sanity();
753        let (y_bnlhp, final_state_bhpr) = match ssd_path {
754            Mamba2SsdPath::Minimal(_chunk_len) => Self::ssd_minimal(ssd_input),
755            Mamba2SsdPath::Serial(_chunk_len) => Self::ssd_serial(ssd_input),
756            Mamba2SsdPath::SerialRecalculated(_chunk_len) => {
757                Self::ssd_serial_recalculated(ssd_input)
758            }
759        };
760        assert_eq!(
761            [batch, nchunks, chunk_len, nheads, per_head_dim],
762            y_bnlhp.dims()
763        );
764        san(&y_bnlhp);
765        san(&final_state_bhpr);
766
767        // Update the SSM state in the cache for the next call.
768        cache.ssm_bhpr = final_state_bhpr;
769        assert_eq!(
770            [batch, nheads, per_head_dim, state_rank],
771            cache.ssm_bhpr.dims()
772        );
773
774        // Remove zero-pad columns that were added at Step 5.
775        let y_bShp = y_bnlhp.reshape([batch, sequence_padded, nheads, per_head_dim]);
776        let y_bshp = if pad == 0 {
777            y_bShp
778        } else {
779            y_bShp.slice(s![.., 0..sequence, .., ..])
780        };
781
782        // Reshape into sequence.
783        let y_bsi = y_bshp.reshape([batch, sequence, d_inner]);
784        assert_eq!([batch, sequence, d_inner], y_bsi.dims());
785
786        // ── Step 7: Gated RMSNorm ─────────────────────────────────────────────
787        let y_bsi = self.norm.forward(y_bsi, z_gate_bsi);
788        assert_eq!([batch, sequence, d_inner], y_bsi.dims());
789        san(&y_bsi);
790
791        // ── Step 8: Out-projection ────────────────────────────────────────────
792        let out_bsm = self.out_proj.forward(y_bsi);
793        assert_eq!([batch, sequence, _d_model], out_bsm.dims());
794        san(&out_bsm);
795
796        (out_bsm, cache)
797    }
798}
799
800// ---------------------------------------------------------------------------
801// Mamba2::step  (recurrent SSM — token-by-token decoding)
802// ---------------------------------------------------------------------------
803
804mod step {
805    use super::*;
806
807    impl<B: Backend> Mamba2<B> {
808        /// Process a **single token** using the pure recurrent SSM form.
809        ///
810        /// This is the O(H·P·N)-per-token decoding path.  It runs one tick of
811        /// the discretised Mamba-2 recurrence:
812        ///
813        /// ```text
814        ///   Āₜ  = exp(Δₜ · A)           scalar per head, ∈ (0, 1)
815        ///   B̄ₜ  = Δₜ · Bₜ               ∈ ℝᴺ   (Euler discretisation)
816        ///   hₜ  = Āₜ · hₜ₋₁ + B̄ₜ · xₜᵀ  ∈ ℝ^{P×N}   (outer product update)
817        ///   yₜ  = Cₜᵀ · hₜ + D · xₜ     ∈ ℝᴾ   (output)
818        /// ```
819        ///
820        /// The convolution is handled by manually sliding the cache window:
821        /// the oldest input column is dropped and the new token's projection
822        /// is appended.
823        ///
824        /// The SSM hidden state `cache.ssm_bhpr` is updated in-place via
825        /// the recurrence above.
826        ///
827        /// # Shapes
828        /// - `input_bm` : `[batch, d_model]`
829        /// - output     : `[batch, d_model]`
830        #[allow(non_snake_case)]
831        pub fn step(
832            &self,
833            input_bm: Tensor<B, 2>,
834            cache: Option<Mamba2Cache<B>>,
835        ) -> (Tensor<B, 2>, Mamba2Cache<B>) {
836            let [batch, d_model] = input_bm.dims();
837            let d_inner = self.d_inner();
838            let ngroups = self.ngroups;
839            let nheads = self.nheads();
840            let per_head_dim = self.per_head_dim();
841            let conv_dim = self.conv_dim();
842            let state_rank = self.state_rank;
843            let [_conv_dim, _, conv_kernel] = self.conv1d.weight.dims();
844            let [_d_model, d_in_proj_out] = self.in_proj.weight.dims();
845
846            assert_eq!(conv_dim, _conv_dim);
847            assert_eq!(nheads % ngroups, 0);
848
849            let mut cache = cache.unwrap_or_else(|| {
850                let device = &input_bm.device();
851                let conv_bvk = Tensor::zeros(Shape::new([batch, conv_dim, conv_kernel]), device);
852                let ssm_bhpr = Tensor::zeros(
853                    Shape::new([batch, nheads, per_head_dim, state_rank]),
854                    device,
855                );
856                Mamba2Cache { conv_bvk, ssm_bhpr }
857            });
858
859            // ── In-projection ─────────────────────────────────────────────────
860            let (z_gate_bi, xbc_bv, dt_raw_bh) = {
861                let z_xbc_dt_bd = self.in_proj.forward(input_bm);
862                assert_eq!([batch, d_in_proj_out], z_xbc_dt_bd.dims());
863                assert_eq!([batch, d_inner + conv_dim + nheads], z_xbc_dt_bd.dims());
864
865                let mut parts = z_xbc_dt_bd
866                    .split_with_sizes(vec![d_inner, conv_dim, nheads], 1)
867                    .into_iter();
868                (
869                    parts.next().unwrap(), // z  [B, d_inner]
870                    parts.next().unwrap(), // xbc[B, conv_dim]
871                    parts.next().unwrap(), // dt [B, nheads]
872                )
873            };
874            assert_eq!([batch, d_inner], z_gate_bi.dims());
875            assert_eq!([batch, conv_dim], xbc_bv.dims());
876            assert_eq!([batch, nheads], dt_raw_bh.dims());
877
878            // ── Causal convolution (single step) ──────────────────────────────
879            // Slide the cache window left by one position, then insert the new
880            // token's projection `xbc_bv` as the rightmost column.
881            cache.conv_bvk = {
882                let conv_bvk = cache.conv_bvk;
883                assert_eq!([batch, conv_dim, conv_kernel], conv_bvk.dims());
884
885                // Drop the oldest (leftmost) column.
886                let tail_bvK = conv_bvk.slice(s![.., .., 1..]);
887                assert_eq!([batch, conv_dim, conv_kernel - 1], tail_bvK.dims());
888
889                // Append the new token as the rightmost column.
890                let updated_bvk = Tensor::cat([tail_bvK, xbc_bv.unsqueeze_dim(2)].to_vec(), 2);
891                assert_eq!([batch, conv_dim, conv_kernel], updated_bvk.dims());
892                updated_bvk
893            };
894
895            // Apply the depthwise convolution manually (one step = dot product
896            // of the cached window with the conv weight along the kernel axis).
897            let xbc_bv = {
898                let conv1d_v1k = self.conv1d.weight.val(); // [conv_dim, 1, conv_kernel]
899                assert_eq!([conv_dim, 1, conv_kernel], conv1d_v1k.dims());
900
901                let conv1d_bvk = conv1d_v1k
902                    .permute([1, 0, 2]) // [1, conv_dim, conv_kernel]
903                    .expand([batch, conv_dim, conv_kernel]);
904                assert_eq!([batch, conv_dim, conv_kernel], conv1d_bvk.dims());
905
906                // Element-wise multiply and sum over the kernel axis.
907                let product_bvk = cache.conv_bvk.clone() * conv1d_bvk;
908                let mut xbc_bv = product_bvk.sum_dim(2).squeeze_dim(2);
909                assert_eq!([batch, conv_dim], xbc_bv.dims());
910
911                // Add the (optional) bias.
912                if let Some(bias_v) = &self.conv1d.bias {
913                    assert_eq!([conv_dim], bias_v.dims());
914                    xbc_bv = xbc_bv + bias_v.val().unsqueeze();
915                }
916
917                Silu::new().forward(xbc_bv)
918            };
919            assert_eq!([batch, conv_dim], xbc_bv.dims());
920
921            // ── Split (x, B, C) ───────────────────────────────────────────────
922            assert_eq!(d_inner, nheads * per_head_dim);
923            let (x_bhp, b_bgr, c_bgr) = {
924                let mut parts = xbc_bv
925                    .split_with_sizes(vec![d_inner, ngroups * state_rank, ngroups * state_rank], 1)
926                    .into_iter();
927                (
928                    parts
929                        .next()
930                        .unwrap() // [B, d_inner]
931                        .reshape([batch, nheads, per_head_dim]),
932                    parts
933                        .next()
934                        .unwrap() // [B, ngroups·N]
935                        .reshape([batch, ngroups, state_rank]),
936                    parts
937                        .next()
938                        .unwrap() // [B, ngroups·N]
939                        .reshape([batch, ngroups, state_rank]),
940                )
941            };
942
943            // ── Discretisation ────────────────────────────────────────────────
944            // Δₜ = softplus(dt_raw + dt_bias)
945            let dt_bias_1h = self.dt_bias_h.val().unsqueeze_dim(0);
946            assert_eq!([1, nheads], dt_bias_1h.dims());
947            let dt_bh = softplus(dt_raw_bh + dt_bias_1h).clamp(self.dt_limit.0, self.dt_limit.1);
948            assert_eq!([batch, nheads], dt_bh.dims());
949
950            // A = -exp(a_log) < 0   (negative, decaying)
951            let a_head_decay_h = -self.a_log_h.val().exp();
952            assert_eq!([nheads], a_head_decay_h.dims());
953
954            // Āₜ = exp(Δₜ · A) ∈ (0, 1)   scalar per [B, H]
955            let dta_bh = (dt_bh.clone() * a_head_decay_h.unsqueeze()).exp();
956            assert_eq!([batch, nheads], dta_bh.dims());
957
958            // ── SSM state update:  hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜᵀ ───────────────────
959            // The cache holds h_{t-1} with shape [B, H, P, N].
960            // Āₜ is a scalar per head, so we broadcast it over P and N.
961            // B̄ₜ xₜᵀ is an outer product producing a [P, N] matrix per [B, H].
962
963            let ssm_shape_bhpr = [batch, nheads, per_head_dim, state_rank];
964
965            let dta_bhpr = dta_bh.unsqueeze_dims::<4>(&[2, 3]).expand(ssm_shape_bhpr); // [B, H, P, N]
966
967            // B̄ₜ xₜᵀ = (Δₜ Bₜ) xₜᵀ:
968            //   x:  [B, H, P]     → broadcast to [B, H, P, N]
969            //   B:  [B, G, N]     → expand to [B, H, N]  → broadcast to [B, H, P, N]
970            //   Δ:  [B, H]        → broadcast to [B, H, P, N]
971            let heads_per_group = nheads / ngroups;
972            let dtbx_bhpr = {
973                let x_bhpr = x_bhp.clone().unsqueeze_dim::<4>(3).expand(ssm_shape_bhpr);
974
975                // Expand B from [B, G, N] → [B, H, P, N], matching the SSD forward path:
976                // each group's projection is replicated across the heads_per_group heads of
977                // that group so that heads 0..(H/G) belong to group 0, etc.
978                let b_bhpr = b_bgr
979                    .unsqueeze_dim::<4>(2) // [B, G, 1, N]
980                    .expand([batch, ngroups, heads_per_group, state_rank]) // [B, G, H/G, N]
981                    .reshape([batch, nheads, state_rank]) // [B, H, N]
982                    .unsqueeze_dim::<4>(2) // [B, H, 1, N]
983                    .expand(ssm_shape_bhpr); // [B, H, P, N]
984
985                let dt_bhpr = dt_bh.unsqueeze_dims::<4>(&[2, 3]).expand(ssm_shape_bhpr);
986
987                dt_bhpr * b_bhpr * x_bhpr // B̄ₜ xₜᵀ  [B, H, P, N]
988            };
989            assert_eq!(ssm_shape_bhpr, dtbx_bhpr.dims());
990
991            // hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜᵀ
992            cache.ssm_bhpr = cache.ssm_bhpr * dta_bhpr + dtbx_bhpr;
993            assert_eq!(ssm_shape_bhpr, cache.ssm_bhpr.dims());
994
995            // ── Output:  yₜ = Cₜ hₜ + D xₜ ──────────────────────────────────
996            let y_bi = {
997                // Cₜ hₜ:  element-wise multiply C (broadcast to [B, H, P, N])
998                // with h_t, then sum over N.
999                let c_bhpr = c_bgr
1000                    .unsqueeze_dim::<4>(2) // [B, G, 1, N]
1001                    .expand([batch, ngroups, heads_per_group, state_rank]) // [B, G, H/G, N]
1002                    .reshape([batch, nheads, state_rank]) // [B, H, N]
1003                    .unsqueeze_dim::<4>(2) // [B, H, 1, N]
1004                    .expand(ssm_shape_bhpr); // [B, H, P, N]
1005                assert_eq!(ssm_shape_bhpr, c_bhpr.dims());
1006
1007                let ch_bhp = (cache.ssm_bhpr.clone() * c_bhpr).sum_dim(3).squeeze_dim(3); // sum over N → [B, H, P]
1008                assert_eq!([batch, nheads, per_head_dim], ch_bhp.dims());
1009
1010                // D xₜ:  per-head scalar skip.
1011                let d_1h1 = self.d_h.val().unsqueeze_dims(&[0, 2]);
1012                assert_eq!([1, nheads, 1], d_1h1.dims());
1013                let skip_bhp = d_1h1.expand([batch, nheads, per_head_dim]) * x_bhp;
1014                assert_eq!([batch, nheads, per_head_dim], skip_bhp.dims());
1015
1016                let y_bhp = ch_bhp + skip_bhp;
1017                assert_eq!([batch, nheads, per_head_dim], y_bhp.dims());
1018
1019                // Flatten heads → [B, d_inner], then apply gated RMSNorm.
1020                let y_bi = y_bhp.reshape([batch, d_inner]);
1021                self.norm.forward(y_bi, z_gate_bi)
1022            };
1023            assert_eq!([batch, d_inner], y_bi.dims());
1024
1025            // ── Out-projection ────────────────────────────────────────────────
1026            let out_bm = self.out_proj.forward(y_bi);
1027            assert_eq!([batch, d_model], out_bm.dims());
1028
1029            (out_bm, cache)
1030        }
1031    }
1032}
1033
1034// ---------------------------------------------------------------------------
1035// Tests
1036// ---------------------------------------------------------------------------
1037
1038#[cfg(all(test, feature = "backend-flex"))]
1039mod tests {
1040    use super::*;
1041    use burn::backend::{Autodiff, Flex};
1042    use burn::tensor::Distribution;
1043
1044    /// Inner (non-autodiff) backend used for materialising values and
1045    /// extracted gradients.
1046    type InnerB = Flex;
1047    /// Autodiff-wrapped backend used to drive `.backward()`.
1048    type B = Autodiff<InnerB>;
1049
1050    type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
1051
1052    fn small_config() -> Mamba2Config {
1053        Mamba2Config::new(32)
1054            .with_state_rank(8)
1055            .with_expand(2)
1056            .with_per_head_dim(8)
1057    }
1058
1059    /// A bundle of input + model-parameter gradients extracted from one
1060    /// forward+backward run.  Each `check_grads_match` call compares these
1061    /// across two runs that should be mathematically equivalent.
1062    struct RunGrads {
1063        out: Tensor<InnerB, 3>,
1064        d_input: Tensor<InnerB, 3>,
1065        d_in_proj_w: Tensor<InnerB, 2>,
1066        d_conv1d_w: Tensor<InnerB, 3>,
1067        d_dt_bias: Tensor<InnerB, 1>,
1068        d_a_log: Tensor<InnerB, 1>,
1069        d_d: Tensor<InnerB, 1>,
1070        d_norm_gamma: Tensor<InnerB, 1>,
1071        d_out_proj_w: Tensor<InnerB, 2>,
1072    }
1073
1074    /// Run a closure that produces an output tensor from a model and an input
1075    /// (wrapped as a `Param` so it has its own autodiff leaf), then derive a
1076    /// scalar loss with a fixed (non-tracked) random "head" and return the
1077    /// gradients of the input and a representative set of model parameters.
1078    fn run_with_grads(
1079        model: &Mamba2<B>,
1080        input: &Param<Tensor<B, 3>>,
1081        head: &Tensor<InnerB, 3>,
1082        forward: impl FnOnce(&Mamba2<B>, Tensor<B, 3>) -> Tensor<B, 3>,
1083    ) -> RunGrads {
1084        let out = forward(model, input.val());
1085        let out_inner = out.clone().inner();
1086
1087        let head = Tensor::from_inner(head.clone());
1088        let loss = (out * head).sum();
1089        let grads = loss.backward();
1090
1091        RunGrads {
1092            out: out_inner,
1093            d_input: input.val().grad(&grads).expect("grad input"),
1094            d_in_proj_w: model
1095                .in_proj
1096                .weight
1097                .val()
1098                .grad(&grads)
1099                .expect("grad in_proj.weight"),
1100            d_conv1d_w: model
1101                .conv1d
1102                .weight
1103                .val()
1104                .grad(&grads)
1105                .expect("grad conv1d.weight"),
1106            d_dt_bias: model.dt_bias_h.val().grad(&grads).expect("grad dt_bias_h"),
1107            d_a_log: model.a_log_h.val().grad(&grads).expect("grad a_log_h"),
1108            d_d: model.d_h.val().grad(&grads).expect("grad d_h"),
1109            d_norm_gamma: model
1110                .norm
1111                .gamma
1112                .val()
1113                .grad(&grads)
1114                .expect("grad norm.gamma"),
1115            d_out_proj_w: model
1116                .out_proj
1117                .weight
1118                .val()
1119                .grad(&grads)
1120                .expect("grad out_proj.weight"),
1121        }
1122    }
1123
1124    /// Assert that every entry in `a` and `b` agrees to within `grad_tol`,
1125    /// printing every comparison so a failure dump shows the full picture
1126    /// (instead of stopping at the first mismatch).
1127    fn check_grads_match(label: &str, a: &RunGrads, b: &RunGrads, grad_tol: f32) {
1128        let mut failures: Vec<String> = Vec::new();
1129        macro_rules! check {
1130            ($field:ident, $name:expr) => {{
1131                let d = (a.$field.clone() - b.$field.clone())
1132                    .abs()
1133                    .max()
1134                    .into_scalar();
1135                eprintln!("{:>40} {:>16} | max abs diff = {:>10.6}", label, $name, d);
1136                if d >= grad_tol {
1137                    failures.push(format!(
1138                        "{}: grad of {} max abs diff = {:.6} (tol {})",
1139                        label, $name, d, grad_tol
1140                    ));
1141                }
1142            }};
1143        }
1144        check!(d_input, "input");
1145        check!(d_in_proj_w, "in_proj.weight");
1146        check!(d_conv1d_w, "conv1d.weight");
1147        check!(d_dt_bias, "dt_bias_h");
1148        check!(d_a_log, "a_log_h");
1149        check!(d_d, "d_h");
1150        check!(d_norm_gamma, "norm.gamma");
1151        check!(d_out_proj_w, "out_proj.weight");
1152        assert!(
1153            failures.is_empty(),
1154            "gradient mismatches:\n  {}",
1155            failures.join("\n  ")
1156        );
1157    }
1158
1159    /// Helper that builds a fresh `Param<Tensor>` from a stable inner tensor.
1160    /// A new Param is needed per run so that the autodiff leaf has a fresh
1161    /// node, isolating each backward pass to its own forward graph.
1162    fn param_input(input: &Tensor<InnerB, 3>) -> Param<Tensor<B, 3>> {
1163        Param::from_tensor(Tensor::from_inner(input.clone()))
1164    }
1165
1166    fn run_step_matches_forward(cfg: Mamba2Config, ssd_path: Mamba2SsdPath) {
1167        let device: Device = Default::default();
1168        let model = cfg.init::<B>(&device);
1169
1170        let batch = 2;
1171        let seq_len = 5;
1172        let d_model = cfg.d_model;
1173
1174        let input = Tensor::<InnerB, 3>::random(
1175            [batch, seq_len, d_model],
1176            Distribution::Normal(0.0, 1.0),
1177            &device,
1178        );
1179        let head = Tensor::<InnerB, 3>::random(
1180            [batch, seq_len, d_model],
1181            Distribution::Normal(0.0, 1.0),
1182            &device,
1183        );
1184
1185        let input_fwd = param_input(&input);
1186        let r_fwd = run_with_grads(&model, &input_fwd, &head, |m, x| {
1187            let (out, _) = m.forward(x, None, ssd_path.clone());
1188            out
1189        });
1190
1191        let input_step = param_input(&input);
1192        let r_step = run_with_grads(&model, &input_step, &head, |m, x| {
1193            let mut cache: Option<Mamba2Cache<B>> = None;
1194            let mut outs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
1195            for t in 0..seq_len {
1196                let token = x.clone().narrow(1, t, 1).squeeze_dim(1);
1197                let (out_t, new_cache) = m.step(token, cache);
1198                cache = Some(new_cache);
1199                outs.push(out_t);
1200            }
1201            Tensor::stack(outs, 1)
1202        });
1203
1204        // ── Forward agreement (existing check) ───────────────────────────
1205        let diff = (r_fwd.out.clone() - r_step.out.clone())
1206            .abs()
1207            .max()
1208            .into_scalar();
1209        assert!(
1210            diff < 1e-4,
1211            "step() vs forward() max absolute difference = {diff:.6} (expected < 1e-4)"
1212        );
1213
1214        // ── Gradient agreement ───────────────────────────────────────────
1215        // step() and forward() are different reductions of the same SSM,
1216        // so their per-parameter gradients should also agree, modulo
1217        // float-summation order noise.
1218        check_grads_match("step vs forward", &r_fwd, &r_step, 1e-3);
1219    }
1220
1221    #[test]
1222    fn step_matches_forward() {
1223        run_step_matches_forward(small_config(), Mamba2SsdPath::Minimal(Some(4)));
1224    }
1225
1226    #[test]
1227    fn step_matches_forward_ngroups2() {
1228        let cfg = Mamba2Config::new(32)
1229            .with_state_rank(8)
1230            .with_expand(2)
1231            .with_per_head_dim(16)
1232            .with_ngroups(2);
1233        run_step_matches_forward(cfg, Mamba2SsdPath::Minimal(Some(4)));
1234    }
1235
1236    /// forward(full) ≡ forward(prefix) then forward(suffix, cache_from_prefix).
1237    ///
1238    /// Verifies stateful chunked-prefill: the convolution window carried in the
1239    /// cache must replay correctly at the start of the second segment.
1240    fn run_split_matches_full(cfg: Mamba2Config, ssd_path: Mamba2SsdPath) {
1241        let device: Device = Default::default();
1242        let model = cfg.init::<B>(&device);
1243
1244        let batch = 2;
1245        let seq_len = 6;
1246        let split = 2;
1247        let d_model = cfg.d_model;
1248
1249        let input = Tensor::<InnerB, 3>::random(
1250            [batch, seq_len, d_model],
1251            Distribution::Normal(0.0, 1.0),
1252            &device,
1253        );
1254        let head = Tensor::<InnerB, 3>::random(
1255            [batch, seq_len, d_model],
1256            Distribution::Normal(0.0, 1.0),
1257            &device,
1258        );
1259
1260        let input_full = param_input(&input);
1261        let r_full = run_with_grads(&model, &input_full, &head, |m, x| {
1262            let (out, _) = m.forward(x, None, ssd_path.clone());
1263            out
1264        });
1265
1266        let input_split = param_input(&input);
1267        let r_split = run_with_grads(&model, &input_split, &head, |m, x| {
1268            let prefix = x.clone().narrow(1, 0, split);
1269            let suffix = x.narrow(1, split, seq_len - split);
1270            let (out_prefix, cache) = m.forward(prefix, None, ssd_path.clone());
1271            let (out_suffix, _) = m.forward(suffix, Some(cache), ssd_path.clone());
1272            Tensor::cat(vec![out_prefix, out_suffix], 1)
1273        });
1274
1275        // ── Forward agreement (existing check) ───────────────────────────
1276        let diff = (r_full.out.clone() - r_split.out.clone())
1277            .abs()
1278            .max()
1279            .into_scalar();
1280        assert!(
1281            diff < 1e-4,
1282            "split forward vs full forward max absolute difference = {diff:.6} (expected < 1e-4)"
1283        );
1284
1285        // ── Gradient agreement ───────────────────────────────────────────
1286        check_grads_match("split vs full", &r_full, &r_split, 1e-3);
1287    }
1288
1289    #[test]
1290    fn split_matches_full() {
1291        run_split_matches_full(small_config(), Mamba2SsdPath::Minimal(Some(4)));
1292    }
1293
1294    #[test]
1295    fn split_matches_full_ngroups2() {
1296        let cfg = Mamba2Config::new(32)
1297            .with_state_rank(8)
1298            .with_expand(2)
1299            .with_per_head_dim(16)
1300            .with_ngroups(2);
1301        run_split_matches_full(cfg, Mamba2SsdPath::Minimal(Some(4)));
1302    }
1303
1304    // ── is_norm_before_gate = true ───────────────────────────────────────────
1305
1306    #[test]
1307    fn step_matches_forward_norm_before_gate() {
1308        let cfg = Mamba2Config::new(32)
1309            .with_state_rank(8)
1310            .with_expand(2)
1311            .with_per_head_dim(8)
1312            .with_is_norm_before_gate(true);
1313        run_step_matches_forward(cfg, Mamba2SsdPath::Minimal(Some(4)));
1314    }
1315
1316    #[test]
1317    fn split_matches_full_norm_before_gate() {
1318        let cfg = Mamba2Config::new(32)
1319            .with_state_rank(8)
1320            .with_expand(2)
1321            .with_per_head_dim(8)
1322            .with_is_norm_before_gate(true);
1323        run_split_matches_full(cfg, Mamba2SsdPath::Minimal(Some(4)));
1324    }
1325
1326    // ── SSD path agreement ───────────────────────────────────────────────────
1327
1328    fn run_ssd_paths_agree(cfg: Mamba2Config) {
1329        let device: Device = Default::default();
1330        let model = cfg.init::<B>(&device);
1331
1332        let batch = 2;
1333        let seq_len = 8;
1334        let d_model = cfg.d_model;
1335
1336        let input = Tensor::<InnerB, 3>::random(
1337            [batch, seq_len, d_model],
1338            Distribution::Normal(0.0, 1.0),
1339            &device,
1340        );
1341        let head = Tensor::<InnerB, 3>::random(
1342            [batch, seq_len, d_model],
1343            Distribution::Normal(0.0, 1.0),
1344            &device,
1345        );
1346
1347        // Each path gets its own input Param so the autodiff leaves are
1348        // independent across the three backward passes.
1349        let input_min = param_input(&input);
1350        let input_ser = param_input(&input);
1351        let input_rec = param_input(&input);
1352
1353        let r_min = run_with_grads(&model, &input_min, &head, |m, x| {
1354            let (out, _) = m.forward(x, None, Mamba2SsdPath::Minimal(Some(4)));
1355            out
1356        });
1357        let r_ser = run_with_grads(&model, &input_ser, &head, |m, x| {
1358            let (out, _) = m.forward(x, None, Mamba2SsdPath::Serial(Some(4)));
1359            out
1360        });
1361        let r_rec = run_with_grads(&model, &input_rec, &head, |m, x| {
1362            let (out, _) = m.forward(x, None, Mamba2SsdPath::SerialRecalculated(Some(4)));
1363            out
1364        });
1365
1366        // ── Forward agreement ────────────────────────────────────────────
1367        let tol = 1e-4;
1368        let d_ser = (r_min.out.clone() - r_ser.out.clone())
1369            .abs()
1370            .max()
1371            .into_scalar();
1372        let d_rec = (r_min.out.clone() - r_rec.out.clone())
1373            .abs()
1374            .max()
1375            .into_scalar();
1376        assert!(
1377            d_ser < tol,
1378            "Minimal vs Serial: forward max abs diff = {d_ser:.6} (tol {tol})"
1379        );
1380        assert!(
1381            d_rec < tol,
1382            "Minimal vs SerialRecalculated: forward max abs diff = {d_rec:.6} (tol {tol})"
1383        );
1384
1385        // ── Gradient agreement ───────────────────────────────────────────
1386        check_grads_match("Minimal vs Serial", &r_min, &r_ser, 1e-3);
1387        check_grads_match("Minimal vs SerialRecalculated", &r_min, &r_rec, 1e-3);
1388    }
1389
1390    #[test]
1391    fn ssd_paths_agree() {
1392        run_ssd_paths_agree(small_config());
1393    }
1394
1395    #[test]
1396    fn ssd_paths_agree_ngroups2() {
1397        let cfg = Mamba2Config::new(32)
1398            .with_state_rank(8)
1399            .with_expand(2)
1400            .with_per_head_dim(16)
1401            .with_ngroups(2);
1402        run_ssd_paths_agree(cfg);
1403    }
1404}