burn_mamba/mamba1/
cache.rs1use crate::mamba1::prelude::*;
2use burn::prelude::*;
3use burn::{
4 module::{Module, Param},
5 nn::Initializer,
6};
7
8#[derive(Module, Debug)]
9pub struct Mamba1Cache<B: Backend> {
10 pub conv: Param<Tensor<B, 3>>,
13 pub ssm: Param<Tensor<B, 3>>,
16}
17
18#[derive(Config, Debug)]
19pub struct Mamba1CacheConfig {
20 pub batch: usize,
21
22 #[config(default = 16)]
24 pub d_state: usize,
25
26 #[config(default = 4)]
27 pub d_conv: usize,
28
29 pub d_inner: usize,
30}
31
32impl Mamba1CacheConfig {
33 pub fn new_from_block_config(batch: usize, block_config: Mamba1Config) -> Self {
34 Self {
35 batch,
36 d_state: block_config.d_state,
37 d_conv: block_config.d_conv,
38 d_inner: block_config.d_inner(),
39 }
40 }
41
42 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Cache<B> {
44 let conv = Initializer::Zeros.init([self.batch, self.d_inner, self.d_conv], device);
45 let ssm = Initializer::Zeros.init([self.batch, self.d_inner, self.d_state], device);
46 Mamba1Cache { conv, ssm }
47 }
48}
49
50#[derive(Module, Debug)]
51pub struct Mamba1Caches<B: Backend> {
52 pub caches: Vec<Mamba1Cache<B>>,
55}
56
57#[derive(Config, Debug)]
58pub struct Mamba1CachesConfig {
59 pub n_layers: usize,
60 pub cache: Mamba1CacheConfig,
61}
62
63impl Mamba1CachesConfig {
64 pub fn new_from_block_config(
65 n_layers: usize,
66 batch: usize,
67 block_config: Mamba1Config,
68 ) -> Self {
69 Self {
70 n_layers,
71 cache: Mamba1CacheConfig::new_from_block_config(batch, block_config),
72 }
73 }
74
75 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Caches<B> {
77 let mut caches: Vec<Mamba1Cache<B>> = Vec::with_capacity(self.n_layers);
78 for _ in 0..self.n_layers {
79 let cache: Mamba1Cache<B> = self.cache.clone().init(device);
80 caches.push(cache);
81 }
82 Mamba1Caches { caches }
83 }
84}