burn_mamba/mamba3/double_ssd/cache.rs
1//! # Mamba-3 Inference Caches
2//!
3//! During autoregressive (token-by-token) generation, three pieces of state
4//! must be preserved between calls:
5//!
6//! 1. **SSM hidden state** — `hₜ ∈ ℝ^{per_head_dim×state_rank}` per head, compressed context.
7//! 2. **Previous K state** — `Bₜ₋₁` per rank `[batch, mimo_rank, nheads, state_rank]`,
8//! needed for the β term of the (double-ssd) trapezoidal recurrence.
9//! 3. **Previous V state** — `xₜ₋₁` per head `[batch, nheads, per_head_dim]`,
10//! paired with k_state to reconstruct β Bₜ₋₁ ⊗ xₜ₋₁.
11//! 4. **Cumulative RoPE angle** — the accumulated rotation angle up to position
12//! `t`, needed to correctly continue data-dependent rotary embeddings.
13//!
14//! Note: Mamba-3 has **no conv cache** (the short 1-dimensional convolution present in
15//! Mamba-3 is removed; its role is absorbed by the trapezoidal discretization
16//! and the learnable B/C biases).
17
18use crate::mamba3::prelude::*;
19use crate::modules::sanity as san;
20use burn::module::Module;
21use burn::prelude::*;
22
23// ---------------------------------------------------------------------------
24// Mamba3DoubleSsdCaches (one cache entry per layer)
25// ---------------------------------------------------------------------------
26
27/// A collection of per-layer caches for a complete Mamba-3 network.
28#[derive(Module, Debug)]
29pub struct Mamba3DoubleSsdCaches {
30 /// Per-layer caches. Length equals the number of virtual layers.
31 pub caches: Vec<Mamba3DoubleSsdCache>,
32}
33
34/// Configuration / factory for [`Mamba3DoubleSsdCaches`].
35#[derive(Config, Debug)]
36pub struct Mamba3DoubleSsdCachesConfig {
37 /// Number of cache slots (= number of virtual layers).
38 pub n_real_caches: usize,
39
40 /// Shared configuration that determines the shape of each cache.
41 pub cache: Mamba3DoubleSsdCacheConfig,
42}
43
44impl Mamba3DoubleSsdCachesConfig {
45 /// Convenience constructor from a block config.
46 pub fn new_from_block_config(
47 n_real_caches: usize,
48 batch: usize,
49 block_config: Mamba3Config,
50 ) -> Self {
51 Self {
52 n_real_caches,
53 cache: Mamba3DoubleSsdCacheConfig::new_from_block_config(batch, block_config),
54 }
55 }
56
57 /// Allocate all cache tensors (zero-initialised) on `device`.
58 pub fn init(&self, device: &Device) -> Mamba3DoubleSsdCaches {
59 let caches = (0..self.n_real_caches)
60 .map(|_| self.cache.clone().init(device))
61 .collect();
62 Mamba3DoubleSsdCaches { caches }
63 }
64}
65
66// ---------------------------------------------------------------------------
67// Mamba3DoubleSsdCache (state for a single layer)
68// ---------------------------------------------------------------------------
69
70/// The mutable state carried between decoding steps for a **single** Mamba-3 layer.
71///
72/// All tensors are updated at every call to [`crate::mamba3::mamba3::Mamba3::step`].
73#[derive(Module, Debug)]
74pub struct Mamba3DoubleSsdCache {
75 /// **SSM hidden state** `hₜ`.
76 ///
77 /// Updated via the (double-ssd) trapezoidal recurrence:
78 /// `hₜ = αₜ hₜ₋₁ + βₜ (sumₘ Kₜ₋₁[m] ⊗ (Vₜ₋₁ * mimo_x[m])) + γₜ (sumₘ Bₜ[m] ⊗ (xₜ * mimo_x[m]))`
79 ///
80 /// Shape: `[batch, nheads, per_head_dim, state_rank]`
81 pub ssm_bhpr: Tensor<4>,
82
83 /// **Previous token's B per mimo rank** = `Bₜ₋₁[m]`.
84 ///
85 /// Used to reconstruct the β term: `β * sum_r Bₜ₋₁[m] ⊗ (xₜ₋₁ * mimo_x[m])`.
86 /// For SISO (mimo_rank=1) this is shape `[batch, 1, nheads, state_rank]`.
87 ///
88 /// Shape: `[batch, mimo_rank, nheads, state_rank]`
89 pub k_state_bmhr: Tensor<4>,
90
91 /// **Previous token's x** = `xₜ₋₁`.
92 ///
93 /// Combined with `k_state_bmhr` and `mimo_x` to produce the β term.
94 ///
95 /// Shape: `[batch, nheads, per_head_dim]`
96 pub v_state_bhp: Tensor<3>,
97
98 /// **Cumulative data-dependent rotation** up to the current position
99 /// ([`RotationState`]): the abelian RoPE angle for
100 /// [`Complex2D`](crate::mamba3::rotation::RotationKind::Complex2D) (each step
101 /// `cum_angleₜ = cum_angleₜ₋₁ + Δₜ · tanh(θₜ) · π`), or the cumulative unit
102 /// quaternion for [`Quaternion4D`](crate::mamba3::rotation::RotationKind::Quaternion4D).
103 ///
104 /// Starts at the identity for fresh sequences; continued across calls for
105 /// streaming.
106 pub rotation: RotationState,
107}
108
109impl Mamba3DoubleSsdCache {
110 /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every cached tensor.
111 pub fn sanity(&self) {
112 san(&self.ssm_bhpr);
113 san(&self.k_state_bmhr);
114 san(&self.v_state_bhp);
115 self.rotation.sanity();
116 }
117}
118
119/// Configuration / factory for a single [`Mamba3DoubleSsdCache`].
120#[derive(Config, Debug)]
121pub struct Mamba3DoubleSsdCacheConfig {
122 /// Batch size.
123 pub batch: usize,
124
125 /// State rank.
126 #[config(default = 128)]
127 pub state_rank: usize,
128
129 /// Head dimension per_head_dim.
130 #[config(default = 64)]
131 pub per_head_dim: usize,
132
133 /// Number of SSM heads.
134 pub nheads: usize,
135
136 /// MIMO rank. 1 = SISO.
137 #[config(default = 1)]
138 pub mimo_rank: usize,
139
140 /// Number of RoPE angle pairs = `rope_dim / 2` = `(state_rank * rope_fraction) / 2`
141 /// (rounded down to even via `Mamba3Config::rope_dim`).
142 pub num_rope_angles: usize,
143
144 /// Which positional rotation the block uses ([`RotationKind`]); selects the
145 /// accumulator variant — [`RotationState::Quaternion`] for
146 /// [`RotationKind::Quaternion4D`], else [`RotationState::Angle`].
147 #[config(default = "crate::mamba3::rotation::RotationKind::Complex2D")]
148 pub rotation: RotationKind,
149
150 /// Number of quaternion blocks (`rope_dim / 4`); only used for
151 /// [`RotationKind::Quaternion4D`].
152 #[config(default = 1)]
153 pub num_quat_blocks: usize,
154}
155
156impl Mamba3DoubleSsdCacheConfig {
157 /// Derive cache shapes from a Mamba-3 block configuration plus a batch size.
158 pub fn new_from_block_config(batch: usize, block_config: Mamba3Config) -> Self {
159 Self {
160 batch,
161 state_rank: block_config.state_rank,
162 per_head_dim: block_config.per_head_dim,
163 nheads: block_config.nheads(),
164 mimo_rank: block_config.mimo_rank,
165 num_rope_angles: block_config.num_rope_angles(),
166 rotation: block_config.rotation,
167 num_quat_blocks: block_config.num_quat_blocks(),
168 }
169 }
170
171 /// Allocate zero/identity-initialised cache tensors on `device`.
172 pub fn init(&self, device: &Device) -> Mamba3DoubleSsdCache {
173 let ssm_bhpr = Tensor::zeros(
174 [self.batch, self.nheads, self.per_head_dim, self.state_rank],
175 device,
176 );
177 let k_state_bmhr = Tensor::zeros(
178 [self.batch, self.mimo_rank, self.nheads, self.state_rank],
179 device,
180 );
181 let v_state_bhp = Tensor::zeros([self.batch, self.nheads, self.per_head_dim], device);
182 let rotation = match self.rotation {
183 RotationKind::Quaternion4D => RotationState::identity_quaternion(
184 self.batch,
185 self.nheads,
186 self.num_quat_blocks,
187 device,
188 ),
189 RotationKind::Complex2D => {
190 RotationState::zeros_angle(self.batch, self.nheads, self.num_rope_angles, device)
191 }
192 };
193 Mamba3DoubleSsdCache {
194 ssm_bhpr,
195 k_state_bmhr,
196 v_state_bhp,
197 rotation,
198 }
199 }
200}