Skip to main content

burn_mamba/utils/loss/
mse.rs

1use crate::utils::div_eps;
2use burn::module::Module;
3use burn::nn::loss::Reduction;
4use burn::tensor::{DType, Tensor, backend::Backend, f16};
5
6/// Calculate the mean squared error loss from the input logits and the targets.
7#[derive(Module, Clone, Debug)]
8pub struct MseLoss;
9
10impl Default for MseLoss {
11    fn default() -> Self {
12        Self::new()
13    }
14}
15
16impl MseLoss {
17    /// Create the criterion.
18    pub fn new() -> Self {
19        Self
20    }
21
22    /// Compute the criterion on the input tensor.
23    ///
24    /// # Shapes
25    ///
26    /// - logits: [batch_size, num_targets]
27    /// - targets: [batch_size, num_targets]
28    pub fn forward<B: Backend>(
29        &self,
30        logits: Tensor<B, 2>,
31        targets: Tensor<B, 2>,
32        reduction: Reduction,
33    ) -> Tensor<B, 1> {
34        let [batch_size, _num_targets] = logits.dims();
35        match logits.dtype() {
36            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
37                let tensor = self.forward_no_reduction(logits, targets);
38                match reduction {
39                    Reduction::Mean | Reduction::Auto => tensor.mean(),
40                    Reduction::BatchMean => tensor.mean() / batch_size as f32,
41                    Reduction::Sum => tensor.sum(),
42                }
43            }
44            DType::F16 => {
45                use burn::tensor::ElementConversion;
46                let div_eps: f16 = f16::from_elem(div_eps::<B>()) * f16::from_f32(2.);
47                // avoid calculating sub² directly (due to overflow e.g. on 256 * 256)
48                let sub = logits.sub(targets);
49                let max = sub.clone().no_grad().detach().abs().max();
50                let sub_ = sub.clone() / (max.clone().expand(sub.shape()) + div_eps); // sub_.abs() <= 1
51                let partial = sub * sub_; // sub² = partial * max
52                let reduced_partial = match reduction {
53                    Reduction::Mean | Reduction::Auto => partial.mean(),
54                    Reduction::BatchMean => partial.mean() / batch_size as f32,
55                    Reduction::Sum => partial.sum(),
56                };
57                reduced_partial * max
58            }
59            DType::I64
60            | DType::I32
61            | DType::I16
62            | DType::I8
63            | DType::U64
64            | DType::U32
65            | DType::U16
66            | DType::U8 => {
67                unreachable!()
68            }
69            DType::Bool(_) => {
70                unreachable!()
71            }
72            DType::QFloat(_) => {
73                unimplemented!()
74            }
75        }
76    }
77
78    /// Compute the criterion on the input tensor without reducing.
79    pub fn forward_no_reduction<const D: usize, B: Backend>(
80        &self,
81        logits: Tensor<B, D>,
82        targets: Tensor<B, D>,
83    ) -> Tensor<B, D> {
84        logits.sub(targets).square()
85    }
86}