Skip to main content

burn_mamba/mamba1/
layer.rs

1use crate::mamba1::prelude::*;
2use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
3use burn::prelude::*;
4
5#[derive(Module, Debug)]
6pub struct Mamba1Layer<B: Backend> {
7    pub norm: RmsNorm<B>,
8    pub mamba_block: Mamba1<B>,
9}
10
11#[derive(Config, Debug)]
12pub struct Mamba1LayerConfig {
13    pub mamba_block: Mamba1Config,
14}
15
16impl Mamba1LayerConfig {
17    /// Returns the initialized model.
18    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Layer<B> {
19        Mamba1Layer {
20            norm: RmsNormConfig::new(self.mamba_block.d_model).init(device),
21            mamba_block: Mamba1Config::new(self.mamba_block.d_model)
22                .with_d_state(self.mamba_block.d_state)
23                .with_dt_rank(self.mamba_block.dt_rank)
24                .with_d_conv(self.mamba_block.d_conv)
25                .with_d_inner(self.mamba_block.d_inner)
26                .init(device),
27        }
28    }
29}
30
31impl<B: Backend> Mamba1Layer<B> {
32    /// See also [`Self::step`].
33    ///
34    /// # Shapes
35    ///   - Input [batch, sequence, d_model]
36    ///   - Output [batch, sequence, d_model]
37    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
38        let [batch, sequence, d_model] = x.dims();
39
40        let res = x.clone();
41        let x = self.norm.forward(x);
42
43        let x = self.mamba_block.forward(x);
44        debug_assert_eq!([batch, sequence, d_model], x.dims());
45
46        x + res
47    }
48}
49
50impl<B: Backend> Mamba1Layer<B> {
51    /// See also [`Self::forward`].
52    ///
53    /// # Shapes
54    ///   - Input [batch, d_model]
55    ///   - Output [batch, d_model]
56    pub fn step(&self, x: Tensor<B, 2>, cache: Mamba1Cache<B>) -> (Tensor<B, 2>, Mamba1Cache<B>) {
57        let [batch, d_model] = x.dims();
58
59        let res = x.clone();
60        let x = self.norm.forward(x);
61        let (x, cache) = self.mamba_block.step(x, cache);
62        debug_assert_eq!([batch, d_model], x.dims());
63
64        (x + res, cache)
65    }
66}