Skip to main content

burn_mamba/mamba1/
network.rs

1//! Utilizes Mamba1 and other Modules to build a Mamba1 model capable of utilizing the state-spaces/mamba-130m text prediction models.
2//!
3//! References:
4//! - https://github.com/huggingface/candle/blob/fd7c8565646039e35925b8730d27ddad195d7e73/candle-examples/examples/mamba-minimal/
5//! - https://github.com/johnma2006/mamba-minimal/blob/61f01953ca153f8c4a850d7111beecbf4be9cee1/
6
7use crate::mamba1::prelude::*;
8use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
9use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig};
10use burn::prelude::*;
11
12#[derive(Module, Debug)]
13pub struct Mamba1Network<B: Backend> {
14    pub embedding: Embedding<B>,
15    pub layers: Vec<Mamba1Layer<B>>,
16    pub norm_f: RmsNorm<B>,
17    /// If missing, re-utilizes a transposed `embedding` weight.
18    pub lm_head: Option<Linear<B>>,
19}
20
21#[derive(Config, Debug)]
22pub struct Mamba1NetworkConfig {
23    pub n_layer: usize,
24
25    /// If vocab_size is divisible by pad_vocab_size_multiple, this should be considered the unpadded vocab size.
26    /// Otherwise, this is padded into `((vocab_size / self.pad_vocab_size_multiple) + 1) * pad_vocab_size_multiple`.
27    pub vocab_size: usize,
28
29    /// If no pad is required, vocab_size must be divisible by pad_vocab_size_multiple.
30    /// If pad is required, vocab_size increases until it's divisible by pad_vocab_size_multiple.
31    ///
32    /// To disable vocab padding, you can set this to `1`.
33    pub pad_vocab_size_multiple: usize,
34
35    pub mamba_block: Mamba1Config,
36
37    /// If set to true, `lm_head` is set to `None` and it re-utilizes the transposed `embedding` weights.
38    pub missing_lm_head: bool,
39}
40
41impl Mamba1NetworkConfig {
42    /// Returns the initialized model.
43    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Network<B> {
44        let mut layers = Vec::with_capacity(self.n_layer);
45        for _ in 0..self.n_layer {
46            let layer = Mamba1LayerConfig::new(self.mamba_block.clone()).init(device);
47            layers.push(layer);
48        }
49
50        let padded_vocab_size = {
51            if self.vocab_size.is_multiple_of(self.pad_vocab_size_multiple) {
52                self.vocab_size
53            } else {
54                ((self.vocab_size / self.pad_vocab_size_multiple) + 1)
55                    * self.pad_vocab_size_multiple
56            }
57        };
58
59        Mamba1Network {
60            embedding: EmbeddingConfig::new(padded_vocab_size, self.mamba_block.d_model)
61                .init(device),
62            layers,
63            norm_f: RmsNormConfig::new(self.mamba_block.d_model).init(device),
64            lm_head: if self.missing_lm_head {
65                None
66            } else {
67                Some(
68                    LinearConfig::new(self.mamba_block.d_model, padded_vocab_size)
69                        .with_bias(false)
70                        .init(device),
71                )
72            },
73        }
74    }
75}
76
77impl<B: Backend> Mamba1Network<B> {
78    /// See also [`Self::step`].
79    ///
80    /// # Shapes
81    ///   - Input [batch, sequence]
82    ///   - Output [batch, sequence, d_model]
83    pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
84        let [batch, sequence] = x.dims();
85        let [padded_vocab, d_model] = self.embedding.weight.dims();
86
87        let mut x = self.embedding.forward(x);
88        debug_assert_eq!([batch, sequence, d_model], x.dims());
89
90        for layer in self.layers.iter() {
91            x = layer.forward(x);
92        }
93
94        x = self.norm_f.forward(x);
95        if let Some(lm_head) = &self.lm_head {
96            x = lm_head.forward(x);
97        } else {
98            // let weight = self.embedding.weight.clone().map(|w| w.swap_dims(0, 1));
99            let weight = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
100            debug_assert_eq!([d_model, padded_vocab], weight.dims());
101
102            let linear = Linear { weight, bias: None };
103            x = linear.forward(x);
104        };
105        debug_assert_eq!([batch, sequence, padded_vocab], x.dims());
106
107        x
108    }
109
110    /// See also [`Self::forward`].
111    ///
112    /// # Shapes
113    ///   - Input [batch]
114    ///   - Output [batch, d_model]
115    pub fn step(
116        &self,
117        x: Tensor<B, 1, Int>,
118        mut caches: Mamba1Caches<B>,
119    ) -> (Tensor<B, 2>, Mamba1Caches<B>) {
120        let [batch] = x.dims();
121        let [padded_vocab, d_model] = self.embedding.weight.dims();
122
123        let x = x.unsqueeze_dim(1);
124        debug_assert_eq!([batch, 1], x.dims());
125
126        let x = self.embedding.forward(x);
127        debug_assert_eq!([batch, 1, d_model], x.dims());
128        let mut x = x.squeeze_dim(1);
129        debug_assert_eq!([batch, d_model], x.dims());
130
131        for (i, layer) in self.layers.iter().enumerate() {
132            let (x_, cache) = layer.step(x, caches.caches[i].clone());
133            x = x_;
134            caches.caches[i] = cache;
135        }
136
137        x = self.norm_f.forward(x);
138        if let Some(lm_head) = &self.lm_head {
139            x = lm_head.forward(x);
140        } else {
141            // let weight = self.embedding.weight.clone().map(|w| w.swap_dims(0, 1));
142            let weight = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
143            debug_assert_eq!([d_model, padded_vocab], weight.dims());
144
145            let linear = Linear { weight, bias: None };
146            x = linear.forward(x);
147        };
148        debug_assert_eq!([batch, padded_vocab], x.dims());
149
150        (x, caches)
151    }
152}