burn_mamba/mamba2/bidi/naive/
layer.rs1use crate::mamba2::bidi::naive::{OutputMerge, OutputMergeConfig};
2use crate::mamba2::prelude::*;
3use crate::schedule::BidiSchedule;
4use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
5use burn::prelude::*;
6
7#[derive(Module, Debug)]
8pub struct Mamba2BidiLayers<B: Backend> {
9 pub n_real_layers: usize,
10 #[module(skip)]
11 pub n_virtual_layers: Option<(usize, BidiSchedule)>,
12 pub real_layers: Vec<Mamba2Layer<B>>,
15 pub ignore_first_residual: bool,
16 pub ignore_last_residual: bool,
17 pub outputs_merge: Vec<OutputMerge<B>>,
20}
21
22#[derive(Config, Debug)]
23pub struct Mamba2BidiLayersConfig {
24 pub n_real_layers: usize,
25 #[config(default = "None")]
26 pub n_virtual_layers: Option<(usize, BidiSchedule)>,
27 pub mamba_block: Mamba2Config,
28 #[config(default = false)]
29 pub ignore_first_residual: bool,
30 #[config(default = false)]
31 pub ignore_last_residual: bool,
32
33 pub outputs_merge: Vec<OutputMergeConfig>,
36}
37
38impl Mamba2BidiLayersConfig {
39 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2BidiLayers<B> {
41 let d_model = self.mamba_block.d_model;
42 let mut real_layers = Vec::with_capacity(self.n_real_layers);
43 let mut outputs_merge = Vec::with_capacity(self.n_real_layers);
44 for _ in 0..self.n_real_layers {
45 let block_config = self.mamba_block.clone();
46 let layer = Mamba2LayerConfig::new(block_config).init(device);
47 real_layers.push(layer);
48 }
49 for i in 0..self.n_real_layers / 2 {
50 let output_merge = self.outputs_merge.get(i).unwrap().init(d_model, device);
51 outputs_merge.push(output_merge);
52 }
53
54 Mamba2BidiLayers {
55 n_real_layers: self.n_real_layers,
56 n_virtual_layers: self.n_virtual_layers.clone(),
57 real_layers,
58 ignore_first_residual: self.ignore_first_residual,
59 ignore_last_residual: self.ignore_last_residual,
60 outputs_merge,
61 }
62 }
63}
64
65impl<B: Backend + Mamba2BackendExt> Mamba2BidiLayers<B> {
66 pub fn forward(
70 &self,
71 mut x: Tensor<B, 3>,
72 caches: Option<Mamba2Caches<B>>,
73 ssd_path: Mamba2SsdPath,
76 ) -> (Tensor<B, 3>, Mamba2Caches<B>) {
77 let n_virtual_layers = self
78 .n_virtual_layers
79 .as_ref()
80 .map(|(l, _schedule)| {
81 assert!(l.is_multiple_of(2), "Bidi virtual layers are used in pairs");
82 *l
83 })
84 .unwrap_or({
85 assert!(
86 self.n_real_layers.is_multiple_of(2),
87 "Bidi layers are used in pairs"
88 );
89 self.n_real_layers
91 });
92
93 let caches = caches.unwrap_or_else(|| {
94 let device = &x.device();
95 let [batch, _sequence, _d_model] = x.dims();
96 let layer0_block = &self.real_layers[0].mamba_block;
97 let [conv_dim, _, conv_kernel] = layer0_block.conv1d.weight.dims();
98
99 Mamba2CachesConfig::new(
100 n_virtual_layers,
101 Mamba2CacheConfig {
102 batch,
103 state_rank: layer0_block.state_rank,
104 conv_kernel,
105 conv_dim,
106 per_head_dim: layer0_block.per_head_dim(),
107 nheads: layer0_block.nheads(),
108 },
109 )
110 .init(device)
111 });
112
113 assert_eq!(
115 caches.caches.len(),
116 n_virtual_layers,
117 "straight and reverse layers in forward() currently cannot share caches"
118 );
119
120 let mut caches: Vec<Option<Mamba2Cache<B>>> = caches.caches.into_iter().map(Some).collect();
121
122 for i in 0..n_virtual_layers / 2 {
123 let (straight_i, reverse_i) = (i * 2, i * 2 + 1);
125 let (straight_layer_idx, reverse_layer_idx) =
126 if let Some((n_virtual_layers, bidi_schedule)) = &self.n_virtual_layers {
127 (
128 bidi_schedule.real_idx(straight_i, *n_virtual_layers, self.n_real_layers),
129 bidi_schedule.real_idx(reverse_i, *n_virtual_layers, self.n_real_layers),
130 )
131 } else {
132 (straight_i, reverse_i)
133 };
134 let straight_layer = self.real_layers.get(straight_layer_idx).unwrap();
135 let reverse_layer = self.real_layers.get(reverse_layer_idx).unwrap();
136
137 let straight_cache_idx = straight_i;
138 let reverse_cache_idx = reverse_i;
139 let straight_cache =
140 core::mem::take(caches.get_mut(straight_cache_idx).unwrap()).unwrap();
141 let reverse_cache =
142 core::mem::take(caches.get_mut(reverse_cache_idx).unwrap()).unwrap();
143
144 let residual_scale = if (self.ignore_first_residual && i == 0)
145 || (self.ignore_last_residual && i + 1 == n_virtual_layers / 2)
146 {
147 0.0
148 } else {
149 1.0
150 };
151
152 let bidi_pair = Mamba2BidiLayerPair {
153 straight_norm: straight_layer.norm.clone(),
154 reverse_norm: reverse_layer.norm.clone(),
155 straight_block: straight_layer.mamba_block.clone(),
156 reverse_block: reverse_layer.mamba_block.clone(),
157 output_merge: self.outputs_merge.get(i).unwrap().clone(),
158 residual_scale,
159 };
160
161 let (x_, straight_cache_, reverse_cache_) = bidi_pair.forward(
162 x,
163 Some(straight_cache),
164 Some(reverse_cache),
165 ssd_path.clone(),
166 );
167 x = x_;
168 caches[straight_cache_idx] = Some(straight_cache_);
169 caches[reverse_cache_idx] = Some(reverse_cache_);
170 }
171
172 let caches = Mamba2Caches {
173 caches: caches.into_iter().map(|c| c.unwrap()).collect(),
174 };
175
176 (x, caches)
177 }
178}
179
180#[derive(Module, Debug)]
181pub struct Mamba2BidiLayerPair<B: Backend> {
182 pub straight_norm: RmsNorm<B>,
183 pub reverse_norm: RmsNorm<B>,
184 pub straight_block: Mamba2<B>,
185 pub reverse_block: Mamba2<B>,
186 pub output_merge: OutputMerge<B>,
187 pub residual_scale: f32,
188}
189
190#[derive(Config, Debug)]
191pub struct Mamba2BidiLayerPairConfig {
192 pub straight_block: Mamba2Config,
193 pub reverse_block: Mamba2Config,
194 #[config(default = 1.0)]
195 pub residual_scale: f32,
196 pub output_merge: OutputMergeConfig,
197}
198
199impl Mamba2BidiLayerPairConfig {
200 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2BidiLayerPair<B> {
202 let d_model = self.straight_block.d_model;
203 Mamba2BidiLayerPair {
204 straight_norm: RmsNormConfig::new(self.straight_block.d_model).init(device),
205 reverse_norm: RmsNormConfig::new(self.reverse_block.d_model).init(device),
206 straight_block: self.straight_block.init(device),
207 reverse_block: self.reverse_block.init(device),
208 residual_scale: self.residual_scale,
209 output_merge: self.output_merge.init(d_model, device),
210 }
211 }
212}
213
214impl<B: Backend + Mamba2BackendExt> Mamba2BidiLayerPair<B> {
215 pub fn forward(
219 &self,
220 x: Tensor<B, 3>,
221 straight_cache: Option<Mamba2Cache<B>>,
222 reverse_cache: Option<Mamba2Cache<B>>,
223 ssd_path: Mamba2SsdPath,
224 ) -> (Tensor<B, 3>, Mamba2Cache<B>, Mamba2Cache<B>) {
225 let [batch, sequence, d_model] = x.dims();
226
227 let res = x.clone() * self.residual_scale;
228
229 let x_rev = x.clone().flip([1]); let x = self.straight_norm.forward(x);
237 let x_rev = self.reverse_norm.forward(x_rev);
238
239 let (x, straight_cache) = self
245 .straight_block
246 .forward(x, straight_cache, ssd_path.clone());
247 debug_assert_eq!([batch, sequence, d_model], x.dims());
248
249 let (x_rev, reverse_cache) = self.reverse_block.forward(x_rev, reverse_cache, ssd_path);
255 debug_assert_eq!([batch, sequence, d_model], x_rev.dims());
256
257 let x_rev = x_rev.flip([1]);
263
264 let merged = self.output_merge.forward(x, x_rev);
270
271 (merged + res, straight_cache, reverse_cache)
272 }
273}