Skip to main content

burn_mamba/mamba3/double_ssd/
cache.rs

1//! # Mamba-3 Inference Caches
2//!
3//! During autoregressive (token-by-token) generation, three pieces of state
4//! must be preserved between calls:
5//!
6//! 1. **SSM hidden state** — `hₜ ∈ ℝ^{per_head_dim×state_rank}` per head, compressed context.
7//! 2. **Previous K state** — `Bₜ₋₁` per rank `[batch, mimo_rank, nheads, state_rank]`,
8//!    needed for the β term of the (double-ssd) trapezoidal recurrence.
9//! 3. **Previous V state** — `xₜ₋₁` per head `[batch, nheads, per_head_dim]`,
10//!    paired with k_state to reconstruct β Bₜ₋₁ ⊗ xₜ₋₁.
11//! 4. **Cumulative RoPE angle** — the accumulated rotation angle up to position
12//!    `t`, needed to correctly continue data-dependent rotary embeddings.
13//!
14//! Note: Mamba-3 has **no conv cache** (the short 1-dimensional convolution present in
15//! Mamba-3 is removed; its role is absorbed by the trapezoidal discretization
16//! and the learnable B/C biases).
17
18use crate::mamba3::prelude::*;
19use crate::modules::sanity as san;
20use burn::module::Module;
21use burn::prelude::*;
22
23// ---------------------------------------------------------------------------
24// Mamba3DoubleSsdCaches  (one cache entry per layer)
25// ---------------------------------------------------------------------------
26
27/// A collection of per-layer caches for a complete Mamba-3 network.
28#[derive(Module, Debug)]
29pub struct Mamba3DoubleSsdCaches {
30    /// Per-layer caches.  Length equals the number of virtual layers.
31    pub caches: Vec<Mamba3DoubleSsdCache>,
32}
33
34/// Configuration / factory for [`Mamba3DoubleSsdCaches`].
35#[derive(Config, Debug)]
36pub struct Mamba3DoubleSsdCachesConfig {
37    /// Number of cache slots (= number of virtual layers).
38    pub n_real_caches: usize,
39
40    /// Shared configuration that determines the shape of each cache.
41    pub cache: Mamba3DoubleSsdCacheConfig,
42}
43
44impl Mamba3DoubleSsdCachesConfig {
45    /// Convenience constructor from a block config.
46    pub fn new_from_block_config(
47        n_real_caches: usize,
48        batch: usize,
49        block_config: Mamba3Config,
50    ) -> Self {
51        Self {
52            n_real_caches,
53            cache: Mamba3DoubleSsdCacheConfig::new_from_block_config(batch, block_config),
54        }
55    }
56
57    /// Allocate all cache tensors (zero-initialised) on `device`.
58    pub fn init(&self, device: &Device) -> Mamba3DoubleSsdCaches {
59        let caches = (0..self.n_real_caches)
60            .map(|_| self.cache.clone().init(device))
61            .collect();
62        Mamba3DoubleSsdCaches { caches }
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Mamba3DoubleSsdCache  (state for a single layer)
68// ---------------------------------------------------------------------------
69
70/// The mutable state carried between decoding steps for a **single** Mamba-3 layer.
71///
72/// All tensors are updated at every call to [`crate::mamba3::mamba3::Mamba3::step`].
73#[derive(Module, Debug)]
74pub struct Mamba3DoubleSsdCache {
75    /// **SSM hidden state** `hₜ`.
76    ///
77    /// Updated via the (double-ssd) trapezoidal recurrence:
78    /// `hₜ = αₜ hₜ₋₁ + βₜ (sumₘ Kₜ₋₁[m] ⊗ (Vₜ₋₁ * mimo_x[m])) + γₜ (sumₘ Bₜ[m] ⊗ (xₜ * mimo_x[m]))`
79    ///
80    /// Shape: `[batch, nheads, per_head_dim, state_rank]`
81    pub ssm_bhpr: Tensor<4>,
82
83    /// **Previous token's B per mimo rank** = `Bₜ₋₁[m]`.
84    ///
85    /// Used to reconstruct the β term: `β * sum_r Bₜ₋₁[m] ⊗ (xₜ₋₁ * mimo_x[m])`.
86    /// For SISO (mimo_rank=1) this is shape `[batch, 1, nheads, state_rank]`.
87    ///
88    /// Shape: `[batch, mimo_rank, nheads, state_rank]`
89    pub k_state_bmhr: Tensor<4>,
90
91    /// **Previous token's x** = `xₜ₋₁`.
92    ///
93    /// Combined with `k_state_bmhr` and `mimo_x` to produce the β term.
94    ///
95    /// Shape: `[batch, nheads, per_head_dim]`
96    pub v_state_bhp: Tensor<3>,
97
98    /// **Cumulative data-dependent rotation** up to the current position
99    /// ([`RotationState`]): the abelian RoPE angle for
100    /// [`Complex2D`](crate::mamba3::rotation::RotationKind::Complex2D) (each step
101    /// `cum_angleₜ = cum_angleₜ₋₁ + Δₜ · tanh(θₜ) · π`), or the cumulative unit
102    /// quaternion for [`Quaternion4D`](crate::mamba3::rotation::RotationKind::Quaternion4D).
103    ///
104    /// Starts at the identity for fresh sequences; continued across calls for
105    /// streaming.
106    pub rotation: RotationState,
107}
108
109impl Mamba3DoubleSsdCache {
110    /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on every cached tensor.
111    pub fn sanity(&self) {
112        san(&self.ssm_bhpr);
113        san(&self.k_state_bmhr);
114        san(&self.v_state_bhp);
115        self.rotation.sanity();
116    }
117}
118
119/// Configuration / factory for a single [`Mamba3DoubleSsdCache`].
120#[derive(Config, Debug)]
121pub struct Mamba3DoubleSsdCacheConfig {
122    /// Batch size.
123    pub batch: usize,
124
125    /// State rank.
126    #[config(default = 128)]
127    pub state_rank: usize,
128
129    /// Head dimension per_head_dim.
130    #[config(default = 64)]
131    pub per_head_dim: usize,
132
133    /// Number of SSM heads.
134    pub nheads: usize,
135
136    /// MIMO rank.  1 = SISO.
137    #[config(default = 1)]
138    pub mimo_rank: usize,
139
140    /// Number of RoPE angle pairs = `rope_dim / 2` = `(state_rank * rope_fraction) / 2`
141    /// (rounded down to even via `Mamba3Config::rope_dim`).
142    pub num_rope_angles: usize,
143
144    /// Which positional rotation the block uses ([`RotationKind`]); selects the
145    /// accumulator variant — [`RotationState::Quaternion`] for
146    /// [`RotationKind::Quaternion4D`], else [`RotationState::Angle`].
147    #[config(default = "crate::mamba3::rotation::RotationKind::Complex2D")]
148    pub rotation: RotationKind,
149
150    /// Number of quaternion blocks (`rope_dim / 4`); only used for
151    /// [`RotationKind::Quaternion4D`].
152    #[config(default = 1)]
153    pub num_quat_blocks: usize,
154}
155
156impl Mamba3DoubleSsdCacheConfig {
157    /// Derive cache shapes from a Mamba-3 block configuration plus a batch size.
158    pub fn new_from_block_config(batch: usize, block_config: Mamba3Config) -> Self {
159        Self {
160            batch,
161            state_rank: block_config.state_rank,
162            per_head_dim: block_config.per_head_dim,
163            nheads: block_config.nheads(),
164            mimo_rank: block_config.mimo_rank,
165            num_rope_angles: block_config.num_rope_angles(),
166            rotation: block_config.rotation,
167            num_quat_blocks: block_config.num_quat_blocks(),
168        }
169    }
170
171    /// Allocate zero/identity-initialised cache tensors on `device`.
172    pub fn init(&self, device: &Device) -> Mamba3DoubleSsdCache {
173        let ssm_bhpr = Tensor::zeros(
174            [self.batch, self.nheads, self.per_head_dim, self.state_rank],
175            device,
176        );
177        let k_state_bmhr = Tensor::zeros(
178            [self.batch, self.mimo_rank, self.nheads, self.state_rank],
179            device,
180        );
181        let v_state_bhp = Tensor::zeros([self.batch, self.nheads, self.per_head_dim], device);
182        let rotation = match self.rotation {
183            RotationKind::Quaternion4D => RotationState::identity_quaternion(
184                self.batch,
185                self.nheads,
186                self.num_quat_blocks,
187                device,
188            ),
189            RotationKind::Complex2D => {
190                RotationState::zeros_angle(self.batch, self.nheads, self.num_rope_angles, device)
191            }
192        };
193        Mamba3DoubleSsdCache {
194            ssm_bhpr,
195            k_state_bmhr,
196            v_state_bhp,
197            rotation,
198        }
199    }
200}