1use crate::mamba1::prelude::*;
2use crate::utils::silu::Silu;
3use burn::prelude::*;
4use burn::{
5 module::{Module, Param},
6 nn::conv::{Conv1d, Conv1dConfig},
7 nn::{Initializer, Linear, LinearConfig, PaddingConfig1d},
8};
9
10#[derive(Module, Debug)]
11pub struct Mamba1<B: Backend> {
12 pub in_proj: Linear<B>,
15
16 pub conv1d: Conv1d<B>,
19
20 pub x_proj: Linear<B>,
23
24 pub dt_proj: Linear<B>,
27
28 pub a_log: Param<Tensor<B, 2>>,
30
31 pub d: Param<Tensor<B, 1>>,
33
34 pub out_proj: Linear<B>,
37}
38
39#[derive(Config, Debug)]
40pub struct Mamba1Config {
41 pub d_model: usize,
43
44 #[config(default = 16)]
46 pub d_state: usize,
47
48 #[config(default = 4)]
49 pub d_conv: usize,
50
51 #[config(default = 2)]
52 pub expand: usize,
53
54 #[config(default = 1e-3)]
56 pub dt_min: f64,
57
58 #[config(default = 1e-1)]
60 pub dt_max: f64,
61
62 #[config(default = 1.)]
64 pub dt_scale: f64,
65
66 #[config(default = 1e-4)]
68 pub dt_init_floor: f64,
69
70 #[config(default = true)]
72 pub conv_bias: bool,
73
74 #[config(default = false)]
76 pub bias: bool,
77
78 pub dt_rank: Option<usize>,
83
84 pub d_inner: Option<usize>,
88}
89
90impl Mamba1Config {
91 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1<B> {
93 let d_inner = self.d_inner();
94 debug_assert_ne!(self.d_state, 0);
95 debug_assert!(self.d_model + self.d_state > 0);
96 let dt_rank = self.dt_rank();
97
98 let uniform_init = |d_input: usize| {
100 let bound = 1.0 / (d_input as f64).sqrt();
101 Initializer::Uniform {
102 min: -bound,
103 max: bound,
104 }
105 };
106
107 let dt_proj = {
108 use burn::tensor::Distribution;
109 let weight: Tensor<B, 2> = {
110 let dt_init_std = (dt_rank as f64).powf(-0.5) * self.dt_scale;
111 Tensor::random(
112 [dt_rank, d_inner],
113 Distribution::Uniform(-dt_init_std, dt_init_std),
114 device,
115 )
116 };
117 debug_assert_eq!([dt_rank, d_inner], weight.dims());
118 let bias: Tensor<B, 1> = {
119 let expm1 = |t: Tensor<B, 1>| t.exp() - 1.;
123 let dt = Tensor::random([d_inner], Distribution::Uniform(0.0, 1.0), device)
124 * (f64::ln(self.dt_max) - f64::ln(self.dt_min))
125 + f64::ln(self.dt_min);
126 let dt = dt.exp().clamp_min(self.dt_init_floor);
127 let inv_dt = dt.clone() + (-expm1(-dt)).log();
129 inv_dt
130 };
131 debug_assert_eq!([d_inner], bias.dims());
132 Linear {
133 weight: Param::from_tensor(weight),
134 bias: Some(Param::from_tensor(bias)),
135 }
136 };
137
138 let a_log = {
139 let a_row: Tensor<B, 1> =
140 Tensor::<B, 1, Int>::arange(1..self.d_state as i64 + 1, device).float();
141 debug_assert_eq!([self.d_state], a_row.dims());
142 let a_row = a_row.unsqueeze();
143 debug_assert_eq!([1, self.d_state], a_row.dims());
144 let a = a_row.repeat(&[d_inner, 1]);
145 debug_assert_eq!([d_inner, self.d_state], a.dims());
146 let a_log = a.log();
147 Param::from_tensor(a_log)
148 };
149
150 Mamba1 {
151 in_proj: LinearConfig::new(self.d_model, 2 * d_inner)
152 .with_bias(self.bias)
153 .with_initializer(uniform_init(self.d_model))
155 .init(device),
156 conv1d: Conv1dConfig::new(d_inner, d_inner, self.d_conv)
157 .with_padding(PaddingConfig1d::Explicit(self.d_conv - 1, self.d_conv - 1))
160 .with_groups(d_inner)
161 .with_bias(self.conv_bias)
162 .with_initializer(uniform_init(self.d_conv))
165 .init(device),
166 x_proj: LinearConfig::new(d_inner, dt_rank + 2 * self.d_state)
167 .with_bias(false)
168 .with_initializer(uniform_init(d_inner))
170 .init(device),
171 dt_proj,
172 a_log,
173 d: Initializer::Ones.init([d_inner], device),
174 out_proj: LinearConfig::new(d_inner, self.d_model)
175 .with_bias(self.bias)
176 .with_initializer(uniform_init(d_inner))
178 .init(device),
179 }
180 }
181 pub fn d_inner(&self) -> usize {
182 self.d_inner.unwrap_or(self.expand * self.d_model)
183 }
184 pub fn dt_rank(&self) -> usize {
185 self.dt_rank.unwrap_or(self.d_model.div_ceil(self.d_state))
186 }
187}
188
189impl<B: Backend> Mamba1<B> {
190 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
196 let [batch, sequence, d_model] = x.dims();
197 let [d_inner] = self.d.dims();
198 let [_, _, d_conv] = self.conv1d.weight.dims();
199
200 let (xs, res) = {
202 let xs_and_res = self.in_proj.forward(x);
204 debug_assert_eq!([batch, sequence, 2 * d_inner], xs_and_res.dims());
205
206 let mut split = xs_and_res
207 .split_with_sizes(vec![d_inner, d_inner], 2)
208 .into_iter();
209 debug_assert_eq!(split.len(), 2);
210 (split.next().unwrap(), split.next().unwrap())
211 };
212 debug_assert_eq!([batch, sequence, d_inner], xs.dims());
213 debug_assert_eq!([batch, sequence, d_inner], res.dims());
214
215 let xs = {
217 let xs = xs.permute([0, 2, 1]);
219 debug_assert_eq!([batch, d_inner, sequence], xs.dims());
220
221 debug_assert!(d_conv > 0);
222 let xs = self.conv1d.forward(xs);
223 debug_assert_eq!([batch, d_inner, sequence + d_conv - 1], xs.dims());
224
225 let xs = xs.narrow(2, 0, sequence);
226 debug_assert_eq!([batch, d_inner, sequence], xs.dims());
227
228 let xs = xs.permute([0, 2, 1]);
231 debug_assert_eq!([batch, sequence, d_inner], xs.dims());
232
233 let xs = Silu::new().forward(xs);
235 debug_assert_eq!([batch, sequence, d_inner], xs.dims());
236
237 xs
238 };
239 debug_assert_eq!([batch, sequence, d_inner], xs.dims());
240
241 let ss = self.ss(xs);
242 debug_assert_eq!([batch, sequence, d_inner], ss.dims());
243
244 let ys = ss * Silu::new().forward(res);
246 debug_assert_eq!([batch, sequence, d_inner], ys.dims());
247
248 let y = self.out_proj.forward(ys);
249 debug_assert_eq!([batch, sequence, d_model], y.dims());
250
251 y
252 }
253
254 pub fn ss(&self, u: Tensor<B, 3>) -> Tensor<B, 3> {
258 let [batch, sequence, d_inner] = u.dims();
259 let [_d_inner, d_state] = self.a_log.dims();
260 let [dt_rank, _d_inner] = self.dt_proj.weight.dims();
261
262 let a = self.a_log.val().exp().neg();
267 debug_assert_eq!([d_inner, d_state], a.dims());
268
269 let x_dbl = self.x_proj.forward(u.clone());
270 debug_assert_eq!([batch, sequence, dt_rank + 2 * d_state], x_dbl.dims());
271
272 let mut split = x_dbl
276 .split_with_sizes(vec![dt_rank, d_state, d_state], 2)
277 .into_iter();
278 let delta = split.next().unwrap();
279 let b = split.next().unwrap();
280 let c = split.next().unwrap();
281 debug_assert_eq!([batch, sequence, dt_rank], delta.dims());
282 debug_assert_eq!([batch, sequence, d_state], b.dims());
283 debug_assert_eq!([batch, sequence, d_state], c.dims());
284
285 let delta = self.dt_proj.forward(delta);
288 debug_assert_eq!([batch, sequence, d_inner], delta.dims());
289
290 let delta = burn::tensor::activation::softplus(delta, 1.);
291
292 let delta = delta.permute([1, 0, 2]);
294 debug_assert_eq!([sequence, batch, d_inner], delta.dims());
295
296 let c = c.permute([1, 0, 2]);
298 debug_assert_eq!([sequence, batch, d_state], c.dims());
299
300 Self::selective_scan(delta, a, b, c, self.d.val(), u)
301 }
302
303 pub fn selective_scan(
319 delta: Tensor<B, 3>,
320 a: Tensor<B, 2>,
321 b: Tensor<B, 3>,
322 c: Tensor<B, 3>,
323 d: Tensor<B, 1>,
324 u: Tensor<B, 3>,
325 ) -> Tensor<B, 3> {
326 let device = &u.device();
327 let [sequence, batch, d_inner] = delta.dims();
328 let [_d_inner, d_state] = a.dims();
329 let outer_shape = [sequence, batch, d_inner, d_state];
330
331 let (delta_a, delta_bu) = {
336 let delta = delta.unsqueeze_dim(3);
337 debug_assert_eq!([sequence, batch, d_inner, 1], delta.dims());
338 let delta = delta.expand(outer_shape);
339 debug_assert_eq!(outer_shape, delta.dims());
340
341 let a = a.unsqueeze_dims(&[0, 1]);
342 debug_assert_eq!([1, 1, d_inner, d_state], a.dims());
343 let a = a.expand(outer_shape);
344 debug_assert_eq!(outer_shape, a.dims());
345 let delta_a = (delta.clone() * a).exp();
346 debug_assert_eq!(outer_shape, delta_a.dims());
347
348 let b = b.permute([1, 0, 2]);
350 debug_assert_eq!([sequence, batch, d_state], b.dims());
351 let b = b.unsqueeze_dim(2);
352 debug_assert_eq!([sequence, batch, 1, d_state], b.dims());
353 let b = b.expand(outer_shape);
354 debug_assert_eq!(outer_shape, b.dims());
355 let delta_b = delta * b;
356 debug_assert_eq!(outer_shape, delta_b.dims());
357
358 let u = u.clone().permute([1, 0, 2]);
360 debug_assert_eq!([sequence, batch, d_inner], u.dims());
361 let u = u.unsqueeze_dim(3);
362 debug_assert_eq!([sequence, batch, d_inner, 1], u.dims());
363 let u = u.expand(outer_shape);
364 debug_assert_eq!(outer_shape, u.dims());
365 let delta_bu = delta_b * u;
366 debug_assert_eq!(outer_shape, delta_bu.dims());
367
368 (delta_a, delta_bu)
369 };
370 debug_assert_eq!(outer_shape, delta_a.dims());
371 debug_assert_eq!(outer_shape, delta_bu.dims());
372
373 let delta_a = delta_a.split(1, 0);
380 debug_assert_eq!(delta_a.len(), sequence);
381
382 let delta_bu = delta_bu.split(1, 0);
383 debug_assert_eq!(delta_bu.len(), sequence);
384
385 let c = c.unsqueeze_dim(3);
386 debug_assert_eq!([sequence, batch, d_state, 1], c.dims());
387 let c = c.split(1, 0);
388 debug_assert_eq!(c.len(), sequence);
389
390 let inner_shape = [batch, d_inner, d_state];
391 let mut xs: Tensor<B, 3> = Tensor::zeros(inner_shape, device);
392 let mut ys = Vec::with_capacity(sequence); for ((delta_a, delta_bu), c) in delta_a
394 .into_iter()
395 .zip(delta_bu.into_iter())
396 .zip(c.into_iter())
397 {
398 let delta_a = delta_a.squeeze_dim(0);
399 debug_assert_eq!(inner_shape, delta_a.dims());
400 let delta_bu = delta_bu.squeeze_dim(0);
401 debug_assert_eq!(inner_shape, delta_bu.dims());
402 let c = c.squeeze_dim(0);
403 debug_assert_eq!([batch, d_state, 1], c.dims());
404
405 xs = (xs.clone() * delta_a) + delta_bu;
406 let y = xs.clone().matmul(c);
407 debug_assert_eq!([batch, d_inner, 1], y.dims());
408 let y = y.squeeze_dim(2);
409 debug_assert_eq!([batch, d_inner], y.dims());
410 ys.push(y);
411 }
412
413 let ys = Tensor::stack(ys, 1);
414 debug_assert_eq!([batch, sequence, d_inner], ys.dims());
415
416 let d = d.unsqueeze_dims(&[0, 1]);
417 debug_assert_eq!([1, 1, d_inner], d.dims());
418 let d = d.expand([batch, sequence, d_inner]);
419
420 let ys = ys + (d * u);
421 debug_assert_eq!([batch, sequence, d_inner], ys.dims());
422
423 ys
424 }
425}
426
427mod step {
428 use super::*;
429
430 impl<B: Backend> Mamba1<B> {
431 pub fn step(
435 &self,
436 x: Tensor<B, 2>,
437 mut cache: Mamba1Cache<B>,
438 ) -> (Tensor<B, 2>, Mamba1Cache<B>) {
439 let [batch, d_inner, d_conv] = cache.conv.dims();
440 let [_batch, d_model] = x.dims();
441
442 let (xs, res) = {
444 let xs_and_res = self.in_proj.forward(x);
446 debug_assert_eq!([batch, 2 * d_inner], xs_and_res.dims());
447
448 let mut split = xs_and_res
449 .split_with_sizes(vec![d_inner, d_inner], 1)
450 .into_iter();
451 (split.next().unwrap(), split.next().unwrap())
452 };
453 debug_assert_eq!([batch, d_inner], xs.dims());
454 debug_assert_eq!([batch, d_inner], res.dims());
455
456 cache.conv = cache.conv.map(|conv| {
458 let t0 = conv.narrow(2, 1, d_conv - 1);
460 debug_assert_eq!([batch, d_inner, d_conv - 1], t0.dims());
461
462 let conv = Tensor::cat([t0, xs.unsqueeze_dim(2)].to_vec(), 2);
464 debug_assert_eq!([batch, d_inner, d_conv], conv.dims());
465
466 conv
467 });
468 let xs = {
469 let conv1d = self.conv1d.weight.val();
470 debug_assert_eq!([d_inner, 1, d_conv], conv1d.dims());
472 let conv1d = conv1d.permute([1, 0, 2]);
474 debug_assert_eq!([1, d_inner, d_conv], conv1d.dims());
475 let conv1d = conv1d.expand([batch, d_inner, d_conv]);
476 debug_assert_eq!([batch, d_inner, d_conv], conv1d.dims());
477
478 let xs = cache.conv.val() * conv1d;
479 let xs = xs.sum_dim(2);
480 debug_assert_eq!([batch, d_inner, 1], xs.dims());
481 let xs = xs.squeeze_dim(2);
482 debug_assert_eq!([batch, d_inner], xs.dims());
483
484 let conv1d_bias = self.conv1d.bias.as_ref().unwrap().val();
486 debug_assert_eq!([d_inner], conv1d_bias.dims());
488 let conv1d_bias = conv1d_bias.unsqueeze();
489 debug_assert_eq!([1, d_inner], conv1d_bias.dims());
490 let xs = xs + conv1d_bias;
491
492 let xs = Silu::new().forward(xs);
494 debug_assert_eq!([batch, d_inner], xs.dims());
495
496 xs
497 };
498 debug_assert_eq!([batch, d_inner], xs.dims());
499
500 let (ss, cache) = self.ss_step(xs, cache);
501 debug_assert_eq!([batch, d_inner], ss.dims());
502
503 let ys = ss * Silu::new().forward(res);
505 debug_assert_eq!([batch, d_inner], ys.dims());
506
507 let y = self.out_proj.forward(ys);
508 debug_assert_eq!([batch, d_model], y.dims());
509
510 (y, cache)
511 }
512
513 pub fn ss_step(
521 &self,
522 u: Tensor<B, 2>,
523 cache: Mamba1Cache<B>,
524 ) -> (Tensor<B, 2>, Mamba1Cache<B>) {
525 let [batch, d_inner, d_state] = cache.ssm.dims();
526 let [dt_rank, _d_inner] = self.dt_proj.weight.dims();
527
528 let a = self.a_log.val().exp().neg();
533 debug_assert_eq!([d_inner, d_state], a.dims());
534
535 let x_dbl = self.x_proj.forward(u.clone());
536 debug_assert_eq!([batch, dt_rank + 2 * d_state], x_dbl.dims());
537
538 let mut split = x_dbl
542 .split_with_sizes(vec![dt_rank, d_state, d_state], 1)
543 .into_iter();
544 let delta = split.next().unwrap();
545 let b = split.next().unwrap();
546 let c = split.next().unwrap();
547 debug_assert_eq!([batch, dt_rank], delta.dims());
548 debug_assert_eq!([batch, d_state], b.dims());
549 debug_assert_eq!([batch, d_state], c.dims());
550
551 let delta = self.dt_proj.forward(delta);
554 debug_assert_eq!([batch, d_inner], delta.dims());
555 let delta = burn::tensor::activation::softplus(delta, 1.);
556
557 Self::selective_scan_step(delta, a, b, c, self.d.val(), u, cache)
558 }
559
560 pub fn selective_scan_step(
576 delta: Tensor<B, 2>,
577 a: Tensor<B, 2>,
578 b: Tensor<B, 2>,
579 c: Tensor<B, 2>,
580 d: Tensor<B, 1>,
581 u: Tensor<B, 2>,
582 mut cache: Mamba1Cache<B>,
583 ) -> (Tensor<B, 2>, Mamba1Cache<B>) {
584 let [batch, d_inner, d_state] = cache.ssm.dims();
585 let outer_shape = [batch, d_inner, d_state];
586
587 let (delta_a, delta_bu) = {
592 let delta = delta.unsqueeze_dim(2);
593 debug_assert_eq!([batch, d_inner, 1], delta.dims());
594 let delta = delta.expand(outer_shape);
595 debug_assert_eq!(outer_shape, delta.dims());
596
597 let a = a.unsqueeze();
598 debug_assert_eq!([1, d_inner, d_state], a.dims());
599 let a = a.expand(outer_shape);
600 debug_assert_eq!(outer_shape, a.dims());
601 let delta_a = (delta.clone() * a).exp();
602 debug_assert_eq!(outer_shape, delta_a.dims());
603
604 let b = b.unsqueeze_dim(1);
605 debug_assert_eq!([batch, 1, d_state], b.dims());
606 let b = b.expand(outer_shape);
607 debug_assert_eq!(outer_shape, b.dims());
608 let delta_b = delta * b;
609 debug_assert_eq!(outer_shape, delta_b.dims());
610
611 let u = u.clone().unsqueeze_dim(2);
612 debug_assert_eq!([batch, d_inner, 1], u.dims());
613 let u = u.expand(outer_shape);
614 debug_assert_eq!(outer_shape, u.dims());
615 let delta_bu = delta_b * u;
616 debug_assert_eq!(outer_shape, delta_bu.dims());
617
618 (delta_a, delta_bu)
619 };
620 debug_assert_eq!(outer_shape, delta_a.dims());
621 debug_assert_eq!(outer_shape, delta_bu.dims());
622
623 cache.ssm = cache.ssm.map(|ssm| (ssm * delta_a) + delta_bu);
624
625 let c = c.unsqueeze_dim(2);
626 debug_assert_eq!([batch, d_state, 1], c.dims());
627
628 let y = cache.ssm.val().matmul(c);
629 debug_assert_eq!([batch, d_inner, 1], y.dims());
630 let y = y.squeeze_dim(2);
631 debug_assert_eq!([batch, d_inner], y.dims());
632
633 let d = d.unsqueeze();
634 debug_assert_eq!([1, d_inner], d.dims());
635 let d = d.expand([batch, d_inner]);
636 debug_assert_eq!([batch, d_inner], d.dims());
637
638 let y = y + (d * u);
639 debug_assert_eq!([batch, d_inner], y.dims());
640
641 (y, cache)
642 }
643 }
644}