Skip to main content

burn_mamba/mamba1/
cache.rs

1use crate::mamba1::prelude::*;
2use burn::prelude::*;
3use burn::{
4    module::{Module, Param},
5    nn::Initializer,
6};
7
8#[derive(Module, Debug)]
9pub struct Mamba1Cache<B: Backend> {
10    /// # Shape
11    /// [batch, d_inner, d_conv]
12    pub conv: Param<Tensor<B, 3>>,
13    /// # Shape
14    /// [batch, d_inner, d_state]
15    pub ssm: Param<Tensor<B, 3>>,
16}
17
18#[derive(Config, Debug)]
19pub struct Mamba1CacheConfig {
20    pub batch: usize,
21
22    /// latent state dimension (`N` in Algorithm 2 from the Mamba paper).
23    #[config(default = 16)]
24    pub d_state: usize,
25
26    #[config(default = 4)]
27    pub d_conv: usize,
28
29    pub d_inner: usize,
30}
31
32impl Mamba1CacheConfig {
33    pub fn new_from_block_config(batch: usize, block_config: Mamba1Config) -> Self {
34        Self {
35            batch,
36            d_state: block_config.d_state,
37            d_conv: block_config.d_conv,
38            d_inner: block_config.d_inner(),
39        }
40    }
41
42    /// Returns the initialized model.
43    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Cache<B> {
44        let conv = Initializer::Zeros.init([self.batch, self.d_inner, self.d_conv], device);
45        let ssm = Initializer::Zeros.init([self.batch, self.d_inner, self.d_state], device);
46        Mamba1Cache { conv, ssm }
47    }
48}
49
50#[derive(Module, Debug)]
51pub struct Mamba1Caches<B: Backend> {
52    /// # Shape
53    /// [n_layers]
54    pub caches: Vec<Mamba1Cache<B>>,
55}
56
57#[derive(Config, Debug)]
58pub struct Mamba1CachesConfig {
59    pub n_layers: usize,
60    pub cache: Mamba1CacheConfig,
61}
62
63impl Mamba1CachesConfig {
64    pub fn new_from_block_config(
65        n_layers: usize,
66        batch: usize,
67        block_config: Mamba1Config,
68    ) -> Self {
69        Self {
70            n_layers,
71            cache: Mamba1CacheConfig::new_from_block_config(batch, block_config),
72        }
73    }
74
75    /// Returns the initialized model.
76    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Caches<B> {
77        let mut caches: Vec<Mamba1Cache<B>> = Vec::with_capacity(self.n_layers);
78        for _ in 0..self.n_layers {
79            let cache: Mamba1Cache<B> = self.cache.clone().init(device);
80            caches.push(cache);
81        }
82        Mamba1Caches { caches }
83    }
84}