burn_mamba/utils/rms_norm_gated.rs
1use crate::utils::div_eps;
2use crate::utils::silu::Silu;
3use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
4use burn::nn::Initializer;
5use burn::prelude::*;
6use burn::tensor::{DType, f16};
7
8/// Configuration to create a [RmsNormGated](RmsNormGated) layer.
9#[derive(Config, Debug)]
10pub struct RmsNormGatedConfig {
11 /// The size of the input features.
12 pub d_model: usize,
13 // // TODO: config epsilon is no longer used.
14 // /// A value required for numerical stability. Default: 1e-5
15 // #[config(default = 1e-5)]
16 // pub epsilon: f64,
17 /// Whether to apply normalization before gating. Default: true
18 #[config(default = true)]
19 pub norm_before_gate: bool,
20}
21
22impl RmsNormGatedConfig {
23 /// Initialize a new [RmsNormGated](RmsNormGated) module.
24 pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNormGated<B> {
25 // assert!(self.epsilon > 0.0, "epsilon must be positive.");
26
27 let gamma = Initializer::Ones.init([self.d_model], device);
28
29 RmsNormGated {
30 gamma,
31 // epsilon: self.epsilon,
32 norm_before_gate: self.norm_before_gate,
33 }
34 }
35}
36
37/// Applies Gated Rms Normalization over an input tensor along the last dimension.
38///
39/// - If `norm_before_gate=true`: `Y = (X / sqrt(mean(X^2) + eps) * gamma) * SiLU(z)`
40/// - If `norm_before_gate=false`: `Y = (X * SiLU(z)) / sqrt(mean((X * SiLU(z))^2) + eps) * gamma`
41///
42/// Where:
43/// - `X` is the input tensor
44/// - `Y` is the output tensor
45/// - `z` is the gating tensor
46/// - `gamma` is the learnable weight
47/// - `mean` is the mean operation
48/// - `eps` is a small value to avoid division by zero.
49///
50/// Should be created using the [RmsNormGatedConfig](RmsNormGatedConfig) configuration.
51#[derive(Module, Debug)]
52#[module(custom_display)]
53pub struct RmsNormGated<B: Backend> {
54 /// The learnable parameter to scale the normalized tensor.
55 pub gamma: Param<Tensor<B, 1>>,
56 // // TODO: config epsilon is no longer used.
57 // /// A value required for numerical stability.
58 // pub epsilon: f64,
59 /// Whether to normalize before applying the gating.
60 pub norm_before_gate: bool,
61}
62
63impl<B: Backend> RmsNormGated<B> {
64 /// Applies the forward pass on the input tensor with gating.
65 ///
66 /// # Shapes
67 /// - input `x`: `[..., any, d_model]`
68 /// - input `z`: `[..., any, d_model]`
69 /// - output: `[..., any, d_model]`
70 pub fn forward<const D: usize>(&self, x: Tensor<B, D>, z: Tensor<B, D>) -> Tensor<B, D> {
71 let silu = Silu::new();
72
73 let x = if self.norm_before_gate {
74 // gate will be applied later
75 x
76 } else {
77 // gate before norm
78 x * silu.forward(z.clone())
79 // x * burn::tensor::activation::leaky_relu(z.clone(), 0.01)
80 };
81
82 let normalized = match x.dtype() {
83 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
84 let div_eps = div_eps::<B>();
85
86 let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
87 let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
88 normalized
89 }
90 DType::F16 => {
91 use burn::tensor::ElementConversion;
92 let div_eps: f16 = f16::from_elem(div_eps::<B>()) * f16::from_f32(2.);
93
94 // avoid calculating x² directly (due to overflow e.g. on 256 * 256)
95 let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
96 let x_ = x.clone() / (max.clone() + div_eps); // |x_| <= 1
97 let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); // √(x²) = √(x²/max) * √max
98 let normalized =
99 (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
100 normalized
101 }
102 DType::I64
103 | DType::I32
104 | DType::I16
105 | DType::I8
106 | DType::U64
107 | DType::U32
108 | DType::U16
109 | DType::U8 => {
110 unreachable!()
111 }
112 DType::Bool(_) => {
113 unreachable!()
114 }
115 DType::QFloat(_) => {
116 unimplemented!()
117 }
118 };
119
120 if self.norm_before_gate {
121 // gate gets applied late (now)
122 normalized * silu.forward(z)
123 } else {
124 // gate already got applied before
125 normalized
126 }
127 }
128}
129
130impl<B: Backend> ModuleDisplay for RmsNormGated<B> {
131 fn custom_settings(&self) -> Option<DisplaySettings> {
132 DisplaySettings::new()
133 .with_new_line_after_attribute(false)
134 .optional()
135 }
136
137 fn custom_content(&self, content: Content) -> Option<Content> {
138 let [d_model] = self.gamma.shape().dims();
139 content
140 .add("d_model", &d_model)
141 // .add("epsilon", &self.epsilon)
142 .optional()
143 }
144}