burn_mamba/mamba2/bidi/naive/
output_merge.rs1use burn::nn::{Linear, LinearConfig};
2use burn::prelude::*;
3
4#[derive(Module, Clone, Debug)]
6pub struct NoOp;
7
8#[allow(clippy::large_enum_variant)]
9#[derive(Module, Debug)]
10pub enum OutputMerge<B: Backend> {
11 Mean(NoOp),
12 CatLinear(Linear<B>),
15}
16
17impl<B: Backend> OutputMerge<B> {
22 pub fn forward(&self, straight: Tensor<B, 3>, reverse: Tensor<B, 3>) -> Tensor<B, 3> {
23 let [batch_size, sequence_len, d_model] = straight.dims();
24 assert_eq!(straight.dims(), reverse.dims());
25 match self {
26 OutputMerge::Mean(_noop) => (straight + reverse) * 0.5,
27 OutputMerge::CatLinear(proj) => {
28 let cat = Tensor::cat([straight, reverse].to_vec(), 2);
29 assert_eq!([batch_size, sequence_len, 2 * d_model], cat.dims());
30 let merged = proj.forward(cat);
31 assert_eq!([batch_size, sequence_len, d_model], merged.dims());
32 merged
33 }
34 }
35 }
36}
37
38#[derive(Config, Debug)]
39pub enum OutputMergeConfig {
40 Mean,
41 CatLinear,
42}
43
44impl OutputMergeConfig {
45 pub fn mean(n_real_layers: usize) -> Vec<Self> {
46 vec![Self::Mean; n_real_layers / 2]
47 }
48 pub fn cat_linear(n_real_layers: usize) -> Vec<Self> {
49 vec![Self::CatLinear; n_real_layers / 2]
50 }
51
52 pub fn init<B: Backend>(&self, d_model: usize, device: &B::Device) -> OutputMerge<B> {
53 match self {
54 OutputMergeConfig::Mean => OutputMerge::Mean(NoOp),
55 OutputMergeConfig::CatLinear => {
56 let cat_linear = LinearConfig::new(d_model * 2, d_model).init(device);
57 OutputMerge::CatLinear(cat_linear)
58 }
59 }
60 }
61}