burn_mamba/utils/
rms_norm.rs1use crate::utils::div_eps;
2use burn::module::{Content, DisplaySettings, ModuleDisplay, Param};
3use burn::nn::Initializer;
4use burn::prelude::*;
5use burn::tensor::{DType, f16};
6
7#[derive(Config, Debug)]
9pub struct RmsNormConfig {
10 pub d_model: usize,
12}
13
14impl RmsNormConfig {
15 pub fn init<B: Backend>(&self, device: &B::Device) -> RmsNorm<B> {
17 let gamma = Initializer::Ones.init([self.d_model], device);
18 RmsNorm { gamma }
19 }
20}
21
22#[derive(Module, Debug)]
33#[module(custom_display)]
34pub struct RmsNorm<B: Backend> {
35 pub gamma: Param<Tensor<B, 1>>,
37}
38
39impl<B: Backend> RmsNorm<B> {
40 pub fn forward<const D: usize>(&self, x: Tensor<B, D>) -> Tensor<B, D> {
47 let normalized = match x.dtype() {
48 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
49 let div_eps = div_eps::<B>();
50 let rms = (x.clone() * x.clone()).mean_dim(D - 1).sqrt();
51 let normalized = (x / (rms + div_eps)) * self.gamma.val().unsqueeze();
52 normalized
53 }
54 DType::F16 => {
55 use burn::tensor::ElementConversion;
56 let div_eps: f16 = f16::from_elem(div_eps::<B>()) * f16::from_f32(2.);
57 let max = x.clone().no_grad().detach().abs().max().expand(x.shape());
59 let x_ = x.clone() / (max.clone() + div_eps); let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt(); let normalized =
62 (x / (rms_partial + div_eps)) / max.sqrt() * self.gamma.val().unsqueeze();
63 normalized
64 }
65 DType::I64
66 | DType::I32
67 | DType::I16
68 | DType::I8
69 | DType::U64
70 | DType::U32
71 | DType::U16
72 | DType::U8 => {
73 unreachable!()
74 }
75 DType::Bool(_) => {
76 unreachable!()
77 }
78 DType::QFloat(_) => {
79 unimplemented!()
80 }
81 };
82 normalized
83 }
84}
85
86impl<B: Backend> ModuleDisplay for RmsNorm<B> {
87 fn custom_settings(&self) -> Option<DisplaySettings> {
88 DisplaySettings::new()
89 .with_new_line_after_attribute(false)
90 .optional()
91 }
92
93 fn custom_content(&self, content: Content) -> Option<Content> {
94 let [d_model] = self.gamma.shape().dims();
95 content.add("d_model", &d_model).optional()
96 }
97}