Skip to main content

burn_mamba/utils/
mod.rs

1use ElementConversion;
2use burn::prelude::*;
3use burn::tensor::{DType, Element};
4
5pub mod log_sigmoid;
6pub mod loss;
7pub mod rms_norm;
8pub mod rms_norm_gated;
9pub mod sanity;
10pub mod scheduler;
11pub mod silu;
12pub mod softplus;
13
14pub fn stable_max<B: Backend>() -> B::FloatElem {
15    match <B::FloatElem as Element>::dtype() {
16        DType::F64 => f64::MAX.elem(),
17        DType::F32 | DType::Flex32 => f32::MAX.elem(),
18        DType::F16 => burn::tensor::f16::MAX.elem(),
19        DType::BF16 => burn::tensor::bf16::MAX.elem(),
20        DType::I64
21        | DType::I32
22        | DType::I16
23        | DType::I8
24        | DType::U64
25        | DType::U32
26        | DType::U16
27        | DType::U8 => {
28            unreachable!()
29        }
30        DType::Bool(_) => {
31            unreachable!()
32        }
33        DType::QFloat(_) => {
34            unimplemented!()
35        }
36    }
37}
38
39pub fn div_eps_f32<B: Backend>() -> f32 {
40    match <B::FloatElem as Element>::dtype() {
41        // 4.0693917e-16
42        DType::F64 => {
43            let raw_exp = -(-f64::MIN_EXP as f32 * 2.3f32).powf(0.35f32);
44            let eps_exp = (f64::EPSILON as f32).log10();
45            let avg = (raw_exp + eps_exp) / 2f32;
46            10f32.powf(avg)
47        }
48        // 8.1584695e-8
49        DType::F32 | DType::Flex32 => {
50            let raw_exp = -(-f32::MIN_EXP as f32 * 2.3f32).powf(0.35f32);
51            let eps_exp = f32::EPSILON.log10();
52            let avg = (raw_exp + eps_exp) / 2f32;
53            10f32.powf(avg)
54        }
55        // 7.1209995e-4
56        DType::F16 => {
57            let raw_exp = -(-burn::tensor::f16::MIN_EXP.to_f32() * 2.3f32).powf(0.35f32);
58            let eps_exp = burn::tensor::f16::EPSILON.to_f32().log10();
59            let avg = (raw_exp + eps_exp) / 2f32;
60            10f32.powf(avg)
61        }
62        // 2.0885676e-5
63        DType::BF16 => {
64            let raw_exp = -(-burn::tensor::bf16::MIN_EXP.to_f32() * 2.3f32).powf(0.35f32);
65            let eps_exp = burn::tensor::bf16::EPSILON.to_f32().log10();
66            let avg = (raw_exp + eps_exp) / 2f32;
67            10f32.powf(avg)
68        }
69        DType::I64
70        | DType::I32
71        | DType::I16
72        | DType::I8
73        | DType::U64
74        | DType::U32
75        | DType::U16
76        | DType::U8
77        | DType::Bool(_) => {
78            unreachable!()
79        }
80        DType::QFloat(_) => {
81            unimplemented!()
82        }
83    }
84}
85
86pub fn div_eps<B: Backend>() -> B::FloatElem {
87    div_eps_f32::<B>().elem()
88}