Skip to main content

burn_mamba/utils/
rms_norm.rs

1use crate::utils::div_eps;
2use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
3use burn::nn::Initializer;
4use burn::prelude::*;
5use burn::tensor::{DType, f16};
6
7/// Configuration to create a [RmsNorm](RmsNorm) layer.
8#[derive(Config, Debug)]
9pub struct RmsNormConfig {
10    /// The size of the input features.
11    pub d_model: usize,
12}
13
14impl RmsNormConfig {
15    /// Initialize a new [RmsNorm](RmsNorm) module.
16    pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
17        let gamma = Initializer::Ones.init([self.d_model], device);
18        RmsNorm { gamma }
19    }
20}
21
22/// Applies Rms Normalization over an input tensor along the last dimension.
23///
24/// Where:
25/// - `X` is the input tensor
26/// - `Y` is the output tensor
27/// - `z` is the gating tensor
28/// - `gamma` is the learnable weight
29/// - `mean` is the mean operation
30///
31/// Should be created using the [RmsNormConfig](RmsNormConfig) configuration.
32#[derive(Module, Debug)]
33#[module(custom_display)]
34pub struct RmsNorm<B: Backend> {
35    /// The learnable parameter to scale the normalized tensor.
36    pub gamma: Param<Tensor<B, 1>>,
37}
38
39impl<B: Backend> RmsNorm<B> {
40    /// Applies the forward pass on the input tensor with gating.
41    ///
42    /// # Shapes
43    /// - input `x`: `[..., any, d_model]`
44    /// - input `z`: `[..., any, d_model]`
45    /// - output: `[..., any, d_model]`
46    pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
47        let normalized = match x.dtype() {
48            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
49                let div_eps = div_eps::<B>();
50                let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
51                let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
52                normalized
53            }
54            DType::F16 => {
55                use burn::tensor::ElementConversion;
56                let div_eps: f16 = f16::from_elem(div_eps::<B>()) * f16::from_f32(2.);
57                // avoid calculating x² directly (due to overflow e.g. on 256 * 256)
58                let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
59                let x_ = x.clone() / (max.clone() + div_eps); // x_.abs() <= 1
60                let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); // √(x²) = √(x²/max) * √max
61                let normalized =
62                    (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
63                normalized
64            }
65            DType::I64
66            | DType::I32
67            | DType::I16
68            | DType::I8
69            | DType::U64
70            | DType::U32
71            | DType::U16
72            | DType::U8 => {
73                unreachable!()
74            }
75            DType::Bool(_) => {
76                unreachable!()
77            }
78            DType::QFloat(_) => {
79                unimplemented!()
80            }
81        };
82        normalized
83    }
84}
85
86impl<B: Backend> ModuleDisplay for RmsNorm<B> {
87    fn custom_settings(&self) -> Option<DisplaySettings> {
88        DisplaySettings::new()
89            .with_new_line_after_attribute(false)
90            .optional()
91    }
92
93    fn custom_content(&self, content: Content) -> Option<Content> {
94        let [d_model] = self.gamma.shape().dims();
95        content.add("d_model", &d_model).optional()
96    }
97}