burn_mamba/utils/loss/
mse.rs1use crate::utils::div_eps;
2use burn::module::Module;
3use burn::nn::loss::Reduction;
4use burn::tensor::{DType, Tensor, backend::Backend, f16};
5
6#[derive(Module, Clone, Debug)]
8pub struct MseLoss;
9
10impl Default for MseLoss {
11 fn default() -> Self {
12 Self::new()
13 }
14}
15
16impl MseLoss {
17 pub fn new() -> Self {
19 Self
20 }
21
22 pub fn forward<B: Backend>(
29 &self,
30 logits: Tensor<B, 2>,
31 targets: Tensor<B, 2>,
32 reduction: Reduction,
33 ) -> Tensor<B, 1> {
34 let [batch_size, _num_targets] = logits.dims();
35 match logits.dtype() {
36 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
37 let tensor = self.forward_no_reduction(logits, targets);
38 match reduction {
39 Reduction::Mean | Reduction::Auto => tensor.mean(),
40 Reduction::BatchMean => tensor.mean() / batch_size as f32,
41 Reduction::Sum => tensor.sum(),
42 }
43 }
44 DType::F16 => {
45 use burn::tensor::ElementConversion;
46 let div_eps: f16 = f16::from_elem(div_eps::<B>()) * f16::from_f32(2.);
47 let sub = logits.sub(targets);
49 let max = sub.clone().no_grad().detach().abs().max();
50 let sub_ = sub.clone() / (max.clone().expand(sub.shape()) + div_eps); let partial = sub * sub_; let reduced_partial = match reduction {
53 Reduction::Mean | Reduction::Auto => partial.mean(),
54 Reduction::BatchMean => partial.mean() / batch_size as f32,
55 Reduction::Sum => partial.sum(),
56 };
57 reduced_partial * max
58 }
59 DType::I64
60 | DType::I32
61 | DType::I16
62 | DType::I8
63 | DType::U64
64 | DType::U32
65 | DType::U16
66 | DType::U8 => {
67 unreachable!()
68 }
69 DType::Bool(_) => {
70 unreachable!()
71 }
72 DType::QFloat(_) => {
73 unimplemented!()
74 }
75 }
76 }
77
78 pub fn forward_no_reduction<const D: usize, B: Backend>(
80 &self,
81 logits: Tensor<B, D>,
82 targets: Tensor<B, D>,
83 ) -> Tensor<B, D> {
84 logits.sub(targets).square()
85 }
86}