burn_mamba/mamba1/
network.rs1use crate::mamba1::prelude::*;
8use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
9use burn::nn::{Embedding, EmbeddingConfig, Linear, LinearConfig};
10use burn::prelude::*;
11
12#[derive(Module, Debug)]
13pub struct Mamba1Network<B: Backend> {
14 pub embedding: Embedding<B>,
15 pub layers: Vec<Mamba1Layer<B>>,
16 pub norm_f: RmsNorm<B>,
17 pub lm_head: Option<Linear<B>>,
19}
20
21#[derive(Config, Debug)]
22pub struct Mamba1NetworkConfig {
23 pub n_layer: usize,
24
25 pub vocab_size: usize,
28
29 pub pad_vocab_size_multiple: usize,
34
35 pub mamba_block: Mamba1Config,
36
37 pub missing_lm_head: bool,
39}
40
41impl Mamba1NetworkConfig {
42 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1Network<B> {
44 let mut layers = Vec::with_capacity(self.n_layer);
45 for _ in 0..self.n_layer {
46 let layer = Mamba1LayerConfig::new(self.mamba_block.clone()).init(device);
47 layers.push(layer);
48 }
49
50 let padded_vocab_size = {
51 if self.vocab_size.is_multiple_of(self.pad_vocab_size_multiple) {
52 self.vocab_size
53 } else {
54 ((self.vocab_size / self.pad_vocab_size_multiple) + 1)
55 * self.pad_vocab_size_multiple
56 }
57 };
58
59 Mamba1Network {
60 embedding: EmbeddingConfig::new(padded_vocab_size, self.mamba_block.d_model)
61 .init(device),
62 layers,
63 norm_f: RmsNormConfig::new(self.mamba_block.d_model).init(device),
64 lm_head: if self.missing_lm_head {
65 None
66 } else {
67 Some(
68 LinearConfig::new(self.mamba_block.d_model, padded_vocab_size)
69 .with_bias(false)
70 .init(device),
71 )
72 },
73 }
74 }
75}
76
77impl<B: Backend> Mamba1Network<B> {
78 pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
84 let [batch, sequence] = x.dims();
85 let [padded_vocab, d_model] = self.embedding.weight.dims();
86
87 let mut x = self.embedding.forward(x);
88 debug_assert_eq!([batch, sequence, d_model], x.dims());
89
90 for layer in self.layers.iter() {
91 x = layer.forward(x);
92 }
93
94 x = self.norm_f.forward(x);
95 if let Some(lm_head) = &self.lm_head {
96 x = lm_head.forward(x);
97 } else {
98 let weight = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
100 debug_assert_eq!([d_model, padded_vocab], weight.dims());
101
102 let linear = Linear { weight, bias: None };
103 x = linear.forward(x);
104 };
105 debug_assert_eq!([batch, sequence, padded_vocab], x.dims());
106
107 x
108 }
109
110 pub fn step(
116 &self,
117 x: Tensor<B, 1, Int>,
118 mut caches: Mamba1Caches<B>,
119 ) -> (Tensor<B, 2>, Mamba1Caches<B>) {
120 let [batch] = x.dims();
121 let [padded_vocab, d_model] = self.embedding.weight.dims();
122
123 let x = x.unsqueeze_dim(1);
124 debug_assert_eq!([batch, 1], x.dims());
125
126 let x = self.embedding.forward(x);
127 debug_assert_eq!([batch, 1, d_model], x.dims());
128 let mut x = x.squeeze_dim(1);
129 debug_assert_eq!([batch, d_model], x.dims());
130
131 for (i, layer) in self.layers.iter().enumerate() {
132 let (x_, cache) = layer.step(x, caches.caches[i].clone());
133 x = x_;
134 caches.caches[i] = cache;
135 }
136
137 x = self.norm_f.forward(x);
138 if let Some(lm_head) = &self.lm_head {
139 x = lm_head.forward(x);
140 } else {
141 let weight = self.embedding.weight.clone().map(|w| w.permute([1, 0]));
143 debug_assert_eq!([d_model, padded_vocab], weight.dims());
144
145 let linear = Linear { weight, bias: None };
146 x = linear.forward(x);
147 };
148 debug_assert_eq!([batch, padded_vocab], x.dims());
149
150 (x, caches)
151 }
152}