1use ElementConversion;
2use burn::prelude::*;
3use burn::tensor::{DType, Element};
4
5pub mod log_sigmoid;
6pub mod loss;
7pub mod rms_norm;
8pub mod rms_norm_gated;
9pub mod sanity;
10pub mod scheduler;
11pub mod silu;
12pub mod softplus;
13
14pub fn stable_max<B: Backend>() -> B::FloatElem {
15 match <B::FloatElem as Element>::dtype() {
16 DType::F64 => f64::MAX.elem(),
17 DType::F32 | DType::Flex32 => f32::MAX.elem(),
18 DType::F16 => burn::tensor::f16::MAX.elem(),
19 DType::BF16 => burn::tensor::bf16::MAX.elem(),
20 DType::I64
21 | DType::I32
22 | DType::I16
23 | DType::I8
24 | DType::U64
25 | DType::U32
26 | DType::U16
27 | DType::U8 => {
28 unreachable!()
29 }
30 DType::Bool(_) => {
31 unreachable!()
32 }
33 DType::QFloat(_) => {
34 unimplemented!()
35 }
36 }
37}
38
39pub fn div_eps_f32<B: Backend>() -> f32 {
40 match <B::FloatElem as Element>::dtype() {
41 DType::F64 => {
43 let raw_exp = -(-f64::MIN_EXP as f32 * 2.3f32).powf(0.35f32);
44 let eps_exp = (f64::EPSILON as f32).log10();
45 let avg = (raw_exp + eps_exp) / 2f32;
46 10f32.powf(avg)
47 }
48 DType::F32 | DType::Flex32 => {
50 let raw_exp = -(-f32::MIN_EXP as f32 * 2.3f32).powf(0.35f32);
51 let eps_exp = f32::EPSILON.log10();
52 let avg = (raw_exp + eps_exp) / 2f32;
53 10f32.powf(avg)
54 }
55 DType::F16 => {
57 let raw_exp = -(-burn::tensor::f16::MIN_EXP.to_f32() * 2.3f32).powf(0.35f32);
58 let eps_exp = burn::tensor::f16::EPSILON.to_f32().log10();
59 let avg = (raw_exp + eps_exp) / 2f32;
60 10f32.powf(avg)
61 }
62 DType::BF16 => {
64 let raw_exp = -(-burn::tensor::bf16::MIN_EXP.to_f32() * 2.3f32).powf(0.35f32);
65 let eps_exp = burn::tensor::bf16::EPSILON.to_f32().log10();
66 let avg = (raw_exp + eps_exp) / 2f32;
67 10f32.powf(avg)
68 }
69 DType::I64
70 | DType::I32
71 | DType::I16
72 | DType::I8
73 | DType::U64
74 | DType::U32
75 | DType::U16
76 | DType::U8
77 | DType::Bool(_) => {
78 unreachable!()
79 }
80 DType::QFloat(_) => {
81 unimplemented!()
82 }
83 }
84}
85
86pub fn div_eps<B: Backend>() -> B::FloatElem {
87 div_eps_f32::<B>().elem()
88}