1use crate::prelude::*;
2use burn::prelude::*;
3
4pub trait CacheStack: Sized {
12 type Cache;
14 fn slot_count(&self) -> usize;
16 fn into_slots(self) -> Vec<Option<Self::Cache>>;
18 fn from_slots(slots: Vec<Option<Self::Cache>>) -> Self;
20}
21
22#[derive(Debug)]
28pub enum MambaCaches {
29 #[cfg(feature = "mamba1")]
31 Mamba1(crate::mamba1::prelude::Mamba1Caches),
32 #[cfg(feature = "mamba2")]
34 Mamba2(crate::mamba2::prelude::Mamba2Caches),
35 #[cfg(feature = "mamba3")]
37 Mamba3(crate::mamba3::prelude::Mamba3Caches),
38}
39
40#[cfg(feature = "mamba2")]
45mod impl_mamba2 {
46 use super::*;
47 use crate::mamba2::prelude::{
48 Mamba2, Mamba2Cache, Mamba2CacheConfig, Mamba2Caches, Mamba2CachesConfig, Mamba2Config,
49 Mamba2SsdPath,
50 };
51
52 impl CacheStack for Mamba2Caches {
53 type Cache = Mamba2Cache;
54 fn slot_count(&self) -> usize {
55 self.caches.len()
56 }
57 fn into_slots(self) -> Vec<Option<Mamba2Cache>> {
58 self.caches.into_iter().map(Some).collect()
59 }
60 fn from_slots(slots: Vec<Option<Mamba2Cache>>) -> Self {
61 Self {
62 caches: slots.into_iter().map(Option::unwrap).collect(),
63 }
64 }
65 }
66
67 impl MambaBlock for Mamba2 {
68 type Cache = Mamba2Cache;
69 type Caches = Mamba2Caches;
70 type SsdPath = Mamba2SsdPath;
71
72 fn block_forward(
73 &self,
74 x: Tensor<3>,
75 cache: Option<Mamba2Cache>,
76 ssd_path: Mamba2SsdPath,
77 ) -> (Tensor<3>, Mamba2Cache) {
78 self.forward(x, cache, ssd_path)
79 }
80 fn block_step(&self, x: Tensor<2>, cache: Option<Mamba2Cache>) -> (Tensor<2>, Mamba2Cache) {
81 self.step(x, cache)
82 }
83 fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Mamba2Caches {
84 let [batch, _seq, _d] = x.dims();
85 self.make_zero(batch, n_virtual, &x.device())
86 }
87 fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Mamba2Caches {
88 let [batch, _d] = x.dims();
89 self.make_zero(batch, n_virtual, &x.device())
90 }
91 }
92
93 impl Mamba2 {
94 fn make_zero(&self, batch: usize, n_virtual: usize, device: &Device) -> Mamba2Caches {
95 let [conv_dim, _, conv_kernel] = self.conv1d.weight.dims();
96 Mamba2CachesConfig::new(
97 n_virtual,
98 Mamba2CacheConfig {
99 batch,
100 state_rank: self.state_rank,
101 conv_kernel,
102 conv_dim,
103 per_head_dim: self.per_head_dim(),
104 nheads: self.nheads(),
105 },
106 )
107 .init(device)
108 }
109 }
110
111 impl MambaBlockConfig for Mamba2Config {
112 type Block = Mamba2;
113 fn d_model(&self) -> usize {
114 self.d_model
115 }
116 fn init_block(&self, device: &Device) -> Mamba2 {
117 self.init(device)
118 }
119 }
120}
121
122#[cfg(feature = "mamba3")]
123mod impl_mamba3 {
124 use super::*;
125 use crate::mamba3::prelude::{Mamba3, Mamba3Cache, Mamba3Caches, Mamba3Config, Mamba3SsdPath};
126 use crate::mamba3::single_ssd::prelude::{
127 Mamba3SingleSsdCacheConfig, Mamba3SingleSsdCaches, Mamba3SingleSsdCachesConfig,
128 };
129
130 fn zero_single_ssd_caches(
134 mamba_block: &Mamba3,
135 batch: usize,
136 n_virtual: usize,
137 device: &Device,
138 ) -> Mamba3SingleSsdCaches {
139 Mamba3SingleSsdCachesConfig::new(
140 n_virtual,
141 Mamba3SingleSsdCacheConfig {
142 batch,
143 state_rank: mamba_block.state_rank,
144 num_rope_angles: mamba_block.num_rope_angles,
145 per_head_dim: mamba_block.per_head_dim(),
146 nheads: mamba_block.nheads(),
147 mimo_rank: mamba_block.mimo_rank,
148 rotation: mamba_block.rotation,
149 num_quat_blocks: mamba_block.num_quat_blocks,
150 },
151 )
152 .init(device)
153 }
154
155 impl CacheStack for Mamba3Caches {
156 type Cache = Mamba3Cache;
157 fn slot_count(&self) -> usize {
158 self.caches_len()
159 }
160 fn into_slots(self) -> Vec<Option<Mamba3Cache>> {
161 self.into_options()
162 }
163 fn from_slots(slots: Vec<Option<Mamba3Cache>>) -> Self {
164 Self::from_options(slots)
165 }
166 }
167
168 impl MambaBlock for Mamba3 {
169 type Cache = Mamba3Cache;
170 type Caches = Mamba3Caches;
171 type SsdPath = Mamba3SsdPath;
172
173 fn block_forward(
174 &self,
175 x: Tensor<3>,
176 cache: Option<Mamba3Cache>,
177 ssd_path: Mamba3SsdPath,
178 ) -> (Tensor<3>, Mamba3Cache) {
179 self.forward(x, cache, ssd_path)
180 }
181 fn block_step(&self, x: Tensor<2>, cache: Option<Mamba3Cache>) -> (Tensor<2>, Mamba3Cache) {
182 self.step(x, cache)
183 }
184 fn block_step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
185 self.step_infinite(x)
186 }
187 fn block_step_n_approx(
188 &self,
189 x: Tensor<2>,
190 n: usize,
191 cache: Option<Mamba3Cache>,
192 ) -> (Tensor<2>, Mamba3Cache) {
193 self.step_n_approx(x, n, cache)
194 }
195 fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Mamba3Caches {
196 let [batch, _seq, _d] = x.dims();
197 zero_single_ssd_caches(self, batch, n_virtual, &x.device()).into()
198 }
199 fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Mamba3Caches {
200 let [batch, _d] = x.dims();
201 zero_single_ssd_caches(self, batch, n_virtual, &x.device()).into()
202 }
203 }
204
205 impl MambaBlockConfig for Mamba3Config {
206 type Block = Mamba3;
207 fn d_model(&self) -> usize {
208 self.d_model
209 }
210 fn init_block(&self, device: &Device) -> Mamba3 {
211 self.init(device)
212 }
213 }
214}
215
216#[cfg(feature = "mamba1")]
217mod impl_mamba1 {
218 use super::*;
219 use crate::mamba1::prelude::{
220 Mamba1, Mamba1Cache, Mamba1CacheConfig, Mamba1Caches, Mamba1CachesConfig, Mamba1Config,
221 };
222
223 impl CacheStack for Mamba1Caches {
224 type Cache = Mamba1Cache;
225 fn slot_count(&self) -> usize {
226 self.caches.len()
227 }
228 fn into_slots(self) -> Vec<Option<Mamba1Cache>> {
229 self.caches.into_iter().map(Some).collect()
230 }
231 fn from_slots(slots: Vec<Option<Mamba1Cache>>) -> Self {
232 Self {
233 caches: slots.into_iter().map(Option::unwrap).collect(),
234 }
235 }
236 }
237
238 impl MambaBlock for Mamba1 {
239 type Cache = Mamba1Cache;
240 type Caches = Mamba1Caches;
241 type SsdPath = ();
243
244 fn block_forward(
245 &self,
246 x: Tensor<3>,
247 cache: Option<Mamba1Cache>,
248 _ssd_path: (),
249 ) -> (Tensor<3>, Mamba1Cache) {
250 self.forward(x, cache)
251 }
252 fn block_step(&self, x: Tensor<2>, cache: Option<Mamba1Cache>) -> (Tensor<2>, Mamba1Cache) {
253 self.step(x, cache)
254 }
255 fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Mamba1Caches {
256 let [batch, _seq, _d] = x.dims();
257 self.make_zero(batch, n_virtual, &x.device())
258 }
259 fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Mamba1Caches {
260 let [batch, _d] = x.dims();
261 self.make_zero(batch, n_virtual, &x.device())
262 }
263 }
264
265 impl Mamba1 {
266 fn cache_config(&self, batch: usize) -> Mamba1CacheConfig {
267 let [d_inner, state_rank] = self.a_log.dims();
268 let [_, _, conv_kernel] = self.conv1d.weight.dims();
269 Mamba1CacheConfig::new(batch, d_inner)
270 .with_state_rank(state_rank)
271 .with_conv_kernel(conv_kernel)
272 }
273 fn make_zero(&self, batch: usize, n_virtual: usize, device: &Device) -> Mamba1Caches {
274 Mamba1CachesConfig::new(n_virtual, self.cache_config(batch)).init(device)
275 }
276 }
277
278 impl MambaBlockConfig for Mamba1Config {
279 type Block = Mamba1;
280 fn d_model(&self) -> usize {
281 self.d_model
282 }
283 fn init_block(&self, device: &Device) -> Mamba1 {
284 self.init(device)
285 }
286 }
287}