Skip to main content

burn_mamba/mamba3/bidi/naive/
layer.rs

1use crate::mamba3::bidi::naive::{OutputMerge, OutputMergeConfig};
2use crate::mamba3::prelude::*;
3use crate::schedule::BidiSchedule;
4use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
5use burn::prelude::*;
6
7#[derive(Module, Debug)]
8pub struct Mamba3BidiLayers<B: Backend> {
9    pub n_real_layers: usize,
10    #[module(skip)]
11    pub n_virtual_layers: Option<(usize, BidiSchedule)>,
12    /// # Shape
13    /// - [n_real_layers]
14    pub real_layers: Vec<Mamba3Layer<B>>,
15    pub ignore_first_residual: bool,
16    pub ignore_last_residual: bool,
17    /// # Shape
18    /// - [n_real_layers / 2]
19    pub outputs_merge: Vec<OutputMerge<B>>,
20}
21
22#[derive(Config, Debug)]
23pub struct Mamba3BidiLayersConfig {
24    pub n_real_layers: usize,
25    #[config(default = "None")]
26    pub n_virtual_layers: Option<(usize, BidiSchedule)>,
27    pub mamba_block: Mamba3Config,
28    #[config(default = false)]
29    pub ignore_first_residual: bool,
30    #[config(default = false)]
31    pub ignore_last_residual: bool,
32
33    /// # Shape
34    /// - [n_real_layers / 2]
35    pub outputs_merge: Vec<OutputMergeConfig>,
36}
37
38impl Mamba3BidiLayersConfig {
39    /// Returns the initialized model.
40    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3BidiLayers<B> {
41        let d_model = self.mamba_block.d_model;
42        let mut real_layers = Vec::with_capacity(self.n_real_layers);
43        let mut outputs_merge = Vec::with_capacity(self.n_real_layers);
44        for _ in 0..self.n_real_layers {
45            let block_config = self.mamba_block.clone();
46            let layer = Mamba3LayerConfig::new(block_config).init(device);
47            real_layers.push(layer);
48        }
49        for i in 0..self.n_real_layers / 2 {
50            let output_merge = self.outputs_merge.get(i).unwrap().init(d_model, device);
51            outputs_merge.push(output_merge);
52        }
53
54        Mamba3BidiLayers {
55            n_real_layers: self.n_real_layers,
56            n_virtual_layers: self.n_virtual_layers.clone(),
57            real_layers,
58            ignore_first_residual: self.ignore_first_residual,
59            ignore_last_residual: self.ignore_last_residual,
60            outputs_merge,
61        }
62    }
63}
64
65impl<B: Backend + Mamba3BackendExt> Mamba3BidiLayers<B> {
66    /// # Shapes
67    ///   - Input [batch, sequence, d_model]
68    ///   - Output [batch, sequence, d_model]
69    pub fn forward(
70        &self,
71        mut x: Tensor<B, 3>,
72        caches: Option<Mamba3Caches<B>>,
73        // straight_caches: Option<Mamba3Caches<B>>,
74        // reverse_caches: Option<Mamba3Caches<B>>,
75        ssd_path: Mamba3SsdPath,
76    ) -> (Tensor<B, 3>, Mamba3Caches<B>) {
77        let n_virtual_layers = self
78            .n_virtual_layers
79            .as_ref()
80            .map(|(l, _schedule)| {
81                assert!(l.is_multiple_of(2), "Bidi virtual layers are used in pairs");
82                *l
83            })
84            .unwrap_or({
85                assert!(
86                    self.n_real_layers.is_multiple_of(2),
87                    "Bidi layers are used in pairs"
88                );
89                // virtual layers fallback to the real layers
90                self.n_real_layers
91            });
92
93        let caches = caches.unwrap_or_else(|| {
94            let device = &x.device();
95            let [batch, _sequence, _d_model] = x.dims();
96            let layer0_block = &self.real_layers[0].mamba_block;
97
98            Mamba3CachesConfig::new(
99                n_virtual_layers,
100                Mamba3CacheConfig {
101                    batch,
102                    state_rank: layer0_block.state_rank,
103                    num_rope_angles: layer0_block.num_rope_angles,
104                    per_head_dim: layer0_block.per_head_dim(),
105                    nheads: layer0_block.nheads(),
106                    mimo_rank: layer0_block.mimo_rank,
107                },
108            )
109            .init(device)
110        });
111
112        // assertions
113        assert_eq!(
114            caches.caches.len(),
115            n_virtual_layers,
116            "straight and reverse layers in forward() currently cannot share caches"
117        );
118
119        let mut caches: Vec<Option<Mamba3Cache<B>>> = caches.caches.into_iter().map(Some).collect();
120
121        for i in 0..n_virtual_layers / 2 {
122            // use real layers by reference (clone)
123            let (straight_i, reverse_i) = (i * 2, i * 2 + 1);
124            let (straight_layer_idx, reverse_layer_idx) =
125                if let Some((n_virtual_layers, bidi_schedule)) = &self.n_virtual_layers {
126                    (
127                        bidi_schedule.real_idx(straight_i, *n_virtual_layers, self.n_real_layers),
128                        bidi_schedule.real_idx(reverse_i, *n_virtual_layers, self.n_real_layers),
129                    )
130                } else {
131                    (straight_i, reverse_i)
132                };
133            let straight_layer = self.real_layers.get(straight_layer_idx).unwrap();
134            let reverse_layer = self.real_layers.get(reverse_layer_idx).unwrap();
135
136            let straight_cache_idx = straight_i;
137            let reverse_cache_idx = reverse_i;
138            let straight_cache =
139                core::mem::take(caches.get_mut(straight_cache_idx).unwrap()).unwrap();
140            let reverse_cache =
141                core::mem::take(caches.get_mut(reverse_cache_idx).unwrap()).unwrap();
142
143            let residual_scale = if (self.ignore_first_residual && i == 0)
144                || (self.ignore_last_residual && i + 1 == n_virtual_layers / 2)
145            {
146                0.0
147            } else {
148                1.0
149            };
150
151            let bidi_pair = Mamba3BidiLayerPair {
152                straight_norm: straight_layer.norm.clone(),
153                reverse_norm: reverse_layer.norm.clone(),
154                straight_block: straight_layer.mamba_block.clone(),
155                reverse_block: reverse_layer.mamba_block.clone(),
156                output_merge: self.outputs_merge.get(i).unwrap().clone(),
157                residual_scale,
158            };
159
160            let (x_, straight_cache_, reverse_cache_) = bidi_pair.forward(
161                x,
162                Some(straight_cache),
163                Some(reverse_cache),
164                ssd_path.clone(),
165            );
166            x = x_;
167            caches[straight_cache_idx] = Some(straight_cache_);
168            caches[reverse_cache_idx] = Some(reverse_cache_);
169        }
170
171        let caches = Mamba3Caches {
172            caches: caches.into_iter().map(|c| c.unwrap()).collect(),
173        };
174
175        (x, caches)
176    }
177}
178
179#[derive(Module, Debug)]
180pub struct Mamba3BidiLayerPair<B: Backend> {
181    pub straight_norm: RmsNorm<B>,
182    pub reverse_norm: RmsNorm<B>,
183    pub straight_block: Mamba3<B>,
184    pub reverse_block: Mamba3<B>,
185    pub output_merge: OutputMerge<B>,
186    pub residual_scale: f32,
187}
188
189#[derive(Config, Debug)]
190pub struct Mamba3BidiLayerPairConfig {
191    pub straight_block: Mamba3Config,
192    pub reverse_block: Mamba3Config,
193    #[config(default = 1.0)]
194    pub residual_scale: f32,
195    pub output_merge: OutputMergeConfig,
196}
197
198impl Mamba3BidiLayerPairConfig {
199    /// Returns the initialized model.
200    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba3BidiLayerPair<B> {
201        let d_model = self.straight_block.d_model;
202        Mamba3BidiLayerPair {
203            straight_norm: RmsNormConfig::new(self.straight_block.d_model).init(device),
204            reverse_norm: RmsNormConfig::new(self.reverse_block.d_model).init(device),
205            straight_block: self.straight_block.init(device),
206            reverse_block: self.reverse_block.init(device),
207            residual_scale: self.residual_scale,
208            output_merge: self.output_merge.init(d_model, device),
209        }
210    }
211}
212
213impl<B: Backend + Mamba3BackendExt> Mamba3BidiLayerPair<B> {
214    /// # Shapes
215    ///   - Input [batch, sequence, d_model]
216    ///   - Output.0 [batch, sequence, d_model]
217    pub fn forward(
218        &self,
219        x: Tensor<B, 3>,
220        straight_cache: Option<Mamba3Cache<B>>,
221        reverse_cache: Option<Mamba3Cache<B>>,
222        ssd_path: Mamba3SsdPath,
223    ) -> (Tensor<B, 3>, Mamba3Cache<B>, Mamba3Cache<B>) {
224        let [batch, sequence, d_model] = x.dims();
225
226        let res = x.clone() * self.residual_scale;
227
228        // x is read as >x₀>x₁>x₂>x₃
229        // x_rev is read as >x₃>x₂>x₁>x₀, i.e. x₀<x₁<x₂<x₃<
230        let x_rev = x.clone().flip([1]); // flip sequence-wise
231
232        // each layer (as stored) carries their own norm,
233        // but perhaps it's redundant to apply two of them after the flip.
234        // i.e. maybe a single norm applied before the flip is better
235        let x = self.straight_norm.forward(x);
236        let x_rev = self.reverse_norm.forward(x_rev);
237
238        // straight reads inputs as:
239        // t₀ >x₀
240        // t₁ >x₀>x₁
241        // t₂ >x₀>x₁>x₂
242        // t₃ >x₀>x₁>x₂>x₃
243        let (x, straight_cache) = self
244            .straight_block
245            .forward(x, straight_cache, ssd_path.clone());
246        debug_assert_eq!([batch, sequence, d_model], x.dims());
247
248        // reverse reads inputs as:
249        // t₀        x₃<
250        // t₁      x₂<x₃<
251        // t₂   x₁<x₂<x₃<
252        // t₃ x₀<x₁<x₂<x₃<
253        let (x_rev, reverse_cache) = self.reverse_block.forward(x_rev, reverse_cache, ssd_path);
254        debug_assert_eq!([batch, sequence, d_model], x_rev.dims());
255
256        // re-align the reversed read:
257        // t₀ x₀<x₁<x₂<x₃<
258        // t₁   x₁<x₂<x₃<
259        // t₂      x₂<x₃<
260        // t₃        x₃<
261        let x_rev = x_rev.flip([1]);
262
263        // merge both reads:
264        // t₀ merge(>x₀ , x₀<x₁<x₂<x₃<)
265        // t₁ merge(>x₀>x₁ , x₁<x₂<x₃<)
266        // t₂ merge(>x₀>x₁>x₂ , x₂<x₃<)
267        // t₃ merge(>x₀>x₁>x₂>x₃ , x₃<)
268        let merged = self.output_merge.forward(x, x_rev);
269
270        (merged + res, straight_cache, reverse_cache)
271    }
272}