Skip to main content

burn_mamba/mamba2/
cache.rs

1//! # Mamba-2 Inference Caches
2//!
3//! This module defines the state that must be preserved between calls during
4//! autoregressive (token-by-token) generation.  During *training* or *prefill*
5//! the full sequence is available at once and the chunked SSD algorithm is used
6//! (see [`crate::mamba2::Mamba2::forward`]).  During *decoding* the model
7//! processes one token per step and the SSM operates in its pure recurrent
8//! form (see [`crate::mamba2::Mamba2::step`]):
9//!
10//! ```text
11//!   hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ        (state update)
12//!   yₜ = Cₜᵀ hₜ + D xₜ            (output)
13//! ```
14//!
15//! Two pieces of state are required per layer:
16//!
17//! 1. **Convolution cache** — the last `conv_kernel` inputs to the depthwise
18//!    Conv1d, kept so that every decoding step can apply the causal filter
19//!    without re-processing previous tokens.
20//!
21//! 2. **SSM hidden state** — the matrix `hₜ ∈ ℝ^{P×N}` (per head), which
22//!    compresses the entire past context into a fixed-size representation
23//!    regardless of how many tokens have been generated.  This is the key
24//!    memory-efficiency advantage of SSMs over attention: the KV-cache of a
25//!    Transformer grows as O(T·N) with sequence length, whereas the SSM state
26//!    is always O(P·N).
27
28use crate::mamba2::prelude::*;
29use crate::utils::sanity::sanity as san;
30use burn::module::Module;
31use burn::prelude::*;
32
33// ---------------------------------------------------------------------------
34// Mamba2Caches  (one cache entry per layer)
35// ---------------------------------------------------------------------------
36
37/// A collection of per-layer caches for a complete Mamba-2 network.
38///
39/// During autoregressive decoding, a [`Mamba2Caches`] instance is threaded
40/// through every call to [`crate::layer::Mamba2Layers::step`].  Each element
41/// of `caches` corresponds to one (virtual) layer in the network.
42#[derive(Module, Debug)]
43pub struct Mamba2Caches<B: Backend> {
44    /// Per-layer caches.
45    ///
46    /// Length: `n_real_caches` (the number of *virtual* layers, which may
47    /// exceed the number of *real* weight layers when weight-sharing / layer
48    /// scheduling is in use).
49    pub caches: Vec<Mamba2Cache<B>>,
50}
51
52/// Configuration / factory for [`Mamba2Caches`].
53#[derive(Config, Debug)]
54pub struct Mamba2CachesConfig {
55    /// Number of cache slots.  Equals the number of virtual layers in the
56    /// network (one cache per layer, even when layers share weights).
57    pub n_real_caches: usize,
58
59    /// Shared configuration that determines the shape of each individual
60    /// cache tensor.
61    pub cache: Mamba2CacheConfig,
62}
63
64impl Mamba2CachesConfig {
65    /// Convenience constructor that derives cache shapes directly from a
66    /// [`Mamba2Config`] block configuration.
67    pub fn new_from_block_config(
68        n_real_caches: usize,
69        batch: usize,
70        block_config: Mamba2Config,
71    ) -> Self {
72        Self {
73            n_real_caches,
74            cache: Mamba2CacheConfig::new_from_block_config(batch, block_config),
75        }
76    }
77
78    /// Allocate all cache tensors (zero-initialised) on `device`.
79    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Caches<B> {
80        let caches = (0..self.n_real_caches)
81            .map(|_| self.cache.clone().init(device))
82            .collect();
83        Mamba2Caches { caches }
84    }
85}
86
87// ---------------------------------------------------------------------------
88// Mamba2Cache  (state for a single layer)
89// ---------------------------------------------------------------------------
90
91/// The mutable state carried between decoding steps for a **single** Mamba-2
92/// layer.
93///
94/// Both tensors are updated in-place (via Burn's functional clone) at every
95/// call to [`crate::mamba2::Mamba2::step`].
96#[derive(Module, Debug)]
97pub struct Mamba2Cache<B: Backend> {
98    /// **Convolution rolling window.**
99    ///
100    /// Stores the last `conv_kernel` pre-activation feature vectors fed into
101    /// the depthwise Conv1d.  At each step, the oldest column is discarded and
102    /// the new token's projection is appended (a left-shift followed by an
103    /// insert into the rightmost column), maintaining strict causality.
104    ///
105    /// Shape: `[batch, conv_dim, conv_kernel]`
106    ///   - `conv_dim  = d_inner + 2 · ngroups · state_rank`
107    ///   - `conv_kernel` is typically 4
108    pub conv_bvk: Tensor<B, 3>,
109
110    /// **SSM hidden state** `hₜ`.
111    ///
112    /// This is the O(P·N) compressed summary of all tokens seen so far.
113    /// Updated via `hₜ = Āₜ hₜ₋₁ + B̄ₜ xₜ` at each decoding step.
114    ///
115    /// The tensor is indexed as `[batch, nheads, per_head_dim, state_rank]`
116    /// (i.e. `[B, H, P, N]` in the paper's notation), which is the transpose
117    /// of the mathematical `hₜ ∈ ℝ^{N×P}` but equivalent in content.
118    ///
119    /// Shape: `[batch, nheads, per_head_dim, state_rank]`
120    pub ssm_bhpr: Tensor<B, 4>,
121}
122
123impl<B: Backend> Mamba2Cache<B> {
124    pub fn sanity(&self) {
125        san(&self.conv_bvk);
126        san(&self.ssm_bhpr);
127    }
128}
129
130/// Configuration / factory for a single [`Mamba2Cache`].
131#[derive(Config, Debug)]
132pub struct Mamba2CacheConfig {
133    /// Batch size.
134    pub batch: usize,
135
136    /// State rank `N` — the number of latent dimensions in the SSM hidden
137    /// state.  Corresponds to `state_rank` in [`Mamba2Config`].
138    #[config(default = 128)]
139    pub state_rank: usize,
140
141    /// Causal convolution window length.  Corresponds to `conv_kernel` in
142    /// [`Mamba2Config`].
143    #[config(default = 4)]
144    pub conv_kernel: usize,
145
146    /// Number of channels entering (and leaving) the depthwise convolution.
147    /// Equal to `d_inner + 2 · ngroups · state_rank`.
148    pub conv_dim: usize,
149
150    /// Head dimension `P`.  Corresponds to `per_head_dim` in [`Mamba2Config`].
151    #[config(default = 64)]
152    pub per_head_dim: usize,
153
154    /// Number of SSM heads `H`.
155    pub nheads: usize,
156}
157
158impl Mamba2CacheConfig {
159    /// Derive cache shapes from a Mamba-2 block configuration plus a batch
160    /// size.
161    pub fn new_from_block_config(batch: usize, block_config: Mamba2Config) -> Self {
162        Self {
163            batch,
164            state_rank: block_config.state_rank,
165            conv_kernel: block_config.conv_kernel,
166            conv_dim: block_config.conv_dim(),
167            per_head_dim: block_config.per_head_dim,
168            nheads: block_config.nheads(),
169        }
170    }
171
172    /// Allocate zero-initialised cache tensors on `device`.
173    ///
174    /// Zero initialisation is correct because:
175    /// - The convolution cache represents "no previous tokens" (identity padding).
176    /// - The SSM state represents `h₀ = 0` (zero initial condition), which is
177    ///   the standard default.  Learnable initial state (if configured) are
178    ///   added on top of this inside [`crate::mamba2::Mamba2::forward`] /
179    ///   [`crate::mamba2::Mamba2::step`].
180    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2Cache<B> {
181        let conv_bvk = Tensor::zeros(
182            Shape::new([self.batch, self.conv_dim, self.conv_kernel]),
183            device,
184        );
185        let ssm_bhpr = Tensor::zeros(
186            Shape::new([self.batch, self.nheads, self.per_head_dim, self.state_rank]),
187            device,
188        );
189        Mamba2Cache { conv_bvk, ssm_bhpr }
190    }
191}