Skip to main content

burn_mamba/mamba3/single_ssd/
cache.rs

1//! # Mamba-3 Single-pass SSD Inference Cache
2//!
3//! The cache used by [`crate::mamba3::mamba3::Mamba3::forward_single_ssd`]
4//! (the single-pass SSD algorithm — see the Triton SISO and Tilelang MIMO
5//! reference kernels).
6//! The four tensor fields mirror those of [`Mamba3Cache`] but their
7//! **SSM accumulator carries different semantics**:
8//!
9//! - [`Mamba3Cache`]: `ssm_bhpr` holds the double-ssd trapezoidal hidden state
10//!   `hₜ = αₜ hₜ₋₁ + βₜ Bₜ₋₁ ⊗ xₜ₋₁ + γₜ Bₜ ⊗ xₜ`.
11//! - [`Mamba3SingleSsdCache`]: `ssm_bhpr` holds the **trapezoid accumulator** `h'ₜ`
12//!   defined by `h'ₜ = αₜ h'ₜ₋₁ + scaleₜ Bₜ ⊗ xₜ`, where
13//!   `scaleₜ = γₜ + (1 − λₜ₊₁) · Δₜ₊₁`. The single-ssd form gives the correct output
14//!   `yₜ = Cₜᵀ h'ₜ` for all positions except the diagonal (s = t), which is
15//!   patched by an explicit `γₜ · (Cₜᵀ Bₜ) · xₜ` correction term in the kernel.
16//!
17//! Because the two accumulators differ, the two caches are not interchangeable.
18//! The distinct type prevents accidentally feeding a `forward_double_ssd` cache into
19//! `forward_single_ssd` (or vice versa) mid-sequence — that would silently corrupt state.
20
21use crate::mamba3::prelude::*;
22use crate::modules::sanity as san;
23use burn::module::Module;
24use burn::prelude::*;
25
26// ---------------------------------------------------------------------------
27// Mamba3SingleSsdCaches  (one cache entry per layer)
28// ---------------------------------------------------------------------------
29
30/// A collection of per-layer single-ssd form caches for a complete Mamba-3 network.
31#[derive(Module, Debug)]
32pub struct Mamba3SingleSsdCaches {
33    /// Per-layer caches. Length equals the number of virtual layers.
34    pub caches: Vec<Mamba3SingleSsdCache>,
35}
36
37/// Configuration / factory for [`Mamba3SingleSsdCaches`].
38#[derive(Config, Debug)]
39pub struct Mamba3SingleSsdCachesConfig {
40    /// Number of cache slots (= number of virtual layers).
41    pub n_real_caches: usize,
42
43    /// Shared configuration that determines the shape of each cache.
44    pub cache: Mamba3SingleSsdCacheConfig,
45}
46
47impl Mamba3SingleSsdCachesConfig {
48    /// Convenience constructor from a block config.
49    pub fn new_from_block_config(
50        n_real_caches: usize,
51        batch: usize,
52        block_config: Mamba3Config,
53    ) -> Self {
54        Self {
55            n_real_caches,
56            cache: Mamba3SingleSsdCacheConfig::new_from_block_config(batch, block_config),
57        }
58    }
59
60    /// Allocate all cache tensors (zero-initialised) on `device`.
61    pub fn init(&self, device: &Device) -> Mamba3SingleSsdCaches {
62        let caches = (0..self.n_real_caches)
63            .map(|_| self.cache.clone().init(device))
64            .collect();
65        Mamba3SingleSsdCaches { caches }
66    }
67}
68
69// ---------------------------------------------------------------------------
70// Mamba3SingleSsdCache  (state for a single layer)
71// ---------------------------------------------------------------------------
72
73/// Mutable state for a single Mamba-3 layer running the single-ssd form algorithm.
74///
75/// Tensor shapes match [`Mamba3Cache`]. The semantic difference lives entirely
76/// in `ssm_bhpr` (see the module-level documentation).
77#[derive(Module, Debug)]
78pub struct Mamba3SingleSsdCache {
79    /// **SingleSsd-form SSM accumulator** `h'ₜ`.
80    ///
81    /// Update rule: `h'ₜ = αₜ h'ₜ₋₁ + scaleₜ · sumₘ Bₜ[m] ⊗ (xₜ ⊙ mimo_xₘ)`.
82    /// Different from `Mamba3Cache::ssm_bhpr`.
83    ///
84    /// Shape: `[batch, nheads, per_head_dim, state_rank]`
85    pub ssm_bhpr: Tensor<4>,
86
87    /// **Previous token's K per mimo rank** = post-RoPE, post-bias `Bₜ₋₁[m]`.
88    ///
89    /// Used at the start of the next forward_single_ssd call to seed the boundary β
90    /// contribution `(1 − λ₀) · Δ₀ · Bₜ₋₁ ⊗ xₜ₋₁` (which the previous call could
91    /// not yet add because it did not know `λ₀, Δ₀`).
92    ///
93    /// Shape: `[batch, mimo_rank, nheads, state_rank]`
94    pub k_state_bmhr: Tensor<4>,
95
96    /// **Previous token's x** = `xₜ₋₁`.
97    ///
98    /// Paired with [`Self::k_state_bmhr`] to form the boundary β term.
99    ///
100    /// Shape: `[batch, nheads, per_head_dim]`
101    pub v_state_bhp: Tensor<3>,
102
103    /// **Cumulative data-dependent rotation** up to the current position
104    /// ([`RotationState`]).
105    ///
106    /// Same role as in [`Mamba3Cache`]: continued across calls for streaming.
107    /// Carries the same value as the double-ssd cache's field (the `From` impls
108    /// move it across), so the two caches still inter-convert by field identity.
109    pub rotation: RotationState,
110}
111
112impl Mamba3SingleSsdCache {
113    /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every cached tensor.
114    pub fn sanity(&self) {
115        san(&self.ssm_bhpr);
116        san(&self.k_state_bmhr);
117        san(&self.v_state_bhp);
118        self.rotation.sanity();
119    }
120}
121
122/// Configuration / factory for a single [`Mamba3SingleSsdCache`].
123#[derive(Config, Debug)]
124pub struct Mamba3SingleSsdCacheConfig {
125    /// Batch size.
126    pub batch: usize,
127
128    /// State rank.
129    #[config(default = 128)]
130    pub state_rank: usize,
131
132    /// Head dimension per_head_dim.
133    #[config(default = 64)]
134    pub per_head_dim: usize,
135
136    /// Number of SSM heads.
137    pub nheads: usize,
138
139    /// MIMO rank. 1 = SISO.
140    #[config(default = 1)]
141    pub mimo_rank: usize,
142
143    /// Number of RoPE angle pairs
144    /// (see [`crate::mamba3::double_ssd::cache::Mamba3DoubleSsdCacheConfig::num_rope_angles`]).
145    pub num_rope_angles: usize,
146
147    /// Which positional rotation the block uses (see
148    /// [`crate::mamba3::double_ssd::cache::Mamba3DoubleSsdCacheConfig::rotation`]).
149    #[config(default = "crate::mamba3::rotation::RotationKind::Complex2D")]
150    pub rotation: RotationKind,
151
152    /// Number of quaternion blocks (`rope_dim / 4`); only used for
153    /// [`RotationKind::Quaternion4D`].
154    #[config(default = 1)]
155    pub num_quat_blocks: usize,
156}
157
158impl Mamba3SingleSsdCacheConfig {
159    /// Derive cache shapes from a Mamba-3 block configuration plus a batch size.
160    pub fn new_from_block_config(batch: usize, block_config: Mamba3Config) -> Self {
161        Self {
162            batch,
163            state_rank: block_config.state_rank,
164            per_head_dim: block_config.per_head_dim,
165            nheads: block_config.nheads(),
166            mimo_rank: block_config.mimo_rank,
167            num_rope_angles: block_config.num_rope_angles(),
168            rotation: block_config.rotation,
169            num_quat_blocks: block_config.num_quat_blocks(),
170        }
171    }
172
173    /// Allocate zero/identity-initialised cache tensors on `device`.
174    pub fn init(&self, device: &Device) -> Mamba3SingleSsdCache {
175        let ssm_bhpr = Tensor::zeros(
176            [self.batch, self.nheads, self.per_head_dim, self.state_rank],
177            device,
178        );
179        let k_state_bmhr = Tensor::zeros(
180            [self.batch, self.mimo_rank, self.nheads, self.state_rank],
181            device,
182        );
183        let v_state_bhp = Tensor::zeros([self.batch, self.nheads, self.per_head_dim], device);
184        let rotation = match self.rotation {
185            RotationKind::Quaternion4D => RotationState::identity_quaternion(
186                self.batch,
187                self.nheads,
188                self.num_quat_blocks,
189                device,
190            ),
191            RotationKind::Complex2D => {
192                RotationState::zeros_angle(self.batch, self.nheads, self.num_rope_angles, device)
193            }
194        };
195        Mamba3SingleSsdCache {
196            ssm_bhpr,
197            k_state_bmhr,
198            v_state_bhp,
199            rotation,
200        }
201    }
202}