burn_mamba/mamba3/bidi/naive/
layer.rs1use crate::mamba3::bidi::naive::{OutputMerge, OutputMergeConfig};
2use crate::mamba3::prelude::*;
3use crate::schedule::BidiSchedule;
4use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
5use burn::prelude::*;
6
7#[derive(Module, Debug)]
8pub struct Mamba3BidiLayers<B: Backend> {
9 pub n_real_layers: usize,
10 #[module(skip)]
11 pub n_virtual_layers: Option<(usize, BidiSchedule)>,
12 pub real_layers: Vec<Mamba3Layer<B>>,
15 pub ignore_first_residual: bool,
16 pub ignore_last_residual: bool,
17 pub outputs_merge: Vec<OutputMerge<B>>,
20}
21
22#[derive(Config, Debug)]
23pub struct Mamba3BidiLayersConfig {
24 pub n_real_layers: usize,
25 #[config(default = "None")]
26 pub n_virtual_layers: Option<(usize, BidiSchedule)>,
27 pub mamba_block: Mamba3Config,
28 #[config(default = false)]
29 pub ignore_first_residual: bool,
30 #[config(default = false)]
31 pub ignore_last_residual: bool,
32
33 pub outputs_merge: Vec<OutputMergeConfig>,
36}
37
38impl Mamba3BidiLayersConfig {
39 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3BidiLayers<B> {
41 let d_model = self.mamba_block.d_model;
42 let mut real_layers = Vec::with_capacity(self.n_real_layers);
43 let mut outputs_merge = Vec::with_capacity(self.n_real_layers);
44 for _ in 0..self.n_real_layers {
45 let block_config = self.mamba_block.clone();
46 let layer = Mamba3LayerConfig::new(block_config).init(device);
47 real_layers.push(layer);
48 }
49 for i in 0..self.n_real_layers / 2 {
50 let output_merge = self.outputs_merge.get(i).unwrap().init(d_model, device);
51 outputs_merge.push(output_merge);
52 }
53
54 Mamba3BidiLayers {
55 n_real_layers: self.n_real_layers,
56 n_virtual_layers: self.n_virtual_layers.clone(),
57 real_layers,
58 ignore_first_residual: self.ignore_first_residual,
59 ignore_last_residual: self.ignore_last_residual,
60 outputs_merge,
61 }
62 }
63}
64
65impl<B: Backend + Mamba3BackendExt> Mamba3BidiLayers<B> {
66 pub fn forward(
70 &self,
71 mut x: Tensor<B, 3>,
72 caches: Option<Mamba3Caches<B>>,
73 ssd_path: Mamba3SsdPath,
76 ) -> (Tensor<B, 3>, Mamba3Caches<B>) {
77 let n_virtual_layers = self
78 .n_virtual_layers
79 .as_ref()
80 .map(|(l, _schedule)| {
81 assert!(l.is_multiple_of(2), "Bidi virtual layers are used in pairs");
82 *l
83 })
84 .unwrap_or({
85 assert!(
86 self.n_real_layers.is_multiple_of(2),
87 "Bidi layers are used in pairs"
88 );
89 self.n_real_layers
91 });
92
93 let caches = caches.unwrap_or_else(|| {
94 let device = &x.device();
95 let [batch, _sequence, _d_model] = x.dims();
96 let layer0_block = &self.real_layers[0].mamba_block;
97
98 Mamba3CachesConfig::new(
99 n_virtual_layers,
100 Mamba3CacheConfig {
101 batch,
102 state_rank: layer0_block.state_rank,
103 num_rope_angles: layer0_block.num_rope_angles,
104 per_head_dim: layer0_block.per_head_dim(),
105 nheads: layer0_block.nheads(),
106 mimo_rank: layer0_block.mimo_rank,
107 },
108 )
109 .init(device)
110 });
111
112 assert_eq!(
114 caches.caches.len(),
115 n_virtual_layers,
116 "straight and reverse layers in forward() currently cannot share caches"
117 );
118
119 let mut caches: Vec<Option<Mamba3Cache<B>>> = caches.caches.into_iter().map(Some).collect();
120
121 for i in 0..n_virtual_layers / 2 {
122 let (straight_i, reverse_i) = (i * 2, i * 2 + 1);
124 let (straight_layer_idx, reverse_layer_idx) =
125 if let Some((n_virtual_layers, bidi_schedule)) = &self.n_virtual_layers {
126 (
127 bidi_schedule.real_idx(straight_i, *n_virtual_layers, self.n_real_layers),
128 bidi_schedule.real_idx(reverse_i, *n_virtual_layers, self.n_real_layers),
129 )
130 } else {
131 (straight_i, reverse_i)
132 };
133 let straight_layer = self.real_layers.get(straight_layer_idx).unwrap();
134 let reverse_layer = self.real_layers.get(reverse_layer_idx).unwrap();
135
136 let straight_cache_idx = straight_i;
137 let reverse_cache_idx = reverse_i;
138 let straight_cache =
139 core::mem::take(caches.get_mut(straight_cache_idx).unwrap()).unwrap();
140 let reverse_cache =
141 core::mem::take(caches.get_mut(reverse_cache_idx).unwrap()).unwrap();
142
143 let residual_scale = if (self.ignore_first_residual && i == 0)
144 || (self.ignore_last_residual && i + 1 == n_virtual_layers / 2)
145 {
146 0.0
147 } else {
148 1.0
149 };
150
151 let bidi_pair = Mamba3BidiLayerPair {
152 straight_norm: straight_layer.norm.clone(),
153 reverse_norm: reverse_layer.norm.clone(),
154 straight_block: straight_layer.mamba_block.clone(),
155 reverse_block: reverse_layer.mamba_block.clone(),
156 output_merge: self.outputs_merge.get(i).unwrap().clone(),
157 residual_scale,
158 };
159
160 let (x_, straight_cache_, reverse_cache_) = bidi_pair.forward(
161 x,
162 Some(straight_cache),
163 Some(reverse_cache),
164 ssd_path.clone(),
165 );
166 x = x_;
167 caches[straight_cache_idx] = Some(straight_cache_);
168 caches[reverse_cache_idx] = Some(reverse_cache_);
169 }
170
171 let caches = Mamba3Caches {
172 caches: caches.into_iter().map(|c| c.unwrap()).collect(),
173 };
174
175 (x, caches)
176 }
177}
178
179#[derive(Module, Debug)]
180pub struct Mamba3BidiLayerPair<B: Backend> {
181 pub straight_norm: RmsNorm<B>,
182 pub reverse_norm: RmsNorm<B>,
183 pub straight_block: Mamba3<B>,
184 pub reverse_block: Mamba3<B>,
185 pub output_merge: OutputMerge<B>,
186 pub residual_scale: f32,
187}
188
189#[derive(Config, Debug)]
190pub struct Mamba3BidiLayerPairConfig {
191 pub straight_block: Mamba3Config,
192 pub reverse_block: Mamba3Config,
193 #[config(default = 1.0)]
194 pub residual_scale: f32,
195 pub output_merge: OutputMergeConfig,
196}
197
198impl Mamba3BidiLayerPairConfig {
199 pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3BidiLayerPair<B> {
201 let d_model = self.straight_block.d_model;
202 Mamba3BidiLayerPair {
203 straight_norm: RmsNormConfig::new(self.straight_block.d_model).init(device),
204 reverse_norm: RmsNormConfig::new(self.reverse_block.d_model).init(device),
205 straight_block: self.straight_block.init(device),
206 reverse_block: self.reverse_block.init(device),
207 residual_scale: self.residual_scale,
208 output_merge: self.output_merge.init(d_model, device),
209 }
210 }
211}
212
213impl<B: Backend + Mamba3BackendExt> Mamba3BidiLayerPair<B> {
214 pub fn forward(
218 &self,
219 x: Tensor<B, 3>,
220 straight_cache: Option<Mamba3Cache<B>>,
221 reverse_cache: Option<Mamba3Cache<B>>,
222 ssd_path: Mamba3SsdPath,
223 ) -> (Tensor<B, 3>, Mamba3Cache<B>, Mamba3Cache<B>) {
224 let [batch, sequence, d_model] = x.dims();
225
226 let res = x.clone() * self.residual_scale;
227
228 let x_rev = x.clone().flip([1]); let x = self.straight_norm.forward(x);
236 let x_rev = self.reverse_norm.forward(x_rev);
237
238 let (x, straight_cache) = self
244 .straight_block
245 .forward(x, straight_cache, ssd_path.clone());
246 debug_assert_eq!([batch, sequence, d_model], x.dims());
247
248 let (x_rev, reverse_cache) = self.reverse_block.forward(x_rev, reverse_cache, ssd_path);
254 debug_assert_eq!([batch, sequence, d_model], x_rev.dims());
255
256 let x_rev = x_rev.flip([1]);
262
263 let merged = self.output_merge.forward(x, x_rev);
269
270 (merged + res, straight_cache, reverse_cache)
271 }
272}