burn_mamba/mamba3/single_ssd/cache.rs
1//! # Mamba-3 Single-pass SSD Inference Cache
2//!
3//! The cache used by [`crate::mamba3::mamba3::Mamba3::forward_single_ssd`]
4//! (the single-pass SSD algorithm — see the Triton SISO and Tilelang MIMO
5//! reference kernels).
6//! The four tensor fields mirror those of [`Mamba3Cache`] but their
7//! **SSM accumulator carries different semantics**:
8//!
9//! - [`Mamba3Cache`]: `ssm_bhpr` holds the double-ssd trapezoidal hidden state
10//! `hₜ = αₜ hₜ₋₁ + βₜ Bₜ₋₁ ⊗ xₜ₋₁ + γₜ Bₜ ⊗ xₜ`.
11//! - [`Mamba3SingleSsdCache`]: `ssm_bhpr` holds the **trapezoid accumulator** `h'ₜ`
12//! defined by `h'ₜ = αₜ h'ₜ₋₁ + scaleₜ Bₜ ⊗ xₜ`, where
13//! `scaleₜ = γₜ + (1 − λₜ₊₁) · Δₜ₊₁`. The single-ssd form gives the correct output
14//! `yₜ = Cₜᵀ h'ₜ` for all positions except the diagonal (s = t), which is
15//! patched by an explicit `γₜ · (Cₜᵀ Bₜ) · xₜ` correction term in the kernel.
16//!
17//! Because the two accumulators differ, the two caches are not interchangeable.
18//! The distinct type prevents accidentally feeding a `forward_double_ssd` cache into
19//! `forward_single_ssd` (or vice versa) mid-sequence — that would silently corrupt state.
20
21use crate::mamba3::prelude::*;
22use crate::modules::sanity as san;
23use burn::module::Module;
24use burn::prelude::*;
25
26// ---------------------------------------------------------------------------
27// Mamba3SingleSsdCaches (one cache entry per layer)
28// ---------------------------------------------------------------------------
29
30/// A collection of per-layer single-ssd form caches for a complete Mamba-3 network.
31#[derive(Module, Debug)]
32pub struct Mamba3SingleSsdCaches {
33 /// Per-layer caches. Length equals the number of virtual layers.
34 pub caches: Vec<Mamba3SingleSsdCache>,
35}
36
37/// Configuration / factory for [`Mamba3SingleSsdCaches`].
38#[derive(Config, Debug)]
39pub struct Mamba3SingleSsdCachesConfig {
40 /// Number of cache slots (= number of virtual layers).
41 pub n_real_caches: usize,
42
43 /// Shared configuration that determines the shape of each cache.
44 pub cache: Mamba3SingleSsdCacheConfig,
45}
46
47impl Mamba3SingleSsdCachesConfig {
48 /// Convenience constructor from a block config.
49 pub fn new_from_block_config(
50 n_real_caches: usize,
51 batch: usize,
52 block_config: Mamba3Config,
53 ) -> Self {
54 Self {
55 n_real_caches,
56 cache: Mamba3SingleSsdCacheConfig::new_from_block_config(batch, block_config),
57 }
58 }
59
60 /// Allocate all cache tensors (zero-initialised) on `device`.
61 pub fn init(&self, device: &Device) -> Mamba3SingleSsdCaches {
62 let caches = (0..self.n_real_caches)
63 .map(|_| self.cache.clone().init(device))
64 .collect();
65 Mamba3SingleSsdCaches { caches }
66 }
67}
68
69// ---------------------------------------------------------------------------
70// Mamba3SingleSsdCache (state for a single layer)
71// ---------------------------------------------------------------------------
72
73/// Mutable state for a single Mamba-3 layer running the single-ssd form algorithm.
74///
75/// Tensor shapes match [`Mamba3Cache`]. The semantic difference lives entirely
76/// in `ssm_bhpr` (see the module-level documentation).
77#[derive(Module, Debug)]
78pub struct Mamba3SingleSsdCache {
79 /// **SingleSsd-form SSM accumulator** `h'ₜ`.
80 ///
81 /// Update rule: `h'ₜ = αₜ h'ₜ₋₁ + scaleₜ · sumₘ Bₜ[m] ⊗ (xₜ ⊙ mimo_xₘ)`.
82 /// Different from `Mamba3Cache::ssm_bhpr`.
83 ///
84 /// Shape: `[batch, nheads, per_head_dim, state_rank]`
85 pub ssm_bhpr: Tensor<4>,
86
87 /// **Previous token's K per mimo rank** = post-RoPE, post-bias `Bₜ₋₁[m]`.
88 ///
89 /// Used at the start of the next forward_single_ssd call to seed the boundary β
90 /// contribution `(1 − λ₀) · Δ₀ · Bₜ₋₁ ⊗ xₜ₋₁` (which the previous call could
91 /// not yet add because it did not know `λ₀, Δ₀`).
92 ///
93 /// Shape: `[batch, mimo_rank, nheads, state_rank]`
94 pub k_state_bmhr: Tensor<4>,
95
96 /// **Previous token's x** = `xₜ₋₁`.
97 ///
98 /// Paired with [`Self::k_state_bmhr`] to form the boundary β term.
99 ///
100 /// Shape: `[batch, nheads, per_head_dim]`
101 pub v_state_bhp: Tensor<3>,
102
103 /// **Cumulative data-dependent rotation** up to the current position
104 /// ([`RotationState`]).
105 ///
106 /// Same role as in [`Mamba3Cache`]: continued across calls for streaming.
107 /// Carries the same value as the double-ssd cache's field (the `From` impls
108 /// move it across), so the two caches still inter-convert by field identity.
109 pub rotation: RotationState,
110}
111
112impl Mamba3SingleSsdCache {
113 /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every cached tensor.
114 pub fn sanity(&self) {
115 san(&self.ssm_bhpr);
116 san(&self.k_state_bmhr);
117 san(&self.v_state_bhp);
118 self.rotation.sanity();
119 }
120}
121
122/// Configuration / factory for a single [`Mamba3SingleSsdCache`].
123#[derive(Config, Debug)]
124pub struct Mamba3SingleSsdCacheConfig {
125 /// Batch size.
126 pub batch: usize,
127
128 /// State rank.
129 #[config(default = 128)]
130 pub state_rank: usize,
131
132 /// Head dimension per_head_dim.
133 #[config(default = 64)]
134 pub per_head_dim: usize,
135
136 /// Number of SSM heads.
137 pub nheads: usize,
138
139 /// MIMO rank. 1 = SISO.
140 #[config(default = 1)]
141 pub mimo_rank: usize,
142
143 /// Number of RoPE angle pairs
144 /// (see [`crate::mamba3::double_ssd::cache::Mamba3DoubleSsdCacheConfig::num_rope_angles`]).
145 pub num_rope_angles: usize,
146
147 /// Which positional rotation the block uses (see
148 /// [`crate::mamba3::double_ssd::cache::Mamba3DoubleSsdCacheConfig::rotation`]).
149 #[config(default = "crate::mamba3::rotation::RotationKind::Complex2D")]
150 pub rotation: RotationKind,
151
152 /// Number of quaternion blocks (`rope_dim / 4`); only used for
153 /// [`RotationKind::Quaternion4D`].
154 #[config(default = 1)]
155 pub num_quat_blocks: usize,
156}
157
158impl Mamba3SingleSsdCacheConfig {
159 /// Derive cache shapes from a Mamba-3 block configuration plus a batch size.
160 pub fn new_from_block_config(batch: usize, block_config: Mamba3Config) -> Self {
161 Self {
162 batch,
163 state_rank: block_config.state_rank,
164 per_head_dim: block_config.per_head_dim,
165 nheads: block_config.nheads(),
166 mimo_rank: block_config.mimo_rank,
167 num_rope_angles: block_config.num_rope_angles(),
168 rotation: block_config.rotation,
169 num_quat_blocks: block_config.num_quat_blocks(),
170 }
171 }
172
173 /// Allocate zero/identity-initialised cache tensors on `device`.
174 pub fn init(&self, device: &Device) -> Mamba3SingleSsdCache {
175 let ssm_bhpr = Tensor::zeros(
176 [self.batch, self.nheads, self.per_head_dim, self.state_rank],
177 device,
178 );
179 let k_state_bmhr = Tensor::zeros(
180 [self.batch, self.mimo_rank, self.nheads, self.state_rank],
181 device,
182 );
183 let v_state_bhp = Tensor::zeros([self.batch, self.nheads, self.per_head_dim], device);
184 let rotation = match self.rotation {
185 RotationKind::Quaternion4D => RotationState::identity_quaternion(
186 self.batch,
187 self.nheads,
188 self.num_quat_blocks,
189 device,
190 ),
191 RotationKind::Complex2D => {
192 RotationState::zeros_angle(self.batch, self.nheads, self.num_rope_angles, device)
193 }
194 };
195 Mamba3SingleSsdCache {
196 ssm_bhpr,
197 k_state_bmhr,
198 v_state_bhp,
199 rotation,
200 }
201 }
202}