burn_mamba/mamba1/
layer.rs1use crate::mamba1::prelude::*;
2use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
3use burn::prelude::*;
4
5#[derive(Module, Debug)]
6pub struct Mamba1Layer<B: Backend> {
7 pub norm: RmsNorm<B>,
8 pub mamba_block: Mamba1<B>,
9}
10
11#[derive(Config, Debug)]
12pub struct Mamba1LayerConfig {
13 pub mamba_block: Mamba1Config,
14}
15
16impl Mamba1LayerConfig {
17 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Layer<B> {
19 Mamba1Layer {
20 norm: RmsNormConfig::new(self.mamba_block.d_model).init(device),
21 mamba_block: Mamba1Config::new(self.mamba_block.d_model)
22 .with_d_state(self.mamba_block.d_state)
23 .with_dt_rank(self.mamba_block.dt_rank)
24 .with_d_conv(self.mamba_block.d_conv)
25 .with_d_inner(self.mamba_block.d_inner)
26 .init(device),
27 }
28 }
29}
30
31impl<B: Backend> Mamba1Layer<B> {
32 pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
38 let [batch, sequence, d_model] = x.dims();
39
40 let res = x.clone();
41 let x = self.norm.forward(x);
42
43 let x = self.mamba_block.forward(x);
44 debug_assert_eq!([batch, sequence, d_model], x.dims());
45
46 x + res
47 }
48}
49
50impl<B: Backend> Mamba1Layer<B> {
51 pub fn step(&self, x: Tensor<B, 2>, cache: Mamba1Cache<B>) -> (Tensor<B, 2>, Mamba1Cache<B>) {
57 let [batch, d_model] = x.dims();
58
59 let res = x.clone();
60 let x = self.norm.forward(x);
61 let (x, cache) = self.mamba_block.step(x, cache);
62 debug_assert_eq!([batch, d_model], x.dims());
63
64 (x + res, cache)
65 }
66}