Skip to main content

burn_mamba/modules/norm/
rms_norm.rs

1//! Root-mean-square normalisation over the last dimension.
2//!
3//! `RMSNorm(x) = x / rms(x) · γ` where `rms(x) = √(mean(x²))`.  Unlike
4//! LayerNorm there is no mean-subtraction or bias — only a learnable per-channel
5//! scale `γ`.  Used both as the Pre-LN of every residual block and, in Mamba-3,
6//! as the **QK-Norm** applied to the B/C projections.
7//!
8//! The fp16 path avoids forming `x²` directly (which overflows for moderately
9//! large activations, e.g. 256·256): it first normalises against `max(|x|)` so
10//! the squared values stay `≤ 1`, then rescales.  See [`rms_norm_gated`] for the
11//! SiLU-gated variant.
12//!
13//! [`rms_norm_gated`]: crate::utils::rms_norm_gated
14
15use crate::utils::div_eps;
16use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
17use burn::nn::Initializer;
18use burn::prelude::*;
19use burn::tensor::{DType, f16};
20
21/// Configuration to create a [`RmsNorm`] layer.
22#[derive(Config, Debug)]
23pub struct RmsNormConfig {
24    /// The size of the input features.
25    pub d_model: usize,
26}
27
28impl RmsNormConfig {
29    /// Initialize a new [`RmsNorm`] module.
30    pub fn init(&self, device: &Device) -> RmsNorm {
31        let gamma = Initializer::Ones.init([self.d_model], device);
32        RmsNorm { gamma }
33    }
34}
35
36/// Applies RMS normalisation over an input tensor along the last dimension:
37/// `y = x / √(mean(x²)) · γ`.
38///
39/// Should be created using the [`RmsNormConfig`] configuration.
40#[derive(Module, Debug)]
41#[module(custom_display)]
42pub struct RmsNorm {
43    /// The learnable per-channel scale `γ`, shape `[d_model]`.
44    pub gamma: Param<Tensor<1>>,
45}
46
47impl RmsNorm {
48    /// Applies the forward pass on the input tensor.
49    ///
50    /// # Shapes
51    /// - input `x`: `[..., d_model]`
52    /// - output: `[..., d_model]`
53    pub fn forward<const D: usize>(&self, x: Tensor<D>) -> Tensor<D> {
54        let normalized = match x.dtype() {
55            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
56                let div_eps = div_eps(x.dtype());
57                let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
58                let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
59                normalized
60            }
61            DType::F16 => {
62                use burn::tensor::ElementConversion;
63                let div_eps: f16 = f16::from_elem(div_eps(x.dtype())) * f16::from_f32(2.);
64                // avoid calculating x² directly (due to overflow e.g. on 256 * 256)
65                let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
66                let x_ = x.clone() / (max.clone() + div_eps); // x_.abs() <= 1
67                let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); // √(x²) = √(x²/max) * √max
68                let normalized =
69                    (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
70                normalized
71            }
72            DType::I64
73            | DType::I32
74            | DType::I16
75            | DType::I8
76            | DType::U64
77            | DType::U32
78            | DType::U16
79            | DType::U8 => {
80                unreachable!()
81            }
82            DType::Bool(_) => {
83                unreachable!()
84            }
85            DType::QFloat(_) => {
86                unimplemented!()
87            }
88        };
89        normalized
90    }
91}
92
93impl ModuleDisplay for RmsNorm {
94    fn custom_settings(&self) -> Option<DisplaySettings> {
95        DisplaySettings::new()
96            .with_new_line_after_attribute(false)
97            .optional()
98    }
99
100    fn custom_content(&self, content: Content) -> Option<Content> {
101        let [d_model] = self.gamma.shape().dims();
102        content.add("d_model", &d_model).optional()
103    }
104}