Skip to main content

burn_mamba/mamba3/
cache.rs

1//! # Mamba-3 Inference Caches
2//!
3//! During autoregressive (token-by-token) generation, three pieces of state
4//! must be preserved between calls:
5//!
6//! 1. **SSM hidden state** — `hₜ ∈ ℝ^{P×N}` per head, compressed context.
7//! 2. **Previous K state** — `B_{t-1}` per rank `[batch, mimo_rank, nheads, state_rank]`,
8//!    needed for the β term of the trapezoidal recurrence.
9//! 3. **Previous V state** — `x_{t-1}` per head `[batch, nheads, per_head_dim]`,
10//!    paired with k_state to reconstruct β B_{t-1} ⊗ x_{t-1}.
11//! 4. **Cumulative RoPE angle** — the accumulated rotation angle up to position
12//!    `t`, needed to correctly continue data-dependent rotary embeddings.
13//!
14//! Note: Mamba-3 has **no conv cache** (the short 1-D convolution present in
15//! Mamba-3 is removed; its role is absorbed by the trapezoidal discretization
16//! and the learnable B/C biases).
17
18use crate::mamba3::prelude::*;
19use crate::utils::sanity::sanity as san;
20use burn::module::Module;
21use burn::prelude::*;
22
23// ---------------------------------------------------------------------------
24// Mamba3Caches  (one cache entry per layer)
25// ---------------------------------------------------------------------------
26
27/// A collection of per-layer caches for a complete Mamba-3 network.
28#[derive(Module, Debug)]
29pub struct Mamba3Caches<B: Backend> {
30    /// Per-layer caches.  Length equals the number of virtual layers.
31    pub caches: Vec<Mamba3Cache<B>>,
32}
33
34/// Configuration / factory for [`Mamba3Caches`].
35#[derive(Config, Debug)]
36pub struct Mamba3CachesConfig {
37    /// Number of cache slots (= number of virtual layers).
38    pub n_real_caches: usize,
39
40    /// Shared configuration that determines the shape of each cache.
41    pub cache: Mamba3CacheConfig,
42}
43
44impl Mamba3CachesConfig {
45    /// Convenience constructor from a block config.
46    pub fn new_from_block_config(
47        n_real_caches: usize,
48        batch: usize,
49        block_config: Mamba3Config,
50    ) -> Self {
51        Self {
52            n_real_caches,
53            cache: Mamba3CacheConfig::new_from_block_config(batch, block_config),
54        }
55    }
56
57    /// Allocate all cache tensors (zero-initialised) on `device`.
58    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Caches<B> {
59        let caches = (0..self.n_real_caches)
60            .map(|_| self.cache.clone().init(device))
61            .collect();
62        Mamba3Caches { caches }
63    }
64}
65
66// ---------------------------------------------------------------------------
67// Mamba3Cache  (state for a single layer)
68// ---------------------------------------------------------------------------
69
70/// The mutable state carried between decoding steps for a **single** Mamba-3 layer.
71///
72/// All tensors are updated at every call to [`crate::mamba3::mamba3::Mamba3::step`].
73#[derive(Module, Debug)]
74pub struct Mamba3Cache<B: Backend> {
75    /// **SSM hidden state** `hₜ`.
76    ///
77    /// Updated via the trapezoidal recurrence:
78    /// `hₜ = αₜ hₜ₋₁ + βₜ (sum_r K_{t-1}[r] ⊗ (V_{t-1} * mimo_x[r])) + γₜ (sum_r Bₜ[r] ⊗ (xₜ * mimo_x[r]))`
79    ///
80    /// Shape: `[batch, nheads, per_head_dim, state_rank]`
81    pub ssm_bhpr: Tensor<B, 4>,
82
83    /// **Previous token's B per rank** = `B_{t-1}[r]`.
84    ///
85    /// Used to reconstruct the β term: `β * sum_r B_{t-1}[r] ⊗ (x_{t-1} * mimo_x[r])`.
86    /// For SISO (mimo_rank=1) this is shape `[batch, 1, nheads, state_rank]`.
87    ///
88    /// Shape: `[batch, mimo_rank, nheads, state_rank]`
89    pub k_state_brhn: Tensor<B, 4>,
90
91    /// **Previous token's x** = `x_{t-1}`.
92    ///
93    /// Combined with `k_state_brhn` and `mimo_x` to produce the β term.
94    ///
95    /// Shape: `[batch, nheads, per_head_dim]`
96    pub v_state_bhp: Tensor<B, 3>,
97
98    /// **Cumulative data-dependent RoPE angle** up to the current position.
99    ///
100    /// Each step updates: `cum_angle_{t} = cum_angle_{t-1} + Δ_t · tanh(θ_t) · π`
101    ///
102    /// Starts at zero for fresh sequences; continued across calls for streaming.
103    ///
104    /// Shape: `[batch, nheads, num_rope_angles]`
105    pub cum_angle_bhr: Tensor<B, 3>,
106}
107
108impl<B: Backend> Mamba3Cache<B> {
109    pub fn sanity(&self) {
110        san(&self.ssm_bhpr);
111        san(&self.k_state_brhn);
112        san(&self.v_state_bhp);
113        san(&self.cum_angle_bhr);
114    }
115}
116
117/// Configuration / factory for a single [`Mamba3Cache`].
118#[derive(Config, Debug)]
119pub struct Mamba3CacheConfig {
120    /// Batch size.
121    pub batch: usize,
122
123    /// State rank N.
124    #[config(default = 128)]
125    pub state_rank: usize,
126
127    /// Head dimension P.
128    #[config(default = 64)]
129    pub per_head_dim: usize,
130
131    /// Number of SSM heads H.
132    pub nheads: usize,
133
134    /// MIMO rank R.  1 = SISO.
135    #[config(default = 1)]
136    pub mimo_rank: usize,
137
138    /// Number of RoPE angle pairs = `rope_dim / 2` = `(state_rank * rope_fraction) / 2`
139    /// (rounded down to even via `Mamba3Config::rope_dim`).
140    pub num_rope_angles: usize,
141}
142
143impl Mamba3CacheConfig {
144    /// Derive cache shapes from a Mamba-3 block configuration plus a batch size.
145    pub fn new_from_block_config(batch: usize, block_config: Mamba3Config) -> Self {
146        Self {
147            batch,
148            state_rank: block_config.state_rank,
149            per_head_dim: block_config.per_head_dim,
150            nheads: block_config.nheads(),
151            mimo_rank: block_config.mimo_rank,
152            num_rope_angles: block_config.num_rope_angles(),
153        }
154    }
155
156    /// Allocate zero-initialised cache tensors on `device`.
157    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Cache<B> {
158        let ssm_bhpr = Tensor::zeros(
159            [self.batch, self.nheads, self.per_head_dim, self.state_rank],
160            device,
161        );
162        let k_state_brhn = Tensor::zeros(
163            [self.batch, self.mimo_rank, self.nheads, self.state_rank],
164            device,
165        );
166        let v_state_bhp = Tensor::zeros([self.batch, self.nheads, self.per_head_dim], device);
167        let cum_angle_bhr = Tensor::zeros([self.batch, self.nheads, self.num_rope_angles], device);
168        Mamba3Cache {
169            ssm_bhpr,
170            k_state_brhn,
171            v_state_bhp,
172            cum_angle_bhr,
173        }
174    }
175}