Skip to main content

burn_mamba/modules/norm/
rms_norm_gated.rs

1//! RMS normalisation fused with a SiLU(z) gate — the Mamba-2 output norm.
2//!
3//! `norm_before_gate` selects the order of the two operations:
4//! - `true`  — normalise, then gate:   `y = (x / rms(x) · γ) · SiLU(z)`
5//! - `false` — gate, then normalise:   `y = rms(x · SiLU(z)) · γ` applied to
6//!   `x · SiLU(z)`
7//!
8//! The numerical-stability epsilon is the per-dtype [`div_eps`] (so there is no
9//! configurable epsilon); the fp16 path uses the same `max(|x|)`-rescaling
10//! trick as [`RmsNorm`](crate::utils::rms_norm::RmsNorm).
11
12use crate::modules::Silu;
13use crate::utils::div_eps;
14use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
15use burn::nn::Initializer;
16use burn::prelude::*;
17use burn::tensor::{DType, f16};
18
19/// Configuration to create a [`RmsNormGated`] layer.
20#[derive(Config, Debug)]
21pub struct RmsNormGatedConfig {
22    /// The size of the input features.
23    pub d_model: usize,
24    /// Whether to apply normalization before gating. Default: true
25    #[config(default = true)]
26    pub norm_before_gate: bool,
27}
28
29impl RmsNormGatedConfig {
30    /// Initialize a new [`RmsNormGated`] module.
31    pub fn init(&self, device: &Device) -> RmsNormGated {
32        let gamma = Initializer::Ones.init([self.d_model], device);
33        RmsNormGated {
34            gamma,
35            norm_before_gate: self.norm_before_gate,
36        }
37    }
38}
39
40/// Applies Gated Rms Normalization over an input tensor along the last dimension.
41///
42/// - If `norm_before_gate=true`: `Y = (X / sqrt(mean(X^2) + eps) * gamma) * SiLU(z)`
43/// - If `norm_before_gate=false`: `Y = (X * SiLU(z)) / sqrt(mean((X * SiLU(z))^2) + eps) * gamma`
44///
45/// Where:
46/// - `X` is the input tensor
47/// - `Y` is the output tensor
48/// - `z` is the gating tensor
49/// - `gamma` is the learnable weight
50/// - `mean` is the mean operation
51/// - `eps` is a small value to avoid division by zero.
52///
53/// Should be created using the [`RmsNormGatedConfig`] configuration.
54#[derive(Module, Debug)]
55#[module(custom_display)]
56pub struct RmsNormGated {
57    /// The learnable per-channel scale `γ`, shape `[d_model]`.
58    pub gamma: Param<Tensor<1>>,
59    /// Whether to normalize before applying the gating.
60    pub norm_before_gate: bool,
61}
62
63impl RmsNormGated {
64    /// Applies the forward pass on the input tensor with gating.
65    ///
66    /// # Shapes
67    /// - input `x`: `[..., any, d_model]`
68    /// - input `z`: `[..., any, d_model]`
69    /// - output: `[..., any, d_model]`
70    pub fn forward<const D: usize>(&self, x: Tensor<D>, z: Tensor<D>) -> Tensor<D> {
71        let silu = Silu::new();
72
73        let x = if self.norm_before_gate {
74            // gate will be applied later
75            x
76        } else {
77            // gate before norm
78            x * silu.forward(z.clone())
79        };
80
81        let normalized = match x.dtype() {
82            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
83                let div_eps = div_eps(x.dtype());
84
85                let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
86                let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
87                normalized
88            }
89            DType::F16 => {
90                use burn::tensor::ElementConversion;
91                let div_eps: f16 = f16::from_elem(div_eps(x.dtype())) * f16::from_f32(2.);
92
93                // avoid calculating x² directly (due to overflow e.g. on 256 * 256)
94                let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
95                let x_ = x.clone() / (max.clone() + div_eps); // |x_| <= 1
96                let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); // √(x²) = √(x²/max) * √max
97                let normalized =
98                    (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
99                normalized
100            }
101            DType::I64
102            | DType::I32
103            | DType::I16
104            | DType::I8
105            | DType::U64
106            | DType::U32
107            | DType::U16
108            | DType::U8 => {
109                unreachable!()
110            }
111            DType::Bool(_) => {
112                unreachable!()
113            }
114            DType::QFloat(_) => {
115                unimplemented!()
116            }
117        };
118
119        if self.norm_before_gate {
120            // gate gets applied late (now)
121            normalized * silu.forward(z)
122        } else {
123            // gate already got applied before
124            normalized
125        }
126    }
127}
128
129impl ModuleDisplay for RmsNormGated {
130    fn custom_settings(&self) -> Option<DisplaySettings> {
131        DisplaySettings::new()
132            .with_new_line_after_attribute(false)
133            .optional()
134    }
135
136    fn custom_content(&self, content: Content) -> Option<Content> {
137        let [d_model] = self.gamma.shape().dims();
138        content.add("d_model", &d_model).optional()
139    }
140}