burn_mamba/modules/loss/
mse.rs1use crate::utils::div_eps;
8use burn::module::Module;
9use burn::nn::loss::Reduction;
10use burn::tensor::{DType, Tensor, f16};
11
12#[derive(Module, Debug)]
14pub struct MseLoss;
15
16impl Default for MseLoss {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl MseLoss {
23 pub fn new() -> Self {
25 Self
26 }
27
28 pub fn forward(
35 &self,
36 logits: Tensor<2>,
37 targets: Tensor<2>,
38 reduction: Reduction,
39 ) -> Tensor<1> {
40 let [batch_size, _num_targets] = logits.dims();
41 match logits.dtype() {
42 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
43 let tensor = self.forward_no_reduction(logits, targets);
44 match reduction {
45 Reduction::Mean | Reduction::Auto => tensor.mean(),
46 Reduction::BatchMean => tensor.mean() / batch_size as f32,
47 Reduction::Sum => tensor.sum(),
48 }
49 }
50 DType::F16 => {
51 use burn::tensor::ElementConversion;
52 let div_eps: f16 = f16::from_elem(div_eps(logits.dtype())) * f16::from_f32(2.);
53 let sub = logits.sub(targets);
55 let max = sub.clone().no_grad().detach().abs().max();
56 let sub_ = sub.clone() / (max.clone().expand(sub.shape()) + div_eps); let partial = sub * sub_; let reduced_partial = match reduction {
59 Reduction::Mean | Reduction::Auto => partial.mean(),
60 Reduction::BatchMean => partial.mean() / batch_size as f32,
61 Reduction::Sum => partial.sum(),
62 };
63 reduced_partial * max
64 }
65 DType::I64
66 | DType::I32
67 | DType::I16
68 | DType::I8
69 | DType::U64
70 | DType::U32
71 | DType::U16
72 | DType::U8 => {
73 unreachable!()
74 }
75 DType::Bool(_) => {
76 unreachable!()
77 }
78 DType::QFloat(_) => {
79 unimplemented!()
80 }
81 }
82 }
83
84 pub fn forward_no_reduction<const D: usize>(
86 &self,
87 logits: Tensor<D>,
88 targets: Tensor<D>,
89 ) -> Tensor<D> {
90 logits.sub(targets).square()
91 }
92}