Skip to main content

burn_mamba/modules/
cache.rs

1use crate::prelude::*;
2use burn::prelude::*;
3
4// ===========================================================================
5// Unifying enums: one runtime + one serializable Config across all families
6// ===========================================================================
7
8/// The uniform interface a per-network cache collection exposes for the generic
9/// [`Layers`] loop. The existing `Mamba{1,2,3}Caches` already provide it (under
10/// `caches_len`/`into_options`/`from_options`).
11pub trait CacheStack: Sized {
12    /// The per-layer cache element.
13    type Cache;
14    /// Number of per-(virtual-)layer slots.
15    fn slot_count(&self) -> usize;
16    /// Move each slot into an `Option` so the loop can `take` without cloning.
17    fn into_slots(self) -> Vec<Option<Self::Cache>>;
18    /// Inverse of [`Self::into_slots`].
19    fn from_slots(slots: Vec<Option<Self::Cache>>) -> Self;
20}
21
22/// Runtime-tagged caches: one variant per family, matching [`MambaLatentNet`].
23///
24/// This is plain runtime state (not a `Module`): caches are threaded through
25/// `forward`/`step`, never recorded or optimised. (`Mamba3Caches` is itself a
26/// non-`Module` enum, so a `Module` derive here would not even apply.)
27#[derive(Debug)]
28pub enum MambaCaches {
29    /// Mamba-1 caches.
30    #[cfg(feature = "mamba1")]
31    Mamba1(crate::mamba1::prelude::Mamba1Caches),
32    /// Mamba-2 caches.
33    #[cfg(feature = "mamba2")]
34    Mamba2(crate::mamba2::prelude::Mamba2Caches),
35    /// Mamba-3 caches.
36    #[cfg(feature = "mamba3")]
37    Mamba3(crate::mamba3::prelude::Mamba3Caches),
38}
39
40// ===========================================================================
41// Per-family impls
42// ===========================================================================
43
44#[cfg(feature = "mamba2")]
45mod impl_mamba2 {
46    use super::*;
47    use crate::mamba2::prelude::{
48        Mamba2, Mamba2Cache, Mamba2CacheConfig, Mamba2Caches, Mamba2CachesConfig, Mamba2Config,
49        Mamba2SsdPath,
50    };
51
52    impl CacheStack for Mamba2Caches {
53        type Cache = Mamba2Cache;
54        fn slot_count(&self) -> usize {
55            self.caches.len()
56        }
57        fn into_slots(self) -> Vec<Option<Mamba2Cache>> {
58            self.caches.into_iter().map(Some).collect()
59        }
60        fn from_slots(slots: Vec<Option<Mamba2Cache>>) -> Self {
61            Self {
62                caches: slots.into_iter().map(Option::unwrap).collect(),
63            }
64        }
65    }
66
67    impl MambaBlock for Mamba2 {
68        type Cache = Mamba2Cache;
69        type Caches = Mamba2Caches;
70        type SsdPath = Mamba2SsdPath;
71
72        fn block_forward(
73            &self,
74            x: Tensor<3>,
75            cache: Option<Mamba2Cache>,
76            ssd_path: Mamba2SsdPath,
77        ) -> (Tensor<3>, Mamba2Cache) {
78            self.forward(x, cache, ssd_path)
79        }
80        fn block_step(&self, x: Tensor<2>, cache: Option<Mamba2Cache>) -> (Tensor<2>, Mamba2Cache) {
81            self.step(x, cache)
82        }
83        fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Mamba2Caches {
84            let [batch, _seq, _d] = x.dims();
85            self.make_zero(batch, n_virtual, &x.device())
86        }
87        fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Mamba2Caches {
88            let [batch, _d] = x.dims();
89            self.make_zero(batch, n_virtual, &x.device())
90        }
91    }
92
93    impl Mamba2 {
94        fn make_zero(&self, batch: usize, n_virtual: usize, device: &Device) -> Mamba2Caches {
95            let [conv_dim, _, conv_kernel] = self.conv1d.weight.dims();
96            Mamba2CachesConfig::new(
97                n_virtual,
98                Mamba2CacheConfig {
99                    batch,
100                    state_rank: self.state_rank,
101                    conv_kernel,
102                    conv_dim,
103                    per_head_dim: self.per_head_dim(),
104                    nheads: self.nheads(),
105                },
106            )
107            .init(device)
108        }
109    }
110
111    impl MambaBlockConfig for Mamba2Config {
112        type Block = Mamba2;
113        fn d_model(&self) -> usize {
114            self.d_model
115        }
116        fn init_block(&self, device: &Device) -> Mamba2 {
117            self.init(device)
118        }
119    }
120}
121
122#[cfg(feature = "mamba3")]
123mod impl_mamba3 {
124    use super::*;
125    use crate::mamba3::prelude::{Mamba3, Mamba3Cache, Mamba3Caches, Mamba3Config, Mamba3SsdPath};
126    use crate::mamba3::single_ssd::prelude::{
127        Mamba3SingleSsdCacheConfig, Mamba3SingleSsdCaches, Mamba3SingleSsdCachesConfig,
128    };
129
130    /// Zero single-ssd caches sized from a `[batch, sequence, d_model]` input.
131    /// (A missing cache defaults to the single-ssd pathway — ≈½ the SSD memory
132    /// of double-ssd — for either rotation kind.)
133    fn zero_single_ssd_caches(
134        mamba_block: &Mamba3,
135        batch: usize,
136        n_virtual: usize,
137        device: &Device,
138    ) -> Mamba3SingleSsdCaches {
139        Mamba3SingleSsdCachesConfig::new(
140            n_virtual,
141            Mamba3SingleSsdCacheConfig {
142                batch,
143                state_rank: mamba_block.state_rank,
144                num_rope_angles: mamba_block.num_rope_angles,
145                per_head_dim: mamba_block.per_head_dim(),
146                nheads: mamba_block.nheads(),
147                mimo_rank: mamba_block.mimo_rank,
148                rotation: mamba_block.rotation,
149                num_quat_blocks: mamba_block.num_quat_blocks,
150            },
151        )
152        .init(device)
153    }
154
155    impl CacheStack for Mamba3Caches {
156        type Cache = Mamba3Cache;
157        fn slot_count(&self) -> usize {
158            self.caches_len()
159        }
160        fn into_slots(self) -> Vec<Option<Mamba3Cache>> {
161            self.into_options()
162        }
163        fn from_slots(slots: Vec<Option<Mamba3Cache>>) -> Self {
164            Self::from_options(slots)
165        }
166    }
167
168    impl MambaBlock for Mamba3 {
169        type Cache = Mamba3Cache;
170        type Caches = Mamba3Caches;
171        type SsdPath = Mamba3SsdPath;
172
173        fn block_forward(
174            &self,
175            x: Tensor<3>,
176            cache: Option<Mamba3Cache>,
177            ssd_path: Mamba3SsdPath,
178        ) -> (Tensor<3>, Mamba3Cache) {
179            self.forward(x, cache, ssd_path)
180        }
181        fn block_step(&self, x: Tensor<2>, cache: Option<Mamba3Cache>) -> (Tensor<2>, Mamba3Cache) {
182            self.step(x, cache)
183        }
184        fn block_step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
185            self.step_infinite(x)
186        }
187        fn block_step_n_approx(
188            &self,
189            x: Tensor<2>,
190            n: usize,
191            cache: Option<Mamba3Cache>,
192        ) -> (Tensor<2>, Mamba3Cache) {
193            self.step_n_approx(x, n, cache)
194        }
195        fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Mamba3Caches {
196            let [batch, _seq, _d] = x.dims();
197            zero_single_ssd_caches(self, batch, n_virtual, &x.device()).into()
198        }
199        fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Mamba3Caches {
200            let [batch, _d] = x.dims();
201            zero_single_ssd_caches(self, batch, n_virtual, &x.device()).into()
202        }
203    }
204
205    impl MambaBlockConfig for Mamba3Config {
206        type Block = Mamba3;
207        fn d_model(&self) -> usize {
208            self.d_model
209        }
210        fn init_block(&self, device: &Device) -> Mamba3 {
211            self.init(device)
212        }
213    }
214}
215
216#[cfg(feature = "mamba1")]
217mod impl_mamba1 {
218    use super::*;
219    use crate::mamba1::prelude::{
220        Mamba1, Mamba1Cache, Mamba1CacheConfig, Mamba1Caches, Mamba1CachesConfig, Mamba1Config,
221    };
222
223    impl CacheStack for Mamba1Caches {
224        type Cache = Mamba1Cache;
225        fn slot_count(&self) -> usize {
226            self.caches.len()
227        }
228        fn into_slots(self) -> Vec<Option<Mamba1Cache>> {
229            self.caches.into_iter().map(Some).collect()
230        }
231        fn from_slots(slots: Vec<Option<Mamba1Cache>>) -> Self {
232            Self {
233                caches: slots.into_iter().map(Option::unwrap).collect(),
234            }
235        }
236    }
237
238    impl MambaBlock for Mamba1 {
239        type Cache = Mamba1Cache;
240        type Caches = Mamba1Caches;
241        /// Mamba-1 has no SSD chunking, so there is no path selector.
242        type SsdPath = ();
243
244        fn block_forward(
245            &self,
246            x: Tensor<3>,
247            cache: Option<Mamba1Cache>,
248            _ssd_path: (),
249        ) -> (Tensor<3>, Mamba1Cache) {
250            self.forward(x, cache)
251        }
252        fn block_step(&self, x: Tensor<2>, cache: Option<Mamba1Cache>) -> (Tensor<2>, Mamba1Cache) {
253            self.step(x, cache)
254        }
255        fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Mamba1Caches {
256            let [batch, _seq, _d] = x.dims();
257            self.make_zero(batch, n_virtual, &x.device())
258        }
259        fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Mamba1Caches {
260            let [batch, _d] = x.dims();
261            self.make_zero(batch, n_virtual, &x.device())
262        }
263    }
264
265    impl Mamba1 {
266        fn cache_config(&self, batch: usize) -> Mamba1CacheConfig {
267            let [d_inner, state_rank] = self.a_log.dims();
268            let [_, _, conv_kernel] = self.conv1d.weight.dims();
269            Mamba1CacheConfig::new(batch, d_inner)
270                .with_state_rank(state_rank)
271                .with_conv_kernel(conv_kernel)
272        }
273        fn make_zero(&self, batch: usize, n_virtual: usize, device: &Device) -> Mamba1Caches {
274            Mamba1CachesConfig::new(n_virtual, self.cache_config(batch)).init(device)
275        }
276    }
277
278    impl MambaBlockConfig for Mamba1Config {
279        type Block = Mamba1;
280        fn d_model(&self) -> usize {
281            self.d_model
282        }
283        fn init_block(&self, device: &Device) -> Mamba1 {
284            self.init(device)
285        }
286    }
287}