burn_mamba/modules/norm/
rms_norm.rs1use crate::utils::div_eps;
16use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
17use burn::nn::Initializer;
18use burn::prelude::*;
19use burn::tensor::{DType, f16};
20
21#[derive(Config, Debug)]
23pub struct RmsNormConfig {
24 pub d_model: usize,
26}
27
28impl RmsNormConfig {
29 pub fn init(&self, device: &Device) -> RmsNorm {
31 let gamma = Initializer::Ones.init([self.d_model], device);
32 RmsNorm { gamma }
33 }
34}
35
36#[derive(Module, Debug)]
41#[module(custom_display)]
42pub struct RmsNorm {
43 pub gamma: Param<Tensor<1>>,
45}
46
47impl RmsNorm {
48 pub fn forward<const D: usize>(&self, x: Tensor<D>) -> Tensor<D> {
54 let normalized = match x.dtype() {
55 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
56 let div_eps = div_eps(x.dtype());
57 let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
58 let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
59 normalized
60 }
61 DType::F16 => {
62 use burn::tensor::ElementConversion;
63 let div_eps: f16 = f16::from_elem(div_eps(x.dtype())) * f16::from_f32(2.);
64 let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
66 let x_ = x.clone() / (max.clone() + div_eps); let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); let normalized =
69 (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
70 normalized
71 }
72 DType::I64
73 | DType::I32
74 | DType::I16
75 | DType::I8
76 | DType::U64
77 | DType::U32
78 | DType::U16
79 | DType::U8 => {
80 unreachable!()
81 }
82 DType::Bool(_) => {
83 unreachable!()
84 }
85 DType::QFloat(_) => {
86 unimplemented!()
87 }
88 };
89 normalized
90 }
91}
92
93impl ModuleDisplay for RmsNorm {
94 fn custom_settings(&self) -> Option<DisplaySettings> {
95 DisplaySettings::new()
96 .with_new_line_after_attribute(false)
97 .optional()
98 }
99
100 fn custom_content(&self, content: Content) -> Option<Content> {
101 let [d_model] = self.gamma.shape().dims();
102 content.add("d_model", &d_model).optional()
103 }
104}