Skip to main content

burn_mamba/mamba3/
mamba3.rs

1//! # Mamba-3 SSM Block — Exponential-Trapezoidal SSD with Data-Dependent RoPE
2//!
3//! This module implements the core **Mamba-3 layer** from the paper
4//! *"The Mamba-3 Framework: Structured State Spaces with Trapezoidal
5//! Discretization and Data-Dependent Rotary Embeddings"*.
6//!
7//! ## The Mamba-3 Recurrence (SISO, Proposition 1)
8//!
9//! ```text
10//!   hₜ = αₜ hₜ₋₁ + βₜ B_{t-1} x_{t-1}ᵀ + γₜ Bₜ xₜᵀ   (state update)
11//!   yₜ = Cₜᵀ hₜ + D xₜ                                  (output)
12//! ```
13//!
14//! ## MIMO Extension (mimo_rank = R > 1)
15//!
16//! With MIMO, the state update becomes a sum of R outer-product contributions:
17//!
18//! ```text
19//!   hₜ = αₜ hₜ₋₁ + βₜ Σ_r B_{t-1}[r] ⊗ (x_{t-1} ⊙ mimo_x[r])
20//!                   + γₜ Σ_r Bₜ[r] ⊗ (xₜ ⊙ mimo_x[r])
21//!   yₜ[r] = Cₜ[r]ᵀ hₜ + D xₜ ⊙ mimo_x[r]
22//!   outₜ = Σ_r mimo_o[r] ⊙ silu(zₜ ⊙ mimo_z[r]) ⊙ yₜ[r]
23//! ```
24//!
25//! The hidden state hₜ is shared across ranks; each rank contributes to it
26//! independently but reads the full shared state when producing its output.
27//!
28//! ## Notation / Dimension Keys
29//!
30//! | Letter | Dimension | Typical value |
31//! |--------|-----------|---------------|
32//! | `b`    | batch     | varies        |
33//! | `s`    | sequence length T | varies |
34//! | `m`    | d_model   | 768, 1024 … |
35//! | `i`    | d_inner = expand·d_model | 2·d_model |
36//! | `h`    | nheads H  | d_inner / P  |
37//! | `p`    | per_head_dim P | 64, 128 |
38//! | `r`    | state_rank N   | 64–256  |
39//! | `R`    | mimo_rank      | 1–8     |
40//! | `n`    | nchunks = T/Q  | varies  |
41//! | `l`    | chunk_len Q    | 64–256  |
42//! | `a`    | num_rope_angles = state_rank / 2 | varies |
43
44use crate::mamba3::prelude::*;
45use crate::utils::sanity::sanity as san;
46use crate::utils::{
47    rms_norm::{RmsNorm, RmsNormConfig},
48    rms_norm_gated::{RmsNormGated, RmsNormGatedConfig},
49    silu::Silu,
50    softplus::softplus,
51};
52use burn::prelude::*;
53use burn::{
54    module::{Module, Param},
55    nn::{Initializer, Linear, LinearConfig},
56};
57
58/// Element-wise sigmoid: σ(x) = 1 / (1 + exp(-x)).
59fn sigmoid<B: Backend, const D: usize>(x: Tensor<B, D>) -> Tensor<B, D> {
60    ((-x).exp() + 1.).recip()
61}
62
63// ---------------------------------------------------------------------------
64// Mamba3  (the SSM block)
65// ---------------------------------------------------------------------------
66
67/// The Mamba-3 SSM block.
68///
69/// Implements the full Mamba-3 layer with exponential-trapezoidal discretization
70/// and data-dependent RoPE.  Supports SISO (mimo_rank=1) and MIMO (mimo_rank>1).
71/// Supports two execution modes:
72///
73/// - [`Self::forward`] — chunkwise two-SSD algorithm for training / prefill
74/// - [`Self::step`]    — recurrent form for token-by-token decoding
75#[derive(Module, Debug)]
76pub struct Mamba3<B: Backend> {
77    /// Input projection.
78    ///
79    /// For SISO (R=1): maps `d_model → 2·d_inner + 2·ngroups·state_rank + 3·nheads + num_rope_angles`.
80    /// For MIMO (R>1): maps `d_model → 2·d_inner + 2·ngroups·state_rank·R + 3·nheads + num_rope_angles`.
81    ///
82    /// Output splits: `[z | x | B_raw | C_raw | dd_dt | dd_A | lam_raw | theta_raw]`
83    pub in_proj: Linear<B>,
84
85    /// Per-head bias for the discretisation step size Δ.
86    /// Shape: `[nheads]`
87    pub dt_bias_h: Param<Tensor<B, 1>>,
88
89    /// Hard clamp applied to Δ after softplus.
90    pub dt_limit: (f64, f64),
91
92    /// Minimum absolute value of A: `A ∈ (−∞, −a_floor]`.
93    pub a_floor: f64,
94
95    /// Per-head skip (D) coefficient.
96    /// Shape: `[nheads]`; initialised to ones.
97    pub d_h: Param<Tensor<B, 1>>,
98
99    /// RMSNorm applied to the B projection (QK-Norm, no gating).
100    /// Normalises over the `state_rank` dimension.
101    pub b_norm: RmsNorm<B>,
102
103    /// RMSNorm applied to the C projection (QK-Norm, no gating).
104    /// Normalises over the `state_rank` dimension.
105    pub c_norm: RmsNorm<B>,
106
107    /// Learnable per-head, per-rank bias for B, added after QK-norm.
108    /// Shape: `[nheads, mimo_rank, state_rank]`; initialised to ones.
109    ///
110    /// For SISO (mimo_rank=1) this has shape `[nheads, 1, state_rank]`.
111    pub b_bias_hrn: Param<Tensor<B, 3>>,
112
113    /// Learnable per-head, per-rank bias for C, added after QK-norm.
114    /// Shape: `[nheads, mimo_rank, state_rank]`; initialised to ones.
115    pub c_bias_hrn: Param<Tensor<B, 3>>,
116
117    /// MIMO up-projection for x (values).
118    /// Shape: `[nheads, mimo_rank, per_head_dim]`.
119    /// Only present when `mimo_rank > 1`.  When SISO, this is `None`.
120    pub mimo_x: Option<Param<Tensor<B, 3>>>,
121
122    /// MIMO up-projection for z (gate).
123    /// Shape: `[nheads, mimo_rank, per_head_dim]`.
124    /// Only present when `mimo_rank > 1`.
125    pub mimo_z: Option<Param<Tensor<B, 3>>>,
126
127    /// MIMO down-projection for the output.
128    /// Shape: `[nheads, mimo_rank, per_head_dim]`.
129    /// Only present when `mimo_rank > 1`.
130    pub mimo_o: Option<Param<Tensor<B, 3>>>,
131
132    /// Optional gated RMSNorm applied before the output projection.
133    ///
134    /// When `Some`, the SiLU gate at the block tail is replaced by
135    /// `RmsNormGated(y, z)` which normalises `y` over `per_head_dim` and
136    /// gates with `SiLU(z)`. Created when `has_outproj_norm = true`.
137    pub out_norm: Option<RmsNormGated<B>>,
138
139    /// Output projection: maps `d_inner → d_model`.
140    pub out_proj: Linear<B>,
141
142    /// Optional learnable initial hidden state `h₀`.
143    /// Shape: `[nheads, per_head_dim, state_rank]`
144    pub init_state_hpr: Option<Param<Tensor<B, 3>>>,
145
146    /// State rank N.
147    pub state_rank: usize,
148
149    /// Number of B/C groups G.  Must divide `nheads`.
150    pub ngroups: usize,
151
152    /// Number of RoPE angle pairs (`rope_dim / 2`).
153    pub num_rope_angles: usize,
154
155    /// Effective RoPE dimension (= `2 · num_rope_angles`). Always even and
156    /// `≤ state_rank`. Only the first `rope_dim` entries of B/C are rotated.
157    pub rope_dim: usize,
158
159    /// MIMO rank R.  1 = SISO (standard Mamba-3).
160    pub mimo_rank: usize,
161}
162
163impl<B: Backend> Mamba3<B> {
164    /// `d_inner = expand · d_model`.  Inferred from `out_proj`.
165    pub fn d_inner(&self) -> usize {
166        let [d_inner, _d_model] = self.out_proj.weight.dims();
167        d_inner
168    }
169
170    /// `nheads = d_inner / per_head_dim`.  Inferred from `d_h`.
171    pub fn nheads(&self) -> usize {
172        let [nheads] = self.d_h.dims();
173        nheads
174    }
175
176    /// `per_head_dim P = d_inner / nheads`.
177    pub fn per_head_dim(&self) -> usize {
178        self.d_inner() / self.nheads()
179    }
180}
181
182// ---------------------------------------------------------------------------
183// Mamba3Config  (hyperparameters and factory)
184// ---------------------------------------------------------------------------
185
186/// Hyperparameters for the Mamba-3 SSM block.
187#[derive(Config, Debug)]
188pub struct Mamba3Config {
189    /// Model (hidden) dimension D.
190    pub d_model: usize,
191
192    /// State rank N — the latent dimension of the SSM hidden state.
193    /// **Must be even** (required for RoPE pairing).
194    #[config(default = 128)]
195    pub state_rank: usize,
196
197    /// Expansion factor for `d_inner = expand · d_model`.
198    #[config(default = 2)]
199    pub expand: usize,
200
201    /// Head dimension P.  `nheads = d_inner / P`.
202    #[config(default = 64)]
203    pub per_head_dim: usize,
204
205    /// Number of B/C groups G.  Must divide `nheads`.
206    #[config(default = 1)]
207    pub ngroups: usize,
208
209    /// MIMO rank R.  `1` = standard SISO Mamba-3.
210    ///
211    /// When `R > 1`, the B/C projections have `R` parallel rank channels.
212    /// Three extra weight matrices (`mimo_x`, `mimo_z`, `mimo_o`) provide
213    /// element-wise up/down projections in head-space across ranks.
214    #[config(default = 1)]
215    pub mimo_rank: usize,
216
217    /// Minimum absolute value of A after clamping.
218    #[config(default = "1e-4")]
219    pub a_floor: f64,
220
221    /// Minimum value of the initial Δ distribution.
222    #[config(default = 1e-3)]
223    pub dt_min: f64,
224
225    /// Maximum value of the initial Δ distribution.
226    #[config(default = 0.1)]
227    pub dt_max: f64,
228
229    /// Floor clamped onto sampled initial Δ values.
230    #[config(default = 1e-4)]
231    pub dt_init_floor: f64,
232
233    /// Hard clamp limits for Δ at runtime.
234    #[config(default = "(0., 6.5504e+4)")]
235    pub dt_limit: (f64, f64),
236
237    /// Whether to add a bias term to the `in_proj` and `out_proj`.
238    #[config(default = false)]
239    pub has_proj_bias: bool,
240
241    /// Whether to allocate a learnable initial SSM state `h₀`.
242    #[config(default = false)]
243    pub has_learnable_init_state: bool,
244
245    /// Fraction of `state_rank` to which RoPE is applied (must be `0.5` or `1.0`).
246    ///
247    /// - `0.5` (default): partial RoPE — only `state_rank / 2` dimensions are
248    ///   rotated; the rest pass through unchanged.
249    /// - `1.0`: full RoPE — every B/C dimension is rotated.
250    ///
251    /// Default matches the reference's `rope_fraction` argument in `mamba3.py:36`.
252    #[config(default = 0.5)]
253    pub rope_fraction: f64,
254
255    /// Whether to apply a gated RMSNorm before the output projection.
256    ///
257    /// When `true`, the SiLU gate at the end of the block is replaced by a
258    /// per-head [`RmsNormGated`] (group size = `per_head_dim`) which both
259    /// normalises `y` and gates it with `SiLU(z)`. Matches the reference's
260    /// `is_outproj_norm` argument in `mamba3.py:41`.
261    #[config(default = false)]
262    pub has_outproj_norm: bool,
263}
264
265impl Mamba3Config {
266    pub fn d_inner(&self) -> usize {
267        self.expand * self.d_model
268    }
269    pub fn nheads(&self) -> usize {
270        self.d_inner() / self.per_head_dim
271    }
272
273    /// Effective RoPE dimension: `2 · num_rope_angles`. Equals `state_rank`
274    /// for full RoPE, `state_rank / 2` for `rope_fraction = 0.5`.
275    pub fn rope_dim(&self) -> usize {
276        let mut d = (self.state_rank as f64 * self.rope_fraction) as usize;
277        if d % 2 != 0 {
278            d -= 1;
279        }
280        d
281    }
282
283    pub fn num_rope_angles(&self) -> usize {
284        self.rope_dim() / 2
285    }
286
287    /// Total input projection output size.
288    ///
289    /// `d_in_proj = 2·d_inner + 2·ngroups·state_rank·mimo_rank + 3·nheads + num_rope_angles`
290    pub fn d_in_proj(&self) -> usize {
291        2 * self.d_inner()
292            + 2 * self.ngroups * self.state_rank * self.mimo_rank
293            + 3 * self.nheads()
294            + self.num_rope_angles()
295    }
296
297    /// Allocate and initialise all Mamba-3 block parameters on `device`.
298    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3<B> {
299        let d_inner = self.d_inner();
300        let nheads = self.nheads();
301        let ngroups = self.ngroups;
302        let state_rank = self.state_rank;
303        let mimo_rank = self.mimo_rank;
304        let num_rope_angles = self.num_rope_angles();
305
306        assert!(
307            state_rank % 2 == 0,
308            "state_rank must be even for RoPE pairing"
309        );
310        assert!(self.per_head_dim > 0, "per_head_dim must be positive");
311        assert_eq!(
312            nheads * self.per_head_dim,
313            d_inner,
314            "d_inner must be divisible by per_head_dim"
315        );
316        assert_ne!(ngroups, 0, "ngroups must be at least 1");
317        assert_eq!(nheads % ngroups, 0, "nheads must be divisible by ngroups");
318        assert!(self.a_floor > 0.0, "a_floor must be positive");
319        assert!(mimo_rank >= 1, "mimo_rank must be at least 1");
320        assert!(
321            self.rope_fraction == 0.5 || self.rope_fraction == 1.0,
322            "rope_fraction must be 0.5 or 1.0"
323        );
324        assert!(num_rope_angles > 0, "num_rope_angles must be at least 1");
325
326        let uniform_init = |fan_in: usize| {
327            let bound = 1.0 / (fan_in as f64).sqrt();
328            Initializer::Uniform {
329                min: -bound,
330                max: bound,
331            }
332        };
333
334        let in_proj = LinearConfig::new(self.d_model, self.d_in_proj())
335            .with_bias(self.has_proj_bias)
336            .with_initializer(uniform_init(self.d_model))
337            .init::<B>(device);
338
339        // dt_bias: inverse-softplus initialisation
340        let expm1 = |t: Tensor<B, 1>| t.exp() - 1.;
341        let dt_h = Tensor::random(
342            [nheads],
343            burn::tensor::Distribution::Uniform(self.dt_min.ln(), self.dt_max.ln()),
344            device,
345        )
346        .exp();
347        let dt_h = dt_h.clamp(self.dt_init_floor, f64::INFINITY);
348        let inv_dt_h = dt_h.clone() + (-expm1(-dt_h)).log();
349        let dt_bias_h = Param::from_tensor(inv_dt_h);
350
351        let d_h = Initializer::Ones.init::<B, 1, _>([nheads], device);
352
353        let b_norm = RmsNormConfig::new(state_rank).init(device);
354        let c_norm = RmsNormConfig::new(state_rank).init(device);
355
356        // B/C biases: [nheads, mimo_rank, state_rank], init to ones
357        let b_bias_hrn = Initializer::Ones.init::<B, 3, _>([nheads, mimo_rank, state_rank], device);
358        let c_bias_hrn = Initializer::Ones.init::<B, 3, _>([nheads, mimo_rank, state_rank], device);
359
360        // MIMO projections (only for R > 1)
361        let (mimo_x, mimo_z, mimo_o) = if mimo_rank > 1 {
362            let per_head_dim = self.per_head_dim;
363            // Init: mimo_x and mimo_o to 1/R, mimo_z to 1
364            let mx = Param::from_tensor(Tensor::full(
365                [nheads, mimo_rank, per_head_dim],
366                1.0 / mimo_rank as f64,
367                device,
368            ));
369            let mz = Param::from_tensor(Tensor::ones([nheads, mimo_rank, per_head_dim], device));
370            let mo = Param::from_tensor(Tensor::full(
371                [nheads, mimo_rank, per_head_dim],
372                1.0 / mimo_rank as f64,
373                device,
374            ));
375            (Some(mx), Some(mz), Some(mo))
376        } else {
377            (None, None, None)
378        };
379
380        // Gated RMSNorm applied per-head (group size = per_head_dim).
381        let out_norm = self.has_outproj_norm.then(|| {
382            RmsNormGatedConfig::new(self.per_head_dim)
383                .with_norm_before_gate(true)
384                .init(device)
385        });
386
387        let out_proj = LinearConfig::new(d_inner, self.d_model)
388            .with_bias(self.has_proj_bias)
389            .with_initializer(uniform_init(d_inner))
390            .init(device);
391
392        let init_state_hpr = self.has_learnable_init_state.then(|| {
393            Initializer::Zeros.init::<B, 3, _>([nheads, self.per_head_dim, state_rank], device)
394        });
395
396        Mamba3 {
397            in_proj,
398            dt_bias_h,
399            dt_limit: self.dt_limit,
400            a_floor: self.a_floor,
401            d_h,
402            b_norm,
403            c_norm,
404            b_bias_hrn,
405            c_bias_hrn,
406            mimo_x,
407            mimo_z,
408            mimo_o,
409            out_norm,
410            out_proj,
411            init_state_hpr,
412            state_rank,
413            ngroups,
414            rope_dim: self.rope_dim(),
415            num_rope_angles,
416            mimo_rank,
417        }
418    }
419}
420
421// ---------------------------------------------------------------------------
422// RoPE utility
423// ---------------------------------------------------------------------------
424
425/// Apply rotary position embeddings to `x` along its last dimension.
426///
427/// Two pairing conventions are supported, selected by `rotate_pairwise`:
428///
429/// - `rotate_pairwise = true` — **interleaved** (NeoX / Triton style): adjacent
430///   pairs `(0,1)`, `(2,3)`, … are rotated together. Used by the SISO Triton
431///   kernel (`mamba3_siso_*.py`).
432/// - `rotate_pairwise = false` — **half-and-half** (GPT-J style): position `n`
433///   is paired with `n + N/2`. Used by the MIMO Tilelang kernel
434///   (`mamba3_mimo_fwd.py`).
435///
436/// Reference: `mamba3.py:335` sets `rotate_pairwise = not self.is_mimo`.
437///
438/// # Shapes
439/// - `x`:      `[..., state_rank]` where `state_rank` is even
440/// - `angles`: `[..., state_rank / 2]`  (one angle per pair)
441/// - output:   same shape as `x`
442pub fn apply_rope<B: Backend, const D: usize>(
443    x: Tensor<B, D>,
444    angles: Tensor<B, D>,
445    rotate_pairwise: bool,
446) -> Tensor<B, D> {
447    let dims = x.dims();
448    let n = dims[D - 1];
449    let n2 = n / 2;
450    let leading: usize = dims[..D - 1].iter().product();
451
452    let angles_flat = angles.reshape([leading, n2]);
453    let cos = angles_flat.clone().cos();
454    let sin = angles_flat.sin();
455
456    if rotate_pairwise {
457        // Interleaved: reshape to [leading, n2, 2], pairs along last axis.
458        let x_pairs = x.reshape([leading, n2, 2]);
459        let x0 = x_pairs.clone().narrow(2, 0, 1).squeeze_dim(2);
460        let x1 = x_pairs.narrow(2, 1, 1).squeeze_dim(2);
461
462        let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
463        let x1r = sin * x0 + cos * x1;
464
465        Tensor::cat(
466            vec![x0r.unsqueeze_dim::<3>(2), x1r.unsqueeze_dim::<3>(2)],
467            2,
468        )
469        .reshape(dims)
470    } else {
471        // Half-and-half: reshape to [leading, 2, n2], halves along middle axis.
472        let x_halves = x.reshape([leading, 2, n2]);
473        let x0 = x_halves.clone().narrow(1, 0, 1).squeeze_dim(1);
474        let x1 = x_halves.narrow(1, 1, 1).squeeze_dim(1);
475
476        let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
477        let x1r = sin * x0 + cos * x1;
478
479        Tensor::cat(
480            vec![x0r.unsqueeze_dim::<3>(1), x1r.unsqueeze_dim::<3>(1)],
481            1,
482        )
483        .reshape(dims)
484    }
485}
486
487/// Apply RoPE to only the rotation-active entries of the last dimension; the
488/// remainder passes through unchanged. Falls back to [`apply_rope`] when
489/// `rope_dim == state_rank` (full RoPE).
490///
491/// Pairing scheme (must match the reference kernels — see Section
492/// "Data-Dependent RoPE" in the paper, and `mamba3_siso_fwd.py` /
493/// `mamba3_mimo_fwd.py`):
494///
495/// - `rotate_pairwise = true` (SISO, interleaved/NeoX): pairs `(0,1), (2,3), …`.
496///   Only pairs `0..num_rope_angles` are rotated; pairs beyond are passed
497///   through. Equivalent to slicing the first `rope_dim` entries and rotating
498///   them.
499/// - `rotate_pairwise = false` (MIMO, half-and-half/GPT-J): pair distance is
500///   always `state_rank/2`, i.e. element `n` is paired with element
501///   `state_rank/2 + n`. With partial RoPE only the first `num_rope_angles`
502///   pairs are rotated; the remaining elements in both halves pass through.
503fn apply_rope_partial<B: Backend, const D: usize>(
504    x: Tensor<B, D>,
505    angles: Tensor<B, D>,
506    rope_dim: usize,
507    rotate_pairwise: bool,
508) -> Tensor<B, D> {
509    let state_rank = x.dims()[D - 1];
510    if rope_dim == state_rank {
511        return apply_rope::<B, D>(x, angles, rotate_pairwise);
512    }
513
514    if rotate_pairwise {
515        // Pairs are local — slicing the first rope_dim entries gives the same
516        // result as the reference (which rotates the whole headdim but with
517        // identity cos/sin for the tail pairs).
518        let x_rope = x.clone().narrow(D - 1, 0, rope_dim);
519        let x_rest = x.narrow(D - 1, rope_dim, state_rank - rope_dim);
520        let x_rope_rotated = apply_rope::<B, D>(x_rope, angles, true);
521        return Tensor::cat(vec![x_rope_rotated, x_rest], D - 1);
522    }
523
524    // Half-and-half partial RoPE: pair distance must be `state_rank/2`, not
525    // `rope_dim/2`. Slicing the first `rope_dim` entries and calling
526    // `apply_rope` would pair within the slice and produce the wrong rotation.
527    let half = state_rank / 2;
528    let num_rope_angles = rope_dim / 2;
529    debug_assert!(
530        num_rope_angles < half,
531        "partial RoPE requires rope_dim < state_rank here"
532    );
533
534    // Split x into the two halves, then within each half separate the
535    // rotation-active prefix from the pass-through suffix.
536    let x_h1 = x.clone().narrow(D - 1, 0, half);
537    let x_h2 = x.narrow(D - 1, half, half);
538    let x_h1_rope = x_h1.clone().narrow(D - 1, 0, num_rope_angles);
539    let x_h1_pass = x_h1.narrow(D - 1, num_rope_angles, half - num_rope_angles);
540    let x_h2_rope = x_h2.clone().narrow(D - 1, 0, num_rope_angles);
541    let x_h2_pass = x_h2.narrow(D - 1, num_rope_angles, half - num_rope_angles);
542
543    // angles: [..., num_rope_angles] — broadcasts element-wise against the rope-active slices.
544    let cos = angles.clone().cos();
545    let sin = angles.sin();
546    let x_h1_rot = cos.clone() * x_h1_rope.clone() - sin.clone() * x_h2_rope.clone();
547    let x_h2_rot = sin * x_h1_rope + cos * x_h2_rope;
548
549    // Reassemble: [ first-half-rotated | first-half-passthrough | second-half-rotated | second-half-passthrough ]
550    let x_h1_out = Tensor::cat(vec![x_h1_rot, x_h1_pass], D - 1);
551    let x_h2_out = Tensor::cat(vec![x_h2_rot, x_h2_pass], D - 1);
552    Tensor::cat(vec![x_h1_out, x_h2_out], D - 1)
553}
554
555// ---------------------------------------------------------------------------
556// MIMO helpers
557// ---------------------------------------------------------------------------
558
559/// Build the V (value) tensor for MIMO by expanding x over ranks.
560///
561/// # Shapes
562/// - `x_bShp`:    `[batch, S, nheads, per_head_dim]`
563/// - `mimo_x_hrp`: `[nheads, mimo_rank, per_head_dim]`
564/// - output:       `[batch, S, mimo_rank, nheads, per_head_dim]`
565///
566/// When `mimo_x_hrp` is `None` (SISO), wraps `x` in a rank-1 dim.
567fn build_v_mimo<B: Backend>(
568    x_bshp: Tensor<B, 4>,
569    mimo_x_hrp: Option<&Tensor<B, 3>>,
570) -> Tensor<B, 5> {
571    let [batch, seq, nheads, per_head_dim] = x_bshp.dims();
572    match mimo_x_hrp {
573        None => {
574            // SISO: just add a rank dimension of size 1
575            x_bshp.unsqueeze_dim::<5>(2) // [b, s, 1, h, p]
576        }
577        Some(mimo_x) => {
578            let [_, mimo_rank, _] = mimo_x.dims();
579            // x_bshp:  [b, s, h, p] → [b, s, 1, h, p]
580            let x_exp =
581                x_bshp
582                    .unsqueeze_dim::<5>(2)
583                    .expand([batch, seq, mimo_rank, nheads, per_head_dim]);
584            // mimo_x: [h, r, p] → [1, 1, r, h, p]
585            let mx_exp = mimo_x
586                .clone()
587                .permute([1, 0, 2]) // [r, h, p]
588                .unsqueeze_dim::<4>(0)
589                .unsqueeze_dim::<5>(0)
590                .expand([batch, seq, mimo_rank, nheads, per_head_dim]);
591            x_exp * mx_exp // [b, s, r, h, p]
592        }
593    }
594}
595
596// ---------------------------------------------------------------------------
597// Mamba3::forward  (chunkwise two-SSD — training / prefill)
598// ---------------------------------------------------------------------------
599
600impl<B: Backend + Mamba3BackendExt> Mamba3<B> {
601    /// Process a full input sequence using the trapezoidal two-SSD algorithm.
602    ///
603    /// For SISO (mimo_rank=1), this is the standard two-SSD decomposition.
604    /// For MIMO (mimo_rank=R>1), B/C have R parallel rank channels. The hidden
605    /// state is shared across ranks; each rank contributes independently.
606    ///
607    /// # Shapes
608    /// - `input_bsm` : `[batch, sequence, d_model]`
609    /// - output      : `[batch, sequence, d_model]`
610    #[allow(non_snake_case)]
611    pub fn forward(
612        &self,
613        input_bsm: Tensor<B, 3>,
614        cache: Option<Mamba3Cache<B>>,
615        ssd_path: Mamba3SsdPath,
616    ) -> (Tensor<B, 3>, Mamba3Cache<B>) {
617        let [batch, sequence, _d_model] = input_bsm.dims();
618        let d_inner = self.d_inner();
619        let nheads = self.nheads();
620        let ngroups = self.ngroups;
621        let per_head_dim = self.per_head_dim();
622        let state_rank = self.state_rank;
623        let num_rope_angles = self.num_rope_angles;
624        let heads_per_group = nheads / ngroups;
625        let mimo_rank = self.mimo_rank;
626        let device = input_bsm.device();
627
628        assert!(sequence > 0, "sequence length must be at least 1");
629        assert_eq!(nheads % ngroups, 0);
630        san(&input_bsm);
631
632        // ── Initialise cache if not provided ──────────────────────────────────
633        let mut cache = cache.unwrap_or_else(|| {
634            let ssm_bhpr = Tensor::zeros([batch, nheads, per_head_dim, state_rank], &device);
635            let k_state_brhn = Tensor::zeros([batch, mimo_rank, nheads, state_rank], &device);
636            let v_state_bhp = Tensor::zeros([batch, nheads, per_head_dim], &device);
637            let cum_angle_bhr = Tensor::zeros([batch, nheads, num_rope_angles], &device);
638            Mamba3Cache {
639                ssm_bhpr,
640                k_state_brhn,
641                v_state_bhp,
642                cum_angle_bhr,
643            }
644        });
645
646        // ── Step 1: In-projection ─────────────────────────────────────────────
647        let proj_bsd = self.in_proj.forward(input_bsm);
648        let bc_size = ngroups * state_rank * mimo_rank;
649
650        let mut parts = proj_bsd
651            .split_with_sizes(
652                vec![
653                    d_inner,
654                    d_inner,
655                    bc_size,
656                    bc_size,
657                    nheads,
658                    nheads,
659                    nheads,
660                    num_rope_angles,
661                ],
662                2,
663            )
664            .into_iter();
665        let z_bsi = parts.next().unwrap(); // [B, T, d_inner]
666        let x_bsi = parts.next().unwrap(); // [B, T, d_inner]
667        let b_raw_bsd = parts.next().unwrap(); // [B, T, ngroups*state_rank*mimo_rank]
668        let c_raw_bsd = parts.next().unwrap(); // [B, T, ngroups*state_rank*mimo_rank]
669        let dd_dt_bsh = parts.next().unwrap(); // [B, T, nheads]
670        let dd_A_raw_bsh = parts.next().unwrap(); // [B, T, nheads]
671        let lam_raw_bsh = parts.next().unwrap(); // [B, T, nheads]
672        let theta_bsa = parts.next().unwrap(); // [B, T, num_rope_angles]
673
674        san(&z_bsi);
675        san(&x_bsi);
676        san(&dd_dt_bsh);
677
678        // ── Step 2: Discretisation + trapezoidal coefficients ─────────────────
679        let dt_bias_11h = self.dt_bias_h.val().unsqueeze_dims(&[0, 1]);
680        let dt_bsh = softplus(dd_dt_bsh + dt_bias_11h).clamp(self.dt_limit.0, self.dt_limit.1);
681
682        let a_bsh = -softplus(dd_A_raw_bsh).clamp(f64::NEG_INFINITY, -self.a_floor);
683        let da_bsh = dt_bsh.clone() * a_bsh;
684
685        let alpha_bsh = da_bsh.clone().exp();
686        let lam_bsh = sigmoid(lam_raw_bsh);
687        let gamma_bsh = lam_bsh.clone() * dt_bsh.clone();
688        let beta_bsh = (-lam_bsh.clone() + 1.0) * dt_bsh.clone() * alpha_bsh.clone();
689
690        san(&dt_bsh);
691        san(&da_bsh);
692        san(&gamma_bsh);
693        san(&beta_bsh);
694
695        // ── Step 3: Reshape x ─────────────────────────────────────────────────
696        let x_bshp = x_bsi.reshape([batch, sequence, nheads, per_head_dim]);
697
698        // ── Step 4: QK-Norm on B and C → [b, T, R, H, N] ─────────────────────
699        // Reshape: [b, T, R*G*N] → [b, T, R, G, N]
700        // QK-Norm over N, then expand G→H, then add per-head+rank bias [H, R, N].
701        let b_bsrhr = {
702            let b_bsrgr = b_raw_bsd.reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
703            // Norm over last dim (state_rank) for each (b, s, r, g) slice:
704            // Flatten leading dims so RmsNorm operates on last dim only.
705            let b_norm = self
706                .b_norm
707                .forward(b_bsrgr.reshape([batch * sequence * mimo_rank, ngroups, state_rank]))
708                .reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
709            // Expand groups → heads: [b, T, R, G, N] → [b, T, R, G, H/G, N] → [b, T, R, H, N]
710            let b_exp = b_norm
711                .unsqueeze_dim::<6>(4) // [b, T, R, G, 1, N]
712                .expand([
713                    batch,
714                    sequence,
715                    mimo_rank,
716                    ngroups,
717                    heads_per_group,
718                    state_rank,
719                ])
720                .reshape([batch, sequence, mimo_rank, nheads, state_rank]);
721            // Add bias [H, R, N] → broadcast as [1, 1, R, H, N]
722            // b_bias_hrn: [H, R, N] → permute to [R, H, N] → unsqueeze → [1, 1, R, H, N]
723            let bias = self
724                .b_bias_hrn
725                .val()
726                .permute([1, 0, 2]) // [R, H, N]
727                .unsqueeze_dim::<4>(0)
728                .unsqueeze_dim::<5>(0); // [1, 1, R, H, N]
729            b_exp + bias
730        };
731        let c_bsrhr = {
732            let c_bsrgr = c_raw_bsd.reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
733            let c_norm = self
734                .c_norm
735                .forward(c_bsrgr.reshape([batch * sequence * mimo_rank, ngroups, state_rank]))
736                .reshape([batch, sequence, mimo_rank, ngroups, state_rank]);
737            let c_exp = c_norm
738                .unsqueeze_dim::<6>(4)
739                .expand([
740                    batch,
741                    sequence,
742                    mimo_rank,
743                    ngroups,
744                    heads_per_group,
745                    state_rank,
746                ])
747                .reshape([batch, sequence, mimo_rank, nheads, state_rank]);
748            let bias = self
749                .c_bias_hrn
750                .val()
751                .permute([1, 0, 2])
752                .unsqueeze_dim::<4>(0)
753                .unsqueeze_dim::<5>(0);
754            c_exp + bias
755        };
756        // b_bsrhr: [b, T, R, H, N]
757        assert_eq!(
758            [batch, sequence, mimo_rank, nheads, state_rank],
759            b_bsrhr.dims()
760        );
761        assert_eq!(
762            [batch, sequence, mimo_rank, nheads, state_rank],
763            c_bsrhr.dims()
764        );
765
766        // ── Step 5: Data-dependent cumulative RoPE angles ─────────────────────
767        let theta_scaled_bsa = theta_bsa.tanh() * std::f32::consts::PI;
768        let raw_angles_bsha =
769            dt_bsh.clone().unsqueeze_dim::<4>(3) * theta_scaled_bsa.unsqueeze_dim::<4>(2);
770
771        let cumsum_bsha = raw_angles_bsha.cumsum(1);
772        let cum_angles_bsha = cache.cum_angle_bhr.clone().unsqueeze_dim::<4>(1) + cumsum_bsha;
773        assert_eq!(
774            [batch, sequence, nheads, num_rope_angles],
775            cum_angles_bsha.dims()
776        );
777        san(&cum_angles_bsha);
778
779        // Apply RoPE to B and C: angles broadcast over the R dim.
780        // b_bsrhr: [b, T, R, H, N], angles: [b, T, H, A] → [b, T, 1, H, A]
781        let angles_exp_bsrha = cum_angles_bsha.clone().unsqueeze_dim::<5>(2).expand([
782            batch,
783            sequence,
784            mimo_rank,
785            nheads,
786            num_rope_angles,
787        ]);
788        // SISO uses interleaved (pairwise) pairing; MIMO uses half-and-half.
789        // Partial RoPE: rotate only the first `rope_dim` entries of B/C.
790        let rotate_pairwise = mimo_rank == 1;
791        let rope_dim = self.rope_dim;
792        let b_bsrhn = apply_rope_partial::<B, 5>(
793            b_bsrhr,
794            angles_exp_bsrha.clone(),
795            rope_dim,
796            rotate_pairwise,
797        );
798        let c_bsrhn =
799            apply_rope_partial::<B, 5>(c_bsrhr, angles_exp_bsrha, rope_dim, rotate_pairwise);
800        san(&b_bsrhn);
801        san(&c_bsrhn);
802
803        // ── Step 6: Build shifted inputs for β term ───────────────────────────
804        //
805        // "Shift-Before-Chunking": prepend the cached x_{-1} / B_{-1} at the
806        // sequence level (before SSD chunking) so the β term at t=0 sees the
807        // prior token from a continued cache. For a fresh (zero) cache this is
808        // equivalent to zero-padding.
809        let x_prev_first_b1hp = cache.v_state_bhp.clone().unsqueeze_dim::<4>(1);
810        let x_prev_bshp = if sequence == 1 {
811            x_prev_first_b1hp
812        } else {
813            Tensor::cat(
814                vec![x_prev_first_b1hp, x_bshp.clone().narrow(1, 0, sequence - 1)],
815                1,
816            )
817        };
818        // b_prev: [b, T, R, H, N]
819        let b_prev_first_b1rhn = cache.k_state_brhn.clone().unsqueeze_dim::<5>(1);
820        let b_prev_bsrhn = if sequence == 1 {
821            b_prev_first_b1rhn
822        } else {
823            Tensor::cat(
824                vec![
825                    b_prev_first_b1rhn,
826                    b_bsrhn.clone().narrow(1, 0, sequence - 1),
827                ],
828                1,
829            )
830        };
831
832        // ── Step 7: Scale inputs by trapezoidal coefficients ──────────────────
833        // gamma and beta are per-head scalars, broadcast over R and P:
834        let gamma_bsh1 = gamma_bsh.unsqueeze_dim::<4>(3);
835        let beta_bsh1 = beta_bsh.unsqueeze_dim::<4>(3);
836        let x_gamma_bshp = x_bshp.clone() * gamma_bsh1; // γ_t · x_t
837        let x_beta_bshp = x_prev_bshp * beta_bsh1; // β_t · x_{t-1}
838
839        // ── Save last-token B for cache ───────────────────────────────────────
840        let b_last_brhn = b_bsrhn
841            .clone()
842            .narrow(1, sequence - 1, 1)
843            .reshape([batch, mimo_rank, nheads, state_rank]);
844
845        // ── Step 8: Pad sequence to multiple of chunk_len ─────────────────────
846        let chunk_len = ssd_path.chunk_len_or_optimal(state_rank, per_head_dim);
847        let sequence_padded = sequence.next_multiple_of(chunk_len);
848        let pad = sequence_padded - sequence;
849
850        let (x_gamma_bShp, x_beta_bShp, da_bSh, b_bSrhn, b_prev_bSrhn, c_bSrhn) = if pad == 0 {
851            (
852                x_gamma_bshp,
853                x_beta_bshp,
854                da_bsh,
855                b_bsrhn,
856                b_prev_bsrhn,
857                c_bsrhn,
858            )
859        } else {
860            let pad_hp = Tensor::zeros([batch, pad, nheads, per_head_dim], &device);
861            let pad_h = Tensor::zeros([batch, pad, nheads], &device);
862            let pad_rhn = Tensor::zeros([batch, pad, mimo_rank, nheads, state_rank], &device);
863            (
864                Tensor::cat(vec![x_gamma_bshp, pad_hp.clone()], 1),
865                Tensor::cat(vec![x_beta_bshp, pad_hp], 1),
866                Tensor::cat(vec![da_bsh, pad_h], 1),
867                Tensor::cat(vec![b_bsrhn, pad_rhn.clone()], 1),
868                Tensor::cat(vec![b_prev_bsrhn, pad_rhn.clone()], 1),
869                Tensor::cat(vec![c_bsrhn, pad_rhn], 1),
870            )
871        };
872
873        // ── Reshape into chunks ───────────────────────────────────────────────
874        let nchunks = sequence_padded / chunk_len;
875        let x_gamma_bnlhp = x_gamma_bShp.reshape([batch, nchunks, chunk_len, nheads, per_head_dim]);
876        let x_beta_bnlhp = x_beta_bShp.reshape([batch, nchunks, chunk_len, nheads, per_head_dim]);
877        let da_bnlh = da_bSh.reshape([batch, nchunks, chunk_len, nheads]);
878        // [b, S, R, H, N] → [b, n, l, R, H, N]
879        let b_bnlrhn = b_bSrhn.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
880        let b_prev_bnlrhn =
881            b_prev_bSrhn.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
882        let c_bnlrhn = c_bSrhn.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, state_rank]);
883
884        // ── Step 9: Two MIMO-SSD calls ────────────────────────────────────────
885        // Build V tensors: [b, n, l, R, H, P]
886        let mimo_x_val = self.mimo_x.as_ref().map(|p| p.val());
887        let v_gamma_bnlrhp = build_v_mimo_chunked(
888            x_gamma_bnlhp.clone(),
889            mimo_x_val.as_ref(),
890            batch,
891            nchunks,
892            chunk_len,
893            mimo_rank,
894            nheads,
895            per_head_dim,
896        );
897        let v_beta_bnlrhp = build_v_mimo_chunked(
898            x_beta_bnlhp,
899            mimo_x_val.as_ref(),
900            batch,
901            nchunks,
902            chunk_len,
903            mimo_rank,
904            nheads,
905            per_head_dim,
906        );
907
908        let input_gamma = Mamba3SsdInput {
909            v_bnlrhp: v_gamma_bnlrhp,
910            da_bnlh: da_bnlh.clone(),
911            b_bnlrhn: b_bnlrhn.clone(),
912            c_bnlrhn: c_bnlrhn.clone(),
913            initial_state_bhpr: cache.ssm_bhpr,
914            init_state_hpr: self.init_state_hpr.as_ref().map(|s| s.val()),
915        };
916        let (y_gamma_bnlrhp, final_state_gamma) = ssd_path.clone().run(input_gamma);
917
918        let input_beta = Mamba3SsdInput {
919            v_bnlrhp: v_beta_bnlrhp,
920            da_bnlh,
921            b_bnlrhn: b_prev_bnlrhn,
922            c_bnlrhn,
923            initial_state_bhpr: Tensor::zeros([batch, nheads, per_head_dim, state_rank], &device),
924            init_state_hpr: None,
925        };
926        let (y_beta_bnlrhp, final_state_beta) = ssd_path.run(input_beta);
927
928        // y_bnlrhp: [b, n, l, R, H, P]
929        let y_bnlrhp = y_gamma_bnlrhp + y_beta_bnlrhp;
930        let final_state_bhpr = final_state_gamma + final_state_beta;
931
932        san(&y_bnlrhp);
933        san(&final_state_bhpr);
934
935        cache.ssm_bhpr = final_state_bhpr;
936
937        // ── Step 10: Unpad ────────────────────────────────────────────────────
938        let y_bSrhp = y_bnlrhp.reshape([batch, sequence_padded, mimo_rank, nheads, per_head_dim]);
939        let y_bsrhp = if pad == 0 {
940            y_bSrhp
941        } else {
942            y_bSrhp.narrow(1, 0, sequence)
943        };
944
945        // ── Step 11: D skip + gate + aggregate ranks ──────────────────────────
946        // D skip uses raw x * mimo_x (not gamma-scaled)
947        let v_raw_bsrhp = build_v_mimo::<B>(x_bshp.clone(), mimo_x_val.as_ref());
948
949        let d_11_h1 = self.d_h.val().unsqueeze_dims::<5>(&[0, 1, 2, 4]); // [1, 1, 1, H, 1]
950        let y_bsrhp = y_bsrhp + d_11_h1 * v_raw_bsrhp.clone();
951
952        // ── Gate (or gated norm) and rank aggregation ─────────────────────────
953        // When `out_norm` is set, the SiLU gate is replaced by a per-head
954        // gated RMSNorm: `RmsNormGated(y, z) = norm(y) * silu(z)`.
955        let y_bsi = if mimo_rank > 1 {
956            let mimo_z_val = self.mimo_z.as_ref().map(|p| p.val()).unwrap();
957            let mimo_o_val = self.mimo_o.as_ref().map(|p| p.val()).unwrap();
958
959            // z_r = z * mimo_z[r]: [b, s, h, p] * [h, r, p] → [b, s, r, h, p]
960            let z_bshp = z_bsi
961                .clone()
962                .reshape([batch, sequence, nheads, per_head_dim]);
963            let z_bsrhp = {
964                // z: [b, s, h, p] → [b, s, 1, h, p]
965                // mimo_z: [h, r, p] → [r, h, p] → [1, 1, r, h, p]
966                let z_exp = z_bshp.unsqueeze_dim::<5>(2).expand([
967                    batch,
968                    sequence,
969                    mimo_rank,
970                    nheads,
971                    per_head_dim,
972                ]);
973                let mz = mimo_z_val
974                    .permute([1, 0, 2]) // [r, h, p]
975                    .unsqueeze_dim::<4>(0)
976                    .unsqueeze_dim::<5>(0)
977                    .expand([batch, sequence, mimo_rank, nheads, per_head_dim]);
978                z_exp * mz
979            };
980
981            // Per-rank gate or gated norm:
982            //   without out_norm: y_r * silu(z_r)
983            //   with    out_norm: norm(y_r) * silu(z_r)  (norm over per_head_dim)
984            let y_combined_bsrhp = match &self.out_norm {
985                Some(norm) => norm.forward(y_bsrhp, z_bsrhp),
986                None => y_bsrhp * Silu::new().forward(z_bsrhp),
987            };
988
989            // Down-project with mimo_o: out = sum_r mimo_o[h, r, p] * y_r
990            // mimo_o: [h, r, p] → [r, h, p] → [1, 1, r, h, p]
991            let mo = mimo_o_val
992                .permute([1, 0, 2]) // [r, h, p]
993                .unsqueeze_dim::<4>(0)
994                .unsqueeze_dim::<5>(0)
995                .expand([batch, sequence, mimo_rank, nheads, per_head_dim]);
996            // sum over rank dim (dim=2): [b, s, r, h, p] → [b, s, h, p]
997            let y_bhp: Tensor<B, 4> = (y_combined_bsrhp * mo).sum_dim(2).squeeze_dim(2);
998            y_bhp.reshape([batch, sequence, d_inner])
999        } else {
1000            // SISO: squeeze rank dim, apply gate (or gated norm) over per_head_dim.
1001            let y_bshp: Tensor<B, 4> = y_bsrhp.squeeze_dim(2); // [b, s, h, p]
1002            let z_bshp = z_bsi.reshape([batch, sequence, nheads, per_head_dim]);
1003            let y_combined_bshp = match &self.out_norm {
1004                Some(norm) => norm.forward(y_bshp, z_bshp),
1005                None => y_bshp * Silu::new().forward(z_bshp),
1006            };
1007            y_combined_bshp.reshape([batch, sequence, d_inner])
1008        };
1009        san(&y_bsi);
1010
1011        // ── Out-projection ────────────────────────────────────────────────────
1012        let out_bsm = self.out_proj.forward(y_bsi);
1013        san(&out_bsm);
1014
1015        // ── Update remaining cache fields ─────────────────────────────────────
1016        // k_state = B at last token: [b, R, H, N]
1017        cache.k_state_brhn = b_last_brhn;
1018
1019        // v_state = x at last token: [b, H, P]
1020        cache.v_state_bhp =
1021            x_bshp
1022                .narrow(1, sequence - 1, 1)
1023                .reshape([batch, nheads, per_head_dim]);
1024
1025        // cum_angle at last token
1026        cache.cum_angle_bhr =
1027            cum_angles_bsha
1028                .narrow(1, sequence - 1, 1)
1029                .reshape([batch, nheads, num_rope_angles]);
1030
1031        (out_bsm, cache)
1032    }
1033}
1034
1035fn build_v_mimo_chunked<B: Backend>(
1036    x_bnlhp: Tensor<B, 5>,
1037    mimo_x: Option<&Tensor<B, 3>>,
1038    batch: usize,
1039    nchunks: usize,
1040    chunk_len: usize,
1041    mimo_rank: usize,
1042    nheads: usize,
1043    per_head_dim: usize,
1044) -> Tensor<B, 6> {
1045    match mimo_x {
1046        None => x_bnlhp.unsqueeze_dim::<6>(3),
1047        Some(mx) => {
1048            let x_exp = x_bnlhp.unsqueeze_dim::<6>(3).expand([
1049                batch,
1050                nchunks,
1051                chunk_len,
1052                mimo_rank,
1053                nheads,
1054                per_head_dim,
1055            ]);
1056            let mx_exp = mx
1057                .clone()
1058                .permute([1, 0, 2])
1059                .unsqueeze_dim::<4>(0)
1060                .unsqueeze_dim::<5>(0)
1061                .unsqueeze_dim::<6>(0)
1062                .expand([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
1063            x_exp * mx_exp
1064        }
1065    }
1066}
1067
1068// ---------------------------------------------------------------------------
1069// Mamba3::step  (recurrent SSM — token-by-token decoding)
1070// ---------------------------------------------------------------------------
1071
1072mod step {
1073    use super::*;
1074
1075    impl<B: Backend> Mamba3<B> {
1076        /// Process a **single token** using the pure recurrent form.
1077        ///
1078        /// For SISO (mimo_rank=1):
1079        /// ```text
1080        ///   hₜ = αₜ hₜ₋₁ + βₜ B_{t-1} ⊗ x_{t-1} + γₜ Bₜ ⊗ xₜ
1081        ///   yₜ = Cₜᵀ hₜ + D xₜ
1082        /// ```
1083        ///
1084        /// For MIMO (mimo_rank=R>1):
1085        /// ```text
1086        ///   hₜ = αₜ hₜ₋₁ + Σ_r βₜ B_{t-1}[r] ⊗ (x_{t-1} ⊙ mimo_x[r])
1087        ///                  + Σ_r γₜ Bₜ[r] ⊗ (xₜ ⊙ mimo_x[r])
1088        ///   yₜ[r] = Cₜ[r]ᵀ hₜ + D xₜ ⊙ mimo_x[r]
1089        ///   outₜ = Σ_r mimo_o[r] ⊙ silu(zₜ ⊙ mimo_z[r]) ⊙ yₜ[r]
1090        /// ```
1091        ///
1092        /// # Shapes
1093        /// - `input_bm` : `[batch, d_model]`
1094        /// - output     : `[batch, d_model]`
1095        #[allow(non_snake_case)]
1096        pub fn step(
1097            &self,
1098            input_bm: Tensor<B, 2>,
1099            cache: Option<Mamba3Cache<B>>,
1100        ) -> (Tensor<B, 2>, Mamba3Cache<B>) {
1101            let [batch, d_model] = input_bm.dims();
1102            let d_inner = self.d_inner();
1103            let nheads = self.nheads();
1104            let ngroups = self.ngroups;
1105            let per_head_dim = self.per_head_dim();
1106            let state_rank = self.state_rank;
1107            let num_rope_angles = self.num_rope_angles;
1108            let heads_per_group = nheads / ngroups;
1109            let mimo_rank = self.mimo_rank;
1110            let device = &input_bm.device();
1111            let ssm_shape = [batch, nheads, per_head_dim, state_rank];
1112
1113            assert_eq!(nheads % ngroups, 0);
1114
1115            let mut cache = cache.unwrap_or_else(|| {
1116                let ssm_bhpr = Tensor::zeros(ssm_shape, device);
1117                let k_state_brhn = Tensor::zeros([batch, mimo_rank, nheads, state_rank], device);
1118                let v_state_bhp = Tensor::zeros([batch, nheads, per_head_dim], device);
1119                let cum_angle_bhr = Tensor::zeros([batch, nheads, num_rope_angles], device);
1120                Mamba3Cache {
1121                    ssm_bhpr,
1122                    k_state_brhn,
1123                    v_state_bhp,
1124                    cum_angle_bhr,
1125                }
1126            });
1127
1128            // ── In-projection ─────────────────────────────────────────────────
1129            let proj_bd = self.in_proj.forward(input_bm);
1130            let bc_size = ngroups * state_rank * mimo_rank;
1131            let mut parts = proj_bd
1132                .split_with_sizes(
1133                    vec![
1134                        d_inner,
1135                        d_inner,
1136                        bc_size,
1137                        bc_size,
1138                        nheads,
1139                        nheads,
1140                        nheads,
1141                        num_rope_angles,
1142                    ],
1143                    1,
1144                )
1145                .into_iter();
1146            let z_bi = parts.next().unwrap(); // [B, d_inner]
1147            let x_bi = parts.next().unwrap(); // [B, d_inner]
1148            let b_raw_bd = parts.next().unwrap(); // [B, ngroups*state_rank*mimo_rank]
1149            let c_raw_bd = parts.next().unwrap();
1150            let dd_dt_bh = parts.next().unwrap(); // [B, nheads]
1151            let dd_A_raw_bh = parts.next().unwrap();
1152            let lam_raw_bh = parts.next().unwrap();
1153            let theta_ba = parts.next().unwrap(); // [B, num_rope_angles]
1154
1155            // ── Reshape x ─────────────────────────────────────────────────────
1156            let x_bhp = x_bi.reshape([batch, nheads, per_head_dim]);
1157
1158            // ── Discretisation ─────────────────────────────────────────────────
1159            let dt_bias_1h = self.dt_bias_h.val().unsqueeze_dim(0);
1160            let dt_bh = softplus(dd_dt_bh + dt_bias_1h).clamp(self.dt_limit.0, self.dt_limit.1);
1161            let a_bh = -softplus(dd_A_raw_bh).clamp(f64::NEG_INFINITY, -self.a_floor);
1162            let da_bh = dt_bh.clone() * a_bh;
1163            let alpha_bh = da_bh.exp();
1164            let lam_bh = sigmoid(lam_raw_bh);
1165            let gamma_bh = lam_bh.clone() * dt_bh.clone();
1166            let beta_bh = (-lam_bh.clone() + 1.0) * dt_bh.clone() * alpha_bh.clone();
1167
1168            // ── QK-Norm on B and C → [B, R, H, N] ────────────────────────────
1169            // b_raw: [B, R*G*N] → [B, R, G, N] → norm → expand → [B, R, H, N] → add bias
1170            let b_brhn = {
1171                let b_brgn = b_raw_bd.reshape([batch, mimo_rank, ngroups, state_rank]);
1172                let b_norm = self
1173                    .b_norm
1174                    .forward(b_brgn.reshape([batch * mimo_rank, ngroups, state_rank]))
1175                    .reshape([batch, mimo_rank, ngroups, state_rank]);
1176                let b_exp = b_norm
1177                    .unsqueeze_dim::<5>(3) // [B, R, G, 1, N]
1178                    .expand([batch, mimo_rank, ngroups, heads_per_group, state_rank])
1179                    .reshape([batch, mimo_rank, nheads, state_rank]);
1180                // bias: [H, R, N] → [R, H, N] → [1, R, H, N]
1181                let bias = self
1182                    .b_bias_hrn
1183                    .val()
1184                    .permute([1, 0, 2])
1185                    .unsqueeze_dim::<4>(0);
1186                b_exp + bias
1187            };
1188            let c_brhn = {
1189                let c_brgn = c_raw_bd.reshape([batch, mimo_rank, ngroups, state_rank]);
1190                let c_norm = self
1191                    .c_norm
1192                    .forward(c_brgn.reshape([batch * mimo_rank, ngroups, state_rank]))
1193                    .reshape([batch, mimo_rank, ngroups, state_rank]);
1194                let c_exp = c_norm
1195                    .unsqueeze_dim::<5>(3)
1196                    .expand([batch, mimo_rank, ngroups, heads_per_group, state_rank])
1197                    .reshape([batch, mimo_rank, nheads, state_rank]);
1198                let bias = self
1199                    .c_bias_hrn
1200                    .val()
1201                    .permute([1, 0, 2])
1202                    .unsqueeze_dim::<4>(0);
1203                c_exp + bias
1204            };
1205            assert_eq!([batch, mimo_rank, nheads, state_rank], b_brhn.dims());
1206
1207            // ── RoPE: update cumulative angle, rotate B and C ──────────────────
1208            let theta_scaled_ba = theta_ba.tanh() * std::f32::consts::PI;
1209            let raw_angle_bha = dt_bh.unsqueeze_dim::<3>(2) * theta_scaled_ba.unsqueeze_dim::<3>(1);
1210            let new_cum_angle_bha = cache.cum_angle_bhr.clone() + raw_angle_bha;
1211
1212            // Broadcast angles over R: [b, H, A] → [b, 1, H, A] → [b, R, H, A]
1213            let angles_brha = new_cum_angle_bha.clone().unsqueeze_dim::<4>(1).expand([
1214                batch,
1215                mimo_rank,
1216                nheads,
1217                num_rope_angles,
1218            ]);
1219            // SISO uses interleaved (pairwise) pairing; MIMO uses half-and-half.
1220            // Partial RoPE: rotate only the first `rope_dim` entries of B/C.
1221            let rotate_pairwise = mimo_rank == 1;
1222            let rope_dim = self.rope_dim;
1223            let b_brhn =
1224                apply_rope_partial::<B, 4>(b_brhn, angles_brha.clone(), rope_dim, rotate_pairwise);
1225            let c_brhn = apply_rope_partial::<B, 4>(c_brhn, angles_brha, rope_dim, rotate_pairwise);
1226
1227            // ── Build MIMO value tensors ───────────────────────────────────────
1228            // x_vals[b, r, h, p] = x[b, h, p] * mimo_x[h, r, p]
1229            // xs_vals[b, r, h, p] = x_state[b, h, p] * mimo_x[h, r, p]
1230            let mimo_x_val = self.mimo_x.as_ref().map(|p| p.val());
1231            let (x_vals_brhp, xs_vals_brhp) = build_mimo_vals(
1232                x_bhp.clone(),
1233                cache.v_state_bhp.clone(),
1234                mimo_x_val.as_ref(),
1235                batch,
1236                mimo_rank,
1237                nheads,
1238                per_head_dim,
1239                device,
1240            );
1241
1242            // ── SSM state update ───────────────────────────────────────────────
1243            // new_state[b, h, p, n] = alpha * state
1244            //   + sum_r gamma * x_vals[r] ⊗ B_cur[r]
1245            //   + sum_r beta  * xs_vals[r] ⊗ B_state[r]
1246            //
1247            // For the outer product sum:
1248            //   xBt[b, h, p, n] = sum_r coeff[r, h, p] * B[r, h, n]
1249            //   = einsum('brhp,brhn->bhpn', coeff*x_vals, B)
1250            //   = matmul over r: [b, h, p, r] @ [b, h, r, n]
1251            // x_vals_brhp * gamma_b1h1 (broadcast: [b, r, h, p] * [b, 1, h, 1]):
1252            // Need gamma as [b, 1, h, 1] to broadcast over r and p:
1253            let gamma_b1h1 = gamma_bh.clone().unsqueeze_dim::<3>(1).unsqueeze_dim::<4>(3); // [b, 1, h, 1]
1254            let beta_b1h1 = beta_bh.clone().unsqueeze_dim::<3>(1).unsqueeze_dim::<4>(3);
1255
1256            let x_gamma_brhp = x_vals_brhp.clone() * gamma_b1h1; // [b, r, h, p]
1257            let x_beta_brhp = xs_vals_brhp * beta_b1h1; // [b, r, h, p]
1258
1259            // einsum('brhp,brhn->bhpn', x_gamma, B_cur):
1260            // [b, r, h, p] → permute to [b, h, p, r]
1261            // [b, r, h, n] → permute to [b, h, r, n]
1262            // matmul: [b, h, p, r] @ [b, h, r, n] = [b, h, p, n]
1263            let xBt_state = {
1264                let b_bhrn = b_brhn.clone().permute([0, 2, 1, 3]); // [b, h, r, n]
1265                let xg_bhpr = x_gamma_brhp.permute([0, 2, 3, 1]); // [b, h, p, r]
1266                xg_bhpr.matmul(b_bhrn) // [b, h, p, n]
1267            };
1268            let xBt_prev = {
1269                let b_state_bhrn = cache.k_state_brhn.clone().permute([0, 2, 1, 3]); // [b, h, r, n]
1270                let xb_bhpr = x_beta_brhp.permute([0, 2, 3, 1]); // [b, h, p, r]
1271                xb_bhpr.matmul(b_state_bhrn) // [b, h, p, n]
1272            };
1273
1274            let alpha_bh11 = alpha_bh.unsqueeze_dims::<4>(&[2, 3]);
1275            let new_state_bhpn = alpha_bh11 * cache.ssm_bhpr.clone() + xBt_state + xBt_prev;
1276
1277            // ── Output ────────────────────────────────────────────────────────
1278            // out_r[b, r, h, p] = sum_n C[b, r, h, n] * state[b, h, p, n] + D * x_vals[b, r, h, p]
1279            // = einsum('bhpn,brhn->brhp', state, C)
1280            let out_r_brhp = {
1281                // state: [b, h, p, n], C: [b, r, h, n]
1282                // For each (b, h): [p, n] @ [r, n]^T = [p, r]
1283                // state: [b, h, p, n] → [b, h, p, n]
1284                // C_bhrn: [b, h, r, n] = b_brhn permuted = c_brhn permuted
1285                let c_bhrn = c_brhn.permute([0, 2, 1, 3]); // [b, h, r, n]
1286                // [b, h, p, n] @ [b, h, n, r] = [b, h, p, r]
1287                let c_bhnr = c_bhrn.permute([0, 1, 3, 2]); // [b, h, n, r]
1288                let out_bhpr = new_state_bhpn.clone().matmul(c_bhnr); // [b, h, p, r]
1289                out_bhpr.permute([0, 3, 1, 2]) // [b, r, h, p]
1290            };
1291
1292            // D skip: D[h] * x_vals[b, R, H, P], broadcast [1, 1, H, 1] over [b, R, H, P]
1293            let d_skip = self
1294                .d_h
1295                .val()
1296                .unsqueeze_dims::<4>(&[0, 1, 3]) // [1, 1, h, 1]
1297                .expand([batch, mimo_rank, nheads, per_head_dim]);
1298            let out_r_brhp = out_r_brhp + d_skip * x_vals_brhp;
1299
1300            // ── Gate (or gated norm) and rank aggregation ─────────────────────
1301            // When `out_norm` is set, the SiLU gate is replaced by a per-head
1302            // gated RMSNorm: `RmsNormGated(y, z) = norm(y) * silu(z)`.
1303            let z_bhp = z_bi.reshape([batch, nheads, per_head_dim]);
1304            let y_bi = if mimo_rank > 1 {
1305                let mimo_z_val = self.mimo_z.as_ref().map(|p| p.val()).unwrap();
1306                let mimo_o_val = self.mimo_o.as_ref().map(|p| p.val()).unwrap();
1307
1308                // z_r = z * mimo_z[r]: z[b, h, p] * mimo_z[h, r, p] → [b, r, h, p]
1309                let z_exp =
1310                    z_bhp
1311                        .unsqueeze_dim::<4>(1)
1312                        .expand([batch, mimo_rank, nheads, per_head_dim]);
1313                // mimo_z: [h, r, p] → [r, h, p] → [1, r, h, p]
1314                let mz = mimo_z_val.permute([1, 0, 2]).unsqueeze_dim::<4>(0).expand([
1315                    batch,
1316                    mimo_rank,
1317                    nheads,
1318                    per_head_dim,
1319                ]);
1320                let z_r = z_exp * mz;
1321
1322                // Per-rank gate or gated norm.
1323                let combined = match &self.out_norm {
1324                    Some(norm) => norm.forward(out_r_brhp, z_r),
1325                    None => out_r_brhp * Silu::new().forward(z_r),
1326                };
1327
1328                // Project down: out = sum_r mimo_o[r] * combined[r]
1329                // mimo_o: [h, r, p] → [r, h, p] → [1, r, h, p]
1330                let mo = mimo_o_val.permute([1, 0, 2]).unsqueeze_dim::<4>(0).expand([
1331                    batch,
1332                    mimo_rank,
1333                    nheads,
1334                    per_head_dim,
1335                ]);
1336                let out_bhp: Tensor<B, 3> = (combined * mo).sum_dim(1).squeeze_dim(1);
1337                out_bhp.reshape([batch, d_inner])
1338            } else {
1339                // SISO: squeeze rank dim, gate (or gated norm) over per_head_dim.
1340                let y_bhp: Tensor<B, 3> = out_r_brhp.squeeze_dim(1); // [b, h, p]
1341                let combined = match &self.out_norm {
1342                    Some(norm) => norm.forward(y_bhp, z_bhp),
1343                    None => y_bhp * Silu::new().forward(z_bhp),
1344                };
1345                combined.reshape([batch, d_inner])
1346            };
1347
1348            // ── Out-projection ────────────────────────────────────────────────
1349            let out_bm = self.out_proj.forward(y_bi);
1350            assert_eq!([batch, d_model], out_bm.dims());
1351
1352            // ── Update cache ──────────────────────────────────────────────────
1353            cache.ssm_bhpr = new_state_bhpn;
1354            // k_state: B at current token [b, R, H, N]
1355            cache.k_state_brhn = b_brhn; // already [b, r, h, n]
1356            cache.v_state_bhp = x_bhp;
1357            cache.cum_angle_bhr = new_cum_angle_bha;
1358
1359            (out_bm, cache)
1360        }
1361    }
1362
1363    /// Build MIMO value tensors for x_current and x_state.
1364    ///
1365    /// Returns `(x_vals_brhp, xs_vals_brhp)` both of shape `[batch, mimo_rank, nheads, per_head_dim]`.
1366    fn build_mimo_vals<B: Backend>(
1367        x_bhp: Tensor<B, 3>,
1368        x_state_bhp: Tensor<B, 3>,
1369        mimo_x: Option<&Tensor<B, 3>>,
1370        batch: usize,
1371        mimo_rank: usize,
1372        nheads: usize,
1373        per_head_dim: usize,
1374        _device: &B::Device,
1375    ) -> (Tensor<B, 4>, Tensor<B, 4>) {
1376        match mimo_x {
1377            None => {
1378                // SISO: add rank dim of 1
1379                (
1380                    x_bhp.unsqueeze_dim::<4>(1),
1381                    x_state_bhp.unsqueeze_dim::<4>(1),
1382                )
1383            }
1384            Some(mx) => {
1385                // x: [b, h, p] → [b, 1, h, p] → [b, R, h, p]
1386                let x_exp =
1387                    x_bhp
1388                        .unsqueeze_dim::<4>(1)
1389                        .expand([batch, mimo_rank, nheads, per_head_dim]);
1390                let xs_exp = x_state_bhp.unsqueeze_dim::<4>(1).expand([
1391                    batch,
1392                    mimo_rank,
1393                    nheads,
1394                    per_head_dim,
1395                ]);
1396                // mimo_x: [h, r, p] → [r, h, p] → [1, r, h, p]
1397                let mx_exp = mx.clone().permute([1, 0, 2]).unsqueeze_dim::<4>(0).expand([
1398                    batch,
1399                    mimo_rank,
1400                    nheads,
1401                    per_head_dim,
1402                ]);
1403                (x_exp * mx_exp.clone(), xs_exp * mx_exp)
1404            }
1405        }
1406    }
1407}
1408
1409// ---------------------------------------------------------------------------
1410// Tests
1411// ---------------------------------------------------------------------------
1412
1413#[cfg(all(test, feature = "backend-flex"))]
1414mod tests {
1415    use super::*;
1416    use burn::backend::{Autodiff, Flex};
1417    use burn::tensor::Distribution;
1418
1419    /// Inner (non-autodiff) backend used for materialising values and
1420    /// extracted gradients.
1421    type InnerB = Flex;
1422    /// Autodiff-wrapped backend used to drive `.backward()`.
1423    type B = Autodiff<InnerB>;
1424
1425    type Device = <InnerB as burn::tensor::backend::BackendTypes>::Device;
1426
1427    fn small_config() -> Mamba3Config {
1428        Mamba3Config::new(32) // d_model = 32
1429            .with_state_rank(8)
1430            .with_expand(2)
1431            .with_per_head_dim(8)
1432    }
1433
1434    fn small_config_mimo() -> Mamba3Config {
1435        Mamba3Config::new(32)
1436            .with_state_rank(8)
1437            .with_expand(2)
1438            .with_per_head_dim(8)
1439            .with_mimo_rank(2)
1440    }
1441
1442    /// A bundle of input + model-parameter gradients extracted from one
1443    /// forward+backward run.  Each `check_grads_match` call compares these
1444    /// across two runs that should be mathematically equivalent.
1445    struct RunGrads {
1446        out: Tensor<InnerB, 3>,
1447        d_input: Tensor<InnerB, 3>,
1448        d_in_proj_w: Tensor<InnerB, 2>,
1449        d_dt_bias: Tensor<InnerB, 1>,
1450        d_d: Tensor<InnerB, 1>,
1451        d_b_norm_gamma: Tensor<InnerB, 1>,
1452        d_c_norm_gamma: Tensor<InnerB, 1>,
1453        d_b_bias: Tensor<InnerB, 3>,
1454        d_c_bias: Tensor<InnerB, 3>,
1455        d_out_proj_w: Tensor<InnerB, 2>,
1456    }
1457
1458    /// Run a closure that produces an output tensor from a model and an input
1459    /// (wrapped as a `Param` so it has its own autodiff leaf), then derive a
1460    /// scalar loss with a fixed (non-tracked) random "head" and return the
1461    /// gradients of the input and a representative set of model parameters.
1462    fn run_with_grads(
1463        model: &Mamba3<B>,
1464        input: &Param<Tensor<B, 3>>,
1465        head: &Tensor<InnerB, 3>,
1466        forward: impl FnOnce(&Mamba3<B>, Tensor<B, 3>) -> Tensor<B, 3>,
1467    ) -> RunGrads {
1468        let out = forward(model, input.val());
1469        let out_inner = out.clone().inner();
1470
1471        let head = Tensor::from_inner(head.clone());
1472        let loss = (out * head).sum();
1473        let grads = loss.backward();
1474
1475        RunGrads {
1476            out: out_inner,
1477            d_input: input.val().grad(&grads).expect("grad input"),
1478            d_in_proj_w: model
1479                .in_proj
1480                .weight
1481                .val()
1482                .grad(&grads)
1483                .expect("grad in_proj.weight"),
1484            d_dt_bias: model.dt_bias_h.val().grad(&grads).expect("grad dt_bias_h"),
1485            d_d: model.d_h.val().grad(&grads).expect("grad d_h"),
1486            d_b_norm_gamma: model
1487                .b_norm
1488                .gamma
1489                .val()
1490                .grad(&grads)
1491                .expect("grad b_norm.gamma"),
1492            d_c_norm_gamma: model
1493                .c_norm
1494                .gamma
1495                .val()
1496                .grad(&grads)
1497                .expect("grad c_norm.gamma"),
1498            d_b_bias: model
1499                .b_bias_hrn
1500                .val()
1501                .grad(&grads)
1502                .expect("grad b_bias_hrn"),
1503            d_c_bias: model
1504                .c_bias_hrn
1505                .val()
1506                .grad(&grads)
1507                .expect("grad c_bias_hrn"),
1508            d_out_proj_w: model
1509                .out_proj
1510                .weight
1511                .val()
1512                .grad(&grads)
1513                .expect("grad out_proj.weight"),
1514        }
1515    }
1516
1517    /// Assert that every entry in `a` and `b` agrees to within `grad_tol`,
1518    /// printing every comparison so a failure dump shows the full picture
1519    /// (instead of stopping at the first mismatch).
1520    fn check_grads_match(label: &str, a: &RunGrads, b: &RunGrads, grad_tol: f32) {
1521        let mut failures: Vec<String> = Vec::new();
1522        macro_rules! check {
1523            ($field:ident, $name:expr) => {{
1524                let d = (a.$field.clone() - b.$field.clone())
1525                    .abs()
1526                    .max()
1527                    .into_scalar();
1528                eprintln!("{:>40} {:>16} | max abs diff = {:>10.6}", label, $name, d);
1529                if d >= grad_tol {
1530                    failures.push(format!(
1531                        "{}: grad of {} max abs diff = {:.6} (tol {})",
1532                        label, $name, d, grad_tol
1533                    ));
1534                }
1535            }};
1536        }
1537        check!(d_input, "input");
1538        check!(d_in_proj_w, "in_proj.weight");
1539        check!(d_dt_bias, "dt_bias_h");
1540        check!(d_d, "d_h");
1541        check!(d_b_norm_gamma, "b_norm.gamma");
1542        check!(d_c_norm_gamma, "c_norm.gamma");
1543        check!(d_b_bias, "b_bias_hrn");
1544        check!(d_c_bias, "c_bias_hrn");
1545        check!(d_out_proj_w, "out_proj.weight");
1546        assert!(
1547            failures.is_empty(),
1548            "gradient mismatches:\n  {}",
1549            failures.join("\n  ")
1550        );
1551    }
1552
1553    /// Build a fresh `Param<Tensor>` from a stable inner tensor.
1554    /// A new Param is needed per run so that the autodiff leaf has a fresh
1555    /// node, isolating each backward pass to its own forward graph.
1556    fn param_input(input: &Tensor<InnerB, 3>) -> Param<Tensor<B, 3>> {
1557        Param::from_tensor(Tensor::from_inner(input.clone()))
1558    }
1559
1560    fn run_step_matches_forward(cfg: Mamba3Config) {
1561        let device: Device = Default::default();
1562        let model = cfg.init::<B>(&device);
1563
1564        let batch = 2;
1565        let seq_len = 5;
1566        let d_model = cfg.d_model;
1567
1568        let input = Tensor::<InnerB, 3>::random(
1569            [batch, seq_len, d_model],
1570            Distribution::Normal(0.0, 1.0),
1571            &device,
1572        );
1573        let head = Tensor::<InnerB, 3>::random(
1574            [batch, seq_len, d_model],
1575            Distribution::Normal(0.0, 1.0),
1576            &device,
1577        );
1578
1579        let ssd_path = Mamba3SsdPath::Minimal(Some(4));
1580
1581        let input_fwd = param_input(&input);
1582        let r_fwd = run_with_grads(&model, &input_fwd, &head, |m, x| {
1583            let (out, _) = m.forward(x, None, ssd_path.clone());
1584            out
1585        });
1586
1587        let input_step = param_input(&input);
1588        let r_step = run_with_grads(&model, &input_step, &head, |m, x| {
1589            let mut cache: Option<Mamba3Cache<B>> = None;
1590            let mut outs: Vec<Tensor<B, 2>> = Vec::with_capacity(seq_len);
1591            for t in 0..seq_len {
1592                let token = x.clone().narrow(1, t, 1).squeeze_dim(1);
1593                let (out_t, new_cache) = m.step(token, cache);
1594                cache = Some(new_cache);
1595                outs.push(out_t);
1596            }
1597            Tensor::stack(outs, 1)
1598        });
1599
1600        // ── Forward agreement (existing check) ───────────────────────────
1601        let diff = (r_fwd.out.clone() - r_step.out.clone())
1602            .abs()
1603            .max()
1604            .into_scalar();
1605        assert!(
1606            diff < 1e-4,
1607            "step() vs forward() max absolute difference = {diff:.6} (expected < 1e-4)"
1608        );
1609
1610        // ── Gradient agreement ───────────────────────────────────────────
1611        check_grads_match("step vs forward", &r_fwd, &r_step, 1e-3);
1612    }
1613
1614    #[test]
1615    fn step_matches_forward() {
1616        run_step_matches_forward(small_config());
1617    }
1618
1619    #[test]
1620    fn step_matches_forward_ngroups2() {
1621        let cfg = Mamba3Config::new(32)
1622            .with_state_rank(8)
1623            .with_expand(2)
1624            .with_per_head_dim(16)
1625            .with_ngroups(2);
1626        run_step_matches_forward(cfg);
1627    }
1628
1629    #[test]
1630    fn step_matches_forward_mimo() {
1631        run_step_matches_forward(small_config_mimo());
1632    }
1633
1634    #[test]
1635    fn step_matches_forward_mimo_ngroups2() {
1636        let cfg = Mamba3Config::new(32)
1637            .with_state_rank(8)
1638            .with_expand(2)
1639            .with_per_head_dim(16)
1640            .with_ngroups(2)
1641            .with_mimo_rank(2);
1642        run_step_matches_forward(cfg);
1643    }
1644
1645    /// forward(full) ≡ forward(prefix) then forward(suffix, cache_from_prefix).
1646    ///
1647    /// Verifies stateful chunked-prefill: the β term at the start of the second
1648    /// chunk must see `x_{-1}` and `B_{-1}` from the cache, not zeros.
1649    fn run_split_matches_full(cfg: Mamba3Config) {
1650        let device: Device = Default::default();
1651        let model = cfg.init::<B>(&device);
1652
1653        let batch = 2;
1654        let seq_len = 6;
1655        let split = 2;
1656        let d_model = cfg.d_model;
1657
1658        let input = Tensor::<InnerB, 3>::random(
1659            [batch, seq_len, d_model],
1660            Distribution::Normal(0.0, 1.0),
1661            &device,
1662        );
1663        let head = Tensor::<InnerB, 3>::random(
1664            [batch, seq_len, d_model],
1665            Distribution::Normal(0.0, 1.0),
1666            &device,
1667        );
1668
1669        let ssd_path = Mamba3SsdPath::Minimal(Some(4));
1670
1671        let input_full = param_input(&input);
1672        let r_full = run_with_grads(&model, &input_full, &head, |m, x| {
1673            let (out, _) = m.forward(x, None, ssd_path.clone());
1674            out
1675        });
1676
1677        let input_split = param_input(&input);
1678        let r_split = run_with_grads(&model, &input_split, &head, |m, x| {
1679            let prefix = x.clone().narrow(1, 0, split);
1680            let suffix = x.narrow(1, split, seq_len - split);
1681            let (out_prefix, cache) = m.forward(prefix, None, ssd_path.clone());
1682            let (out_suffix, _) = m.forward(suffix, Some(cache), ssd_path.clone());
1683            Tensor::cat(vec![out_prefix, out_suffix], 1)
1684        });
1685
1686        // ── Forward agreement (existing check) ───────────────────────────
1687        let diff = (r_full.out.clone() - r_split.out.clone())
1688            .abs()
1689            .max()
1690            .into_scalar();
1691        assert!(
1692            diff < 1e-4,
1693            "split forward vs full forward max absolute difference = {diff:.6} (expected < 1e-4)"
1694        );
1695
1696        // ── Gradient agreement ───────────────────────────────────────────
1697        check_grads_match("split vs full", &r_full, &r_split, 1e-3);
1698    }
1699
1700    #[test]
1701    fn split_matches_full() {
1702        run_split_matches_full(small_config());
1703    }
1704
1705    #[test]
1706    fn split_matches_full_ngroups2() {
1707        let cfg = Mamba3Config::new(32)
1708            .with_state_rank(8)
1709            .with_expand(2)
1710            .with_per_head_dim(16)
1711            .with_ngroups(2);
1712        run_split_matches_full(cfg);
1713    }
1714
1715    #[test]
1716    fn split_matches_full_mimo() {
1717        run_split_matches_full(small_config_mimo());
1718    }
1719
1720    #[test]
1721    fn split_matches_full_mimo_ngroups2() {
1722        let cfg = Mamba3Config::new(32)
1723            .with_state_rank(8)
1724            .with_expand(2)
1725            .with_per_head_dim(16)
1726            .with_ngroups(2)
1727            .with_mimo_rank(2);
1728        run_split_matches_full(cfg);
1729    }
1730
1731    // ── rope_fraction = 0.5 (partial RoPE) ──────────────────────────────────
1732
1733    #[test]
1734    fn step_matches_forward_rope_half() {
1735        let cfg = Mamba3Config::new(32)
1736            .with_state_rank(8)
1737            .with_expand(2)
1738            .with_per_head_dim(8)
1739            .with_rope_fraction(0.5);
1740        run_step_matches_forward(cfg);
1741    }
1742
1743    #[test]
1744    fn step_matches_forward_rope_half_mimo() {
1745        let cfg = Mamba3Config::new(32)
1746            .with_state_rank(8)
1747            .with_expand(2)
1748            .with_per_head_dim(8)
1749            .with_mimo_rank(2)
1750            .with_rope_fraction(0.5);
1751        run_step_matches_forward(cfg);
1752    }
1753
1754    #[test]
1755    fn split_matches_full_rope_half() {
1756        let cfg = Mamba3Config::new(32)
1757            .with_state_rank(8)
1758            .with_expand(2)
1759            .with_per_head_dim(8)
1760            .with_rope_fraction(0.5);
1761        run_split_matches_full(cfg);
1762    }
1763
1764    // ── has_outproj_norm = true (gated RMSNorm) ─────────────────────────────
1765
1766    #[test]
1767    fn step_matches_forward_outproj_norm() {
1768        let cfg = Mamba3Config::new(32)
1769            .with_state_rank(8)
1770            .with_expand(2)
1771            .with_per_head_dim(8)
1772            .with_has_outproj_norm(true);
1773        run_step_matches_forward(cfg);
1774    }
1775
1776    #[test]
1777    fn step_matches_forward_outproj_norm_mimo() {
1778        let cfg = Mamba3Config::new(32)
1779            .with_state_rank(8)
1780            .with_expand(2)
1781            .with_per_head_dim(8)
1782            .with_mimo_rank(2)
1783            .with_has_outproj_norm(true);
1784        run_step_matches_forward(cfg);
1785    }
1786
1787    #[test]
1788    fn split_matches_full_outproj_norm() {
1789        let cfg = Mamba3Config::new(32)
1790            .with_state_rank(8)
1791            .with_expand(2)
1792            .with_per_head_dim(8)
1793            .with_has_outproj_norm(true);
1794        run_split_matches_full(cfg);
1795    }
1796
1797    // ── Both features combined ──────────────────────────────────────────────
1798
1799    #[test]
1800    fn step_matches_forward_rope_half_outproj_norm_mimo() {
1801        let cfg = Mamba3Config::new(32)
1802            .with_state_rank(8)
1803            .with_expand(2)
1804            .with_per_head_dim(8)
1805            .with_mimo_rank(2)
1806            .with_rope_fraction(0.5)
1807            .with_has_outproj_norm(true);
1808        run_step_matches_forward(cfg);
1809    }
1810}