burn_mamba/mamba2/cache.rs
1//! # Mamba-2 Inference Caches
2//!
3//! This module defines the state that must be preserved between calls during
4//! autoregressive (token-by-token) generation. During *training* or *prefill*
5//! the full sequence is available at once and the chunked SSD algorithm is used
6//! (see [`crate::mamba2::Mamba2::forward`]). During *decoding* the model
7//! processes one token per step and the SSM operates in its pure recurrent
8//! form (see [`crate::mamba2::Mamba2::step`]):
9//!
10//! ```text
11//! hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ (state update)
12//! yₜ = Cₜᵀ hₜ + D xₜ (output)
13//! ```
14//!
15//! Two pieces of state are required per layer:
16//!
17//! 1. **Convolution cache** — the last `conv_kernel` inputs to the depthwise
18//! Conv1d, kept so that every decoding step can apply the causal filter
19//! without re-processing previous tokens.
20//!
21//! 2. **SSM hidden state** — the matrix `hₜ ∈ ℝ^{P×N}` (per head), which
22//! compresses the entire past context into a fixed-size representation
23//! regardless of how many tokens have been generated. This is the key
24//! memory-efficiency advantage of SSMs over attention: the KV-cache of a
25//! Transformer grows as O(T·N) with sequence length, whereas the SSM state
26//! is always O(P·N).
27
28use crate::mamba2::prelude::*;
29use crate::utils::sanity::sanity as san;
30use burn::module::Module;
31use burn::prelude::*;
32
33// ---------------------------------------------------------------------------
34// Mamba2Caches (one cache entry per layer)
35// ---------------------------------------------------------------------------
36
37/// A collection of per-layer caches for a complete Mamba-2 network.
38///
39/// During autoregressive decoding, a [`Mamba2Caches`] instance is threaded
40/// through every call to [`crate::layer::Mamba2Layers::step`]. Each element
41/// of `caches` corresponds to one (virtual) layer in the network.
42#[derive(Module, Debug)]
43pub struct Mamba2Caches<B: Backend> {
44 /// Per-layer caches.
45 ///
46 /// Length: `n_real_caches` (the number of *virtual* layers, which may
47 /// exceed the number of *real* weight layers when weight-sharing / layer
48 /// scheduling is in use).
49 pub caches: Vec<Mamba2Cache<B>>,
50}
51
52/// Configuration / factory for [`Mamba2Caches`].
53#[derive(Config, Debug)]
54pub struct Mamba2CachesConfig {
55 /// Number of cache slots. Equals the number of virtual layers in the
56 /// network (one cache per layer, even when layers share weights).
57 pub n_real_caches: usize,
58
59 /// Shared configuration that determines the shape of each individual
60 /// cache tensor.
61 pub cache: Mamba2CacheConfig,
62}
63
64impl Mamba2CachesConfig {
65 /// Convenience constructor that derives cache shapes directly from a
66 /// [`Mamba2Config`] block configuration.
67 pub fn new_from_block_config(
68 n_real_caches: usize,
69 batch: usize,
70 block_config: Mamba2Config,
71 ) -> Self {
72 Self {
73 n_real_caches,
74 cache: Mamba2CacheConfig::new_from_block_config(batch, block_config),
75 }
76 }
77
78 /// Allocate all cache tensors (zero-initialised) on `device`.
79 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Caches<B> {
80 let caches = (0..self.n_real_caches)
81 .map(|_| self.cache.clone().init(device))
82 .collect();
83 Mamba2Caches { caches }
84 }
85}
86
87// ---------------------------------------------------------------------------
88// Mamba2Cache (state for a single layer)
89// ---------------------------------------------------------------------------
90
91/// The mutable state carried between decoding steps for a **single** Mamba-2
92/// layer.
93///
94/// Both tensors are updated in-place (via Burn's functional clone) at every
95/// call to [`crate::mamba2::Mamba2::step`].
96#[derive(Module, Debug)]
97pub struct Mamba2Cache<B: Backend> {
98 /// **Convolution rolling window.**
99 ///
100 /// Stores the last `conv_kernel` pre-activation feature vectors fed into
101 /// the depthwise Conv1d. At each step, the oldest column is discarded and
102 /// the new token's projection is appended (a left-shift followed by an
103 /// insert into the rightmost column), maintaining strict causality.
104 ///
105 /// Shape: `[batch, conv_dim, conv_kernel]`
106 /// - `conv_dim = d_inner + 2 · ngroups · state_rank`
107 /// - `conv_kernel` is typically 4
108 pub conv_bvk: Tensor<B, 3>,
109
110 /// **SSM hidden state** `hₜ`.
111 ///
112 /// This is the O(P·N) compressed summary of all tokens seen so far.
113 /// Updated via `hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ` at each decoding step.
114 ///
115 /// The tensor is indexed as `[batch, nheads, per_head_dim, state_rank]`
116 /// (i.e. `[B, H, P, N]` in the paper's notation), which is the transpose
117 /// of the mathematical `hₜ ∈ ℝ^{N×P}` but equivalent in content.
118 ///
119 /// Shape: `[batch, nheads, per_head_dim, state_rank]`
120 pub ssm_bhpr: Tensor<B, 4>,
121}
122
123impl<B: Backend> Mamba2Cache<B> {
124 pub fn sanity(&self) {
125 san(&self.conv_bvk);
126 san(&self.ssm_bhpr);
127 }
128}
129
130/// Configuration / factory for a single [`Mamba2Cache`].
131#[derive(Config, Debug)]
132pub struct Mamba2CacheConfig {
133 /// Batch size.
134 pub batch: usize,
135
136 /// State rank `N` — the number of latent dimensions in the SSM hidden
137 /// state. Corresponds to `state_rank` in [`Mamba2Config`].
138 #[config(default = 128)]
139 pub state_rank: usize,
140
141 /// Causal convolution window length. Corresponds to `conv_kernel` in
142 /// [`Mamba2Config`].
143 #[config(default = 4)]
144 pub conv_kernel: usize,
145
146 /// Number of channels entering (and leaving) the depthwise convolution.
147 /// Equal to `d_inner + 2 · ngroups · state_rank`.
148 pub conv_dim: usize,
149
150 /// Head dimension `P`. Corresponds to `per_head_dim` in [`Mamba2Config`].
151 #[config(default = 64)]
152 pub per_head_dim: usize,
153
154 /// Number of SSM heads `H`.
155 pub nheads: usize,
156}
157
158impl Mamba2CacheConfig {
159 /// Derive cache shapes from a Mamba-2 block configuration plus a batch
160 /// size.
161 pub fn new_from_block_config(batch: usize, block_config: Mamba2Config) -> Self {
162 Self {
163 batch,
164 state_rank: block_config.state_rank,
165 conv_kernel: block_config.conv_kernel,
166 conv_dim: block_config.conv_dim(),
167 per_head_dim: block_config.per_head_dim,
168 nheads: block_config.nheads(),
169 }
170 }
171
172 /// Allocate zero-initialised cache tensors on `device`.
173 ///
174 /// Zero initialisation is correct because:
175 /// - The convolution cache represents "no previous tokens" (identity padding).
176 /// - The SSM state represents `h₀ = 0` (zero initial condition), which is
177 /// the standard default. Learnable initial state (if configured) are
178 /// added on top of this inside [`crate::mamba2::Mamba2::forward`] /
179 /// [`crate::mamba2::Mamba2::step`].
180 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Cache<B> {
181 let conv_bvk = Tensor::zeros(
182 Shape::new([self.batch, self.conv_dim, self.conv_kernel]),
183 device,
184 );
185 let ssm_bhpr = Tensor::zeros(
186 Shape::new([self.batch, self.nheads, self.per_head_dim, self.state_rank]),
187 device,
188 );
189 Mamba2Cache { conv_bvk, ssm_bhpr }
190 }
191}