Skip to main content

burn_mamba/utils/
rms_norm_gated.rs

1use crate::utils::div_eps;
2use crate::utils::silu::Silu;
3use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
4use burn::nn::Initializer;
5use burn::prelude::*;
6use burn::tensor::{DType, f16};
7
8/// Configuration to create a [RmsNormGated](RmsNormGated) layer.
9#[derive(Config, Debug)]
10pub struct RmsNormGatedConfig {
11    /// The size of the input features.
12    pub d_model: usize,
13    // // TODO: config epsilon is no longer used.
14    // /// A value required for numerical stability. Default: 1e-5
15    // #[config(default = 1e-5)]
16    // pub epsilon: f64,
17    /// Whether to apply normalization before gating. Default: true
18    #[config(default = true)]
19    pub norm_before_gate: bool,
20}
21
22impl RmsNormGatedConfig {
23    /// Initialize a new [RmsNormGated](RmsNormGated) module.
24    pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNormGated<B> {
25        // assert!(self.epsilon > 0.0, "epsilon must be positive.");
26
27        let gamma = Initializer::Ones.init([self.d_model], device);
28
29        RmsNormGated {
30            gamma,
31            // epsilon: self.epsilon,
32            norm_before_gate: self.norm_before_gate,
33        }
34    }
35}
36
37/// Applies Gated Rms Normalization over an input tensor along the last dimension.
38///
39/// - If `norm_before_gate=true`: `Y = (X / sqrt(mean(X^2) + eps) * gamma) * SiLU(z)`
40/// - If `norm_before_gate=false`: `Y = (X * SiLU(z)) / sqrt(mean((X * SiLU(z))^2) + eps) * gamma`
41///
42/// Where:
43/// - `X` is the input tensor
44/// - `Y` is the output tensor
45/// - `z` is the gating tensor
46/// - `gamma` is the learnable weight
47/// - `mean` is the mean operation
48/// - `eps` is a small value to avoid division by zero.
49///
50/// Should be created using the [RmsNormGatedConfig](RmsNormGatedConfig) configuration.
51#[derive(Module, Debug)]
52#[module(custom_display)]
53pub struct RmsNormGated<B: Backend> {
54    /// The learnable parameter to scale the normalized tensor.
55    pub gamma: Param<Tensor<B, 1>>,
56    // // TODO: config epsilon is no longer used.
57    // /// A value required for numerical stability.
58    // pub epsilon: f64,
59    /// Whether to normalize before applying the gating.
60    pub norm_before_gate: bool,
61}
62
63impl<B: Backend> RmsNormGated<B> {
64    /// Applies the forward pass on the input tensor with gating.
65    ///
66    /// # Shapes
67    /// - input `x`: `[..., any, d_model]`
68    /// - input `z`: `[..., any, d_model]`
69    /// - output: `[..., any, d_model]`
70    pub fn forward<const D: usize>(&self, x: Tensor<B, D>, z: Tensor<B, D>) -> Tensor<B, D> {
71        let silu = Silu::new();
72
73        let x = if self.norm_before_gate {
74            // gate will be applied later
75            x
76        } else {
77            // gate before norm
78            x * silu.forward(z.clone())
79            // x * burn::tensor::activation::leaky_relu(z.clone(), 0.01)
80        };
81
82        let normalized = match x.dtype() {
83            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
84                let div_eps = div_eps::<B>();
85
86                let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
87                let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
88                normalized
89            }
90            DType::F16 => {
91                use burn::tensor::ElementConversion;
92                let div_eps: f16 = f16::from_elem(div_eps::<B>()) * f16::from_f32(2.);
93
94                // avoid calculating x² directly (due to overflow e.g. on 256 * 256)
95                let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
96                let x_ = x.clone() / (max.clone() + div_eps); // |x_| <= 1
97                let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); // √(x²) = √(x²/max) * √max
98                let normalized =
99                    (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
100                normalized
101            }
102            DType::I64
103            | DType::I32
104            | DType::I16
105            | DType::I8
106            | DType::U64
107            | DType::U32
108            | DType::U16
109            | DType::U8 => {
110                unreachable!()
111            }
112            DType::Bool(_) => {
113                unreachable!()
114            }
115            DType::QFloat(_) => {
116                unimplemented!()
117            }
118        };
119
120        if self.norm_before_gate {
121            // gate gets applied late (now)
122            normalized * silu.forward(z)
123        } else {
124            // gate already got applied before
125            normalized
126        }
127    }
128}
129
130impl<B: Backend> ModuleDisplay for RmsNormGated<B> {
131    fn custom_settings(&self) -> Option<DisplaySettings> {
132        DisplaySettings::new()
133            .with_new_line_after_attribute(false)
134            .optional()
135    }
136
137    fn custom_content(&self, content: Content) -> Option<Content> {
138        let [d_model] = self.gamma.shape().dims();
139        content
140            .add("d_model", &d_model)
141            // .add("epsilon", &self.epsilon)
142            .optional()
143    }
144}