Skip to main content

burn_mamba/mamba3/bidi/naive/
output_merge.rs

1use burn::nn::{Linear, LinearConfig};
2use burn::prelude::*;
3
4/// Used when a Module is expected.
5#[derive(Module, Clone, Debug)]
6pub struct NoOp;
7
8#[allow(clippy::large_enum_variant)]
9#[derive(Module, Debug)]
10pub enum OutputMerge<B: Backend> {
11    Mean(NoOp),
12    /// # Shape
13    /// - [2 * d_model, d_model]
14    CatLinear(Linear<B>),
15}
16
17/// # Shapes
18///   - Input straight [batch_size, sequence_len, d_model]
19///   - Input reverse [batch_size, sequence_len, d_model]
20///   - Output [batch_size, sequence_len, d_model]
21impl<B: Backend> OutputMerge<B> {
22    pub fn forward(&self, straight: Tensor<B, 3>, reverse: Tensor<B, 3>) -> Tensor<B, 3> {
23        let [batch_size, sequence_len, d_model] = straight.dims();
24        assert_eq!(straight.dims(), reverse.dims());
25        match self {
26            OutputMerge::Mean(_noop) => (straight + reverse) * 0.5,
27            OutputMerge::CatLinear(proj) => {
28                let cat = Tensor::cat([straight, reverse].to_vec(), 2);
29                assert_eq!([batch_size, sequence_len, 2 * d_model], cat.dims());
30                let merged = proj.forward(cat);
31                assert_eq!([batch_size, sequence_len, d_model], merged.dims());
32                merged
33            }
34        }
35    }
36}
37
38#[derive(Config, Debug)]
39pub enum OutputMergeConfig {
40    Mean,
41    CatLinear,
42}
43
44impl OutputMergeConfig {
45    pub fn mean(n_real_layers: usize) -> Vec<Self> {
46        vec![Self::Mean; n_real_layers / 2]
47    }
48    pub fn cat_linear(n_real_layers: usize) -> Vec<Self> {
49        vec![Self::CatLinear; n_real_layers / 2]
50    }
51
52    pub fn init<B: Backend>(&self, d_model: usize, device: &B::Device) -> OutputMerge<B> {
53        match self {
54            OutputMergeConfig::Mean => OutputMerge::Mean(NoOp),
55            OutputMergeConfig::CatLinear => {
56                let cat_linear = LinearConfig::new(d_model * 2, d_model).init(device);
57                OutputMerge::CatLinear(cat_linear)
58            }
59        }
60    }
61}