burn_mamba/mamba3/cache.rs
1//! # Mamba-3 Inference Caches
2//!
3//! During autoregressive (token-by-token) generation, three pieces of state
4//! must be preserved between calls:
5//!
6//! 1. **SSM hidden state** — `hₜ ∈ ℝ^{P×N}` per head, compressed context.
7//! 2. **Previous K state** — `B_{t-1}` per rank `[batch, mimo_rank, nheads, state_rank]`,
8//! needed for the β term of the trapezoidal recurrence.
9//! 3. **Previous V state** — `x_{t-1}` per head `[batch, nheads, per_head_dim]`,
10//! paired with k_state to reconstruct β B_{t-1} ⊗ x_{t-1}.
11//! 4. **Cumulative RoPE angle** — the accumulated rotation angle up to position
12//! `t`, needed to correctly continue data-dependent rotary embeddings.
13//!
14//! Note: Mamba-3 has **no conv cache** (the short 1-D convolution present in
15//! Mamba-3 is removed; its role is absorbed by the trapezoidal discretization
16//! and the learnable B/C biases).
17
18use crate::mamba3::prelude::*;
19use crate::utils::sanity::sanity as san;
20use burn::module::Module;
21use burn::prelude::*;
22
23// ---------------------------------------------------------------------------
24// Mamba3Caches (one cache entry per layer)
25// ---------------------------------------------------------------------------
26
27/// A collection of per-layer caches for a complete Mamba-3 network.
28#[derive(Module, Debug)]
29pub struct Mamba3Caches<B: Backend> {
30 /// Per-layer caches. Length equals the number of virtual layers.
31 pub caches: Vec<Mamba3Cache<B>>,
32}
33
34/// Configuration / factory for [`Mamba3Caches`].
35#[derive(Config, Debug)]
36pub struct Mamba3CachesConfig {
37 /// Number of cache slots (= number of virtual layers).
38 pub n_real_caches: usize,
39
40 /// Shared configuration that determines the shape of each cache.
41 pub cache: Mamba3CacheConfig,
42}
43
44impl Mamba3CachesConfig {
45 /// Convenience constructor from a block config.
46 pub fn new_from_block_config(
47 n_real_caches: usize,
48 batch: usize,
49 block_config: Mamba3Config,
50 ) -> Self {
51 Self {
52 n_real_caches,
53 cache: Mamba3CacheConfig::new_from_block_config(batch, block_config),
54 }
55 }
56
57 /// Allocate all cache tensors (zero-initialised) on `device`.
58 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Caches<B> {
59 let caches = (0..self.n_real_caches)
60 .map(|_| self.cache.clone().init(device))
61 .collect();
62 Mamba3Caches { caches }
63 }
64}
65
66// ---------------------------------------------------------------------------
67// Mamba3Cache (state for a single layer)
68// ---------------------------------------------------------------------------
69
70/// The mutable state carried between decoding steps for a **single** Mamba-3 layer.
71///
72/// All tensors are updated at every call to [`crate::mamba3::mamba3::Mamba3::step`].
73#[derive(Module, Debug)]
74pub struct Mamba3Cache<B: Backend> {
75 /// **SSM hidden state** `hₜ`.
76 ///
77 /// Updated via the trapezoidal recurrence:
78 /// `hₜ = αₜ hₜ₋₁ + βₜ (sum_r K_{t-1}[r] ⊗ (V_{t-1} * mimo_x[r])) + γₜ (sum_r Bₜ[r] ⊗ (xₜ * mimo_x[r]))`
79 ///
80 /// Shape: `[batch, nheads, per_head_dim, state_rank]`
81 pub ssm_bhpr: Tensor<B, 4>,
82
83 /// **Previous token's B per rank** = `B_{t-1}[r]`.
84 ///
85 /// Used to reconstruct the β term: `β * sum_r B_{t-1}[r] ⊗ (x_{t-1} * mimo_x[r])`.
86 /// For SISO (mimo_rank=1) this is shape `[batch, 1, nheads, state_rank]`.
87 ///
88 /// Shape: `[batch, mimo_rank, nheads, state_rank]`
89 pub k_state_brhn: Tensor<B, 4>,
90
91 /// **Previous token's x** = `x_{t-1}`.
92 ///
93 /// Combined with `k_state_brhn` and `mimo_x` to produce the β term.
94 ///
95 /// Shape: `[batch, nheads, per_head_dim]`
96 pub v_state_bhp: Tensor<B, 3>,
97
98 /// **Cumulative data-dependent RoPE angle** up to the current position.
99 ///
100 /// Each step updates: `cum_angle_{t} = cum_angle_{t-1} + Δ_t · tanh(θ_t) · π`
101 ///
102 /// Starts at zero for fresh sequences; continued across calls for streaming.
103 ///
104 /// Shape: `[batch, nheads, num_rope_angles]`
105 pub cum_angle_bhr: Tensor<B, 3>,
106}
107
108impl<B: Backend> Mamba3Cache<B> {
109 pub fn sanity(&self) {
110 san(&self.ssm_bhpr);
111 san(&self.k_state_brhn);
112 san(&self.v_state_bhp);
113 san(&self.cum_angle_bhr);
114 }
115}
116
117/// Configuration / factory for a single [`Mamba3Cache`].
118#[derive(Config, Debug)]
119pub struct Mamba3CacheConfig {
120 /// Batch size.
121 pub batch: usize,
122
123 /// State rank N.
124 #[config(default = 128)]
125 pub state_rank: usize,
126
127 /// Head dimension P.
128 #[config(default = 64)]
129 pub per_head_dim: usize,
130
131 /// Number of SSM heads H.
132 pub nheads: usize,
133
134 /// MIMO rank R. 1 = SISO.
135 #[config(default = 1)]
136 pub mimo_rank: usize,
137
138 /// Number of RoPE angle pairs = `rope_dim / 2` = `(state_rank * rope_fraction) / 2`
139 /// (rounded down to even via `Mamba3Config::rope_dim`).
140 pub num_rope_angles: usize,
141}
142
143impl Mamba3CacheConfig {
144 /// Derive cache shapes from a Mamba-3 block configuration plus a batch size.
145 pub fn new_from_block_config(batch: usize, block_config: Mamba3Config) -> Self {
146 Self {
147 batch,
148 state_rank: block_config.state_rank,
149 per_head_dim: block_config.per_head_dim,
150 nheads: block_config.nheads(),
151 mimo_rank: block_config.mimo_rank,
152 num_rope_angles: block_config.num_rope_angles(),
153 }
154 }
155
156 /// Allocate zero-initialised cache tensors on `device`.
157 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3Cache<B> {
158 let ssm_bhpr = Tensor::zeros(
159 [self.batch, self.nheads, self.per_head_dim, self.state_rank],
160 device,
161 );
162 let k_state_brhn = Tensor::zeros(
163 [self.batch, self.mimo_rank, self.nheads, self.state_rank],
164 device,
165 );
166 let v_state_bhp = Tensor::zeros([self.batch, self.nheads, self.per_head_dim], device);
167 let cum_angle_bhr = Tensor::zeros([self.batch, self.nheads, self.num_rope_angles], device);
168 Mamba3Cache {
169 ssm_bhpr,
170 k_state_brhn,
171 v_state_bhp,
172 cum_angle_bhr,
173 }
174 }
175}