Skip to main content

burn_mamba/mamba2/bidi/naive/
layer.rs

1use crate::mamba2::bidi::naive::{OutputMerge, OutputMergeConfig};
2use crate::mamba2::prelude::*;
3use crate::schedule::BidiSchedule;
4use crate::utils::rms_norm::{RmsNorm, RmsNormConfig};
5use burn::prelude::*;
6
7#[derive(Module, Debug)]
8pub struct Mamba2BidiLayers<B: Backend> {
9    pub n_real_layers: usize,
10    #[module(skip)]
11    pub n_virtual_layers: Option<(usize, BidiSchedule)>,
12    /// # Shape
13    /// - [n_real_layers]
14    pub real_layers: Vec<Mamba2Layer<B>>,
15    pub ignore_first_residual: bool,
16    pub ignore_last_residual: bool,
17    /// # Shape
18    /// - [n_real_layers / 2]
19    pub outputs_merge: Vec<OutputMerge<B>>,
20}
21
22#[derive(Config, Debug)]
23pub struct Mamba2BidiLayersConfig {
24    pub n_real_layers: usize,
25    #[config(default = "None")]
26    pub n_virtual_layers: Option<(usize, BidiSchedule)>,
27    pub mamba_block: Mamba2Config,
28    #[config(default = false)]
29    pub ignore_first_residual: bool,
30    #[config(default = false)]
31    pub ignore_last_residual: bool,
32
33    /// # Shape
34    /// - [n_real_layers / 2]
35    pub outputs_merge: Vec<OutputMergeConfig>,
36}
37
38impl Mamba2BidiLayersConfig {
39    /// Returns the initialized model.
40    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2BidiLayers<B> {
41        let d_model = self.mamba_block.d_model;
42        let mut real_layers = Vec::with_capacity(self.n_real_layers);
43        let mut outputs_merge = Vec::with_capacity(self.n_real_layers);
44        for _ in 0..self.n_real_layers {
45            let block_config = self.mamba_block.clone();
46            let layer = Mamba2LayerConfig::new(block_config).init(device);
47            real_layers.push(layer);
48        }
49        for i in 0..self.n_real_layers / 2 {
50            let output_merge = self.outputs_merge.get(i).unwrap().init(d_model, device);
51            outputs_merge.push(output_merge);
52        }
53
54        Mamba2BidiLayers {
55            n_real_layers: self.n_real_layers,
56            n_virtual_layers: self.n_virtual_layers.clone(),
57            real_layers,
58            ignore_first_residual: self.ignore_first_residual,
59            ignore_last_residual: self.ignore_last_residual,
60            outputs_merge,
61        }
62    }
63}
64
65impl<B: Backend + Mamba2BackendExt> Mamba2BidiLayers<B> {
66    /// # Shapes
67    ///   - Input [batch, sequence, d_model]
68    ///   - Output [batch, sequence, d_model]
69    pub fn forward(
70        &self,
71        mut x: Tensor<B, 3>,
72        caches: Option<Mamba2Caches<B>>,
73        // straight_caches: Option<Mamba2Caches<B>>,
74        // reverse_caches: Option<Mamba2Caches<B>>,
75        ssd_path: Mamba2SsdPath,
76    ) -> (Tensor<B, 3>, Mamba2Caches<B>) {
77        let n_virtual_layers = self
78            .n_virtual_layers
79            .as_ref()
80            .map(|(l, _schedule)| {
81                assert!(l.is_multiple_of(2), "Bidi virtual layers are used in pairs");
82                *l
83            })
84            .unwrap_or({
85                assert!(
86                    self.n_real_layers.is_multiple_of(2),
87                    "Bidi layers are used in pairs"
88                );
89                // virtual layers fallback to the real layers
90                self.n_real_layers
91            });
92
93        let caches = caches.unwrap_or_else(|| {
94            let device = &x.device();
95            let [batch, _sequence, _d_model] = x.dims();
96            let layer0_block = &self.real_layers[0].mamba_block;
97            let [conv_dim, _, conv_kernel] = layer0_block.conv1d.weight.dims();
98
99            Mamba2CachesConfig::new(
100                n_virtual_layers,
101                Mamba2CacheConfig {
102                    batch,
103                    state_rank: layer0_block.state_rank,
104                    conv_kernel,
105                    conv_dim,
106                    per_head_dim: layer0_block.per_head_dim(),
107                    nheads: layer0_block.nheads(),
108                },
109            )
110            .init(device)
111        });
112
113        // assertions
114        assert_eq!(
115            caches.caches.len(),
116            n_virtual_layers,
117            "straight and reverse layers in forward() currently cannot share caches"
118        );
119
120        let mut caches: Vec<Option<Mamba2Cache<B>>> = caches.caches.into_iter().map(Some).collect();
121
122        for i in 0..n_virtual_layers / 2 {
123            // use real layers by reference (clone)
124            let (straight_i, reverse_i) = (i * 2, i * 2 + 1);
125            let (straight_layer_idx, reverse_layer_idx) =
126                if let Some((n_virtual_layers, bidi_schedule)) = &self.n_virtual_layers {
127                    (
128                        bidi_schedule.real_idx(straight_i, *n_virtual_layers, self.n_real_layers),
129                        bidi_schedule.real_idx(reverse_i, *n_virtual_layers, self.n_real_layers),
130                    )
131                } else {
132                    (straight_i, reverse_i)
133                };
134            let straight_layer = self.real_layers.get(straight_layer_idx).unwrap();
135            let reverse_layer = self.real_layers.get(reverse_layer_idx).unwrap();
136
137            let straight_cache_idx = straight_i;
138            let reverse_cache_idx = reverse_i;
139            let straight_cache =
140                core::mem::take(caches.get_mut(straight_cache_idx).unwrap()).unwrap();
141            let reverse_cache =
142                core::mem::take(caches.get_mut(reverse_cache_idx).unwrap()).unwrap();
143
144            let residual_scale = if (self.ignore_first_residual && i == 0)
145                || (self.ignore_last_residual && i + 1 == n_virtual_layers / 2)
146            {
147                0.0
148            } else {
149                1.0
150            };
151
152            let bidi_pair = Mamba2BidiLayerPair {
153                straight_norm: straight_layer.norm.clone(),
154                reverse_norm: reverse_layer.norm.clone(),
155                straight_block: straight_layer.mamba_block.clone(),
156                reverse_block: reverse_layer.mamba_block.clone(),
157                output_merge: self.outputs_merge.get(i).unwrap().clone(),
158                residual_scale,
159            };
160
161            let (x_, straight_cache_, reverse_cache_) = bidi_pair.forward(
162                x,
163                Some(straight_cache),
164                Some(reverse_cache),
165                ssd_path.clone(),
166            );
167            x = x_;
168            caches[straight_cache_idx] = Some(straight_cache_);
169            caches[reverse_cache_idx] = Some(reverse_cache_);
170        }
171
172        let caches = Mamba2Caches {
173            caches: caches.into_iter().map(|c| c.unwrap()).collect(),
174        };
175
176        (x, caches)
177    }
178}
179
180#[derive(Module, Debug)]
181pub struct Mamba2BidiLayerPair<B: Backend> {
182    pub straight_norm: RmsNorm<B>,
183    pub reverse_norm: RmsNorm<B>,
184    pub straight_block: Mamba2<B>,
185    pub reverse_block: Mamba2<B>,
186    pub output_merge: OutputMerge<B>,
187    pub residual_scale: f32,
188}
189
190#[derive(Config, Debug)]
191pub struct Mamba2BidiLayerPairConfig {
192    pub straight_block: Mamba2Config,
193    pub reverse_block: Mamba2Config,
194    #[config(default = 1.0)]
195    pub residual_scale: f32,
196    pub output_merge: OutputMergeConfig,
197}
198
199impl Mamba2BidiLayerPairConfig {
200    /// Returns the initialized model.
201    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba2BidiLayerPair<B> {
202        let d_model = self.straight_block.d_model;
203        Mamba2BidiLayerPair {
204            straight_norm: RmsNormConfig::new(self.straight_block.d_model).init(device),
205            reverse_norm: RmsNormConfig::new(self.reverse_block.d_model).init(device),
206            straight_block: self.straight_block.init(device),
207            reverse_block: self.reverse_block.init(device),
208            residual_scale: self.residual_scale,
209            output_merge: self.output_merge.init(d_model, device),
210        }
211    }
212}
213
214impl<B: Backend + Mamba2BackendExt> Mamba2BidiLayerPair<B> {
215    /// # Shapes
216    ///   - Input [batch, sequence, d_model]
217    ///   - Output.0 [batch, sequence, d_model]
218    pub fn forward(
219        &self,
220        x: Tensor<B, 3>,
221        straight_cache: Option<Mamba2Cache<B>>,
222        reverse_cache: Option<Mamba2Cache<B>>,
223        ssd_path: Mamba2SsdPath,
224    ) -> (Tensor<B, 3>, Mamba2Cache<B>, Mamba2Cache<B>) {
225        let [batch, sequence, d_model] = x.dims();
226
227        let res = x.clone() * self.residual_scale;
228
229        // x is read as >x₀>x₁>x₂>x₃
230        // x_rev is read as >x₃>x₂>x₁>x₀, i.e. x₀<x₁<x₂<x₃<
231        let x_rev = x.clone().flip([1]); // flip sequence-wise
232
233        // each layer (as stored) carries their own norm,
234        // but perhaps it's redundant to apply two of them after the flip.
235        // i.e. maybe a single norm applied before the flip is better
236        let x = self.straight_norm.forward(x);
237        let x_rev = self.reverse_norm.forward(x_rev);
238
239        // straight reads inputs as:
240        // t₀ >x₀
241        // t₁ >x₀>x₁
242        // t₂ >x₀>x₁>x₂
243        // t₃ >x₀>x₁>x₂>x₃
244        let (x, straight_cache) = self
245            .straight_block
246            .forward(x, straight_cache, ssd_path.clone());
247        debug_assert_eq!([batch, sequence, d_model], x.dims());
248
249        // reverse reads inputs as:
250        // t₀        x₃<
251        // t₁      x₂<x₃<
252        // t₂   x₁<x₂<x₃<
253        // t₃ x₀<x₁<x₂<x₃<
254        let (x_rev, reverse_cache) = self.reverse_block.forward(x_rev, reverse_cache, ssd_path);
255        debug_assert_eq!([batch, sequence, d_model], x_rev.dims());
256
257        // re-align the reversed read:
258        // t₀ x₀<x₁<x₂<x₃<
259        // t₁   x₁<x₂<x₃<
260        // t₂      x₂<x₃<
261        // t₃        x₃<
262        let x_rev = x_rev.flip([1]);
263
264        // merge both reads:
265        // t₀ merge(>x₀ , x₀<x₁<x₂<x₃<)
266        // t₁ merge(>x₀>x₁ , x₁<x₂<x₃<)
267        // t₂ merge(>x₀>x₁>x₂ , x₂<x₃<)
268        // t₃ merge(>x₀>x₁>x₂>x₃ , x₃<)
269        let merged = self.output_merge.forward(x, x_rev);
270
271        (merged + res, straight_cache, reverse_cache)
272    }
273}