Skip to main content

burn_mamba/modules/loss/
mse.rs

1//! Mean squared error loss.
2//!
3//! The fp16 path avoids forming `(logits − targets)²` directly (which overflows
4//! for large differences) by factoring out `max(|diff|)` before squaring, then
5//! multiplying it back in after the reduction.
6
7use crate::utils::div_eps;
8use burn::module::Module;
9use burn::nn::loss::Reduction;
10use burn::tensor::{DType, Tensor, f16};
11
12/// Calculate the mean squared error loss from the input logits and the targets.
13#[derive(Module, Debug)]
14pub struct MseLoss;
15
16impl Default for MseLoss {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl MseLoss {
23    /// Create the criterion.
24    pub fn new() -> Self {
25        Self
26    }
27
28    /// Compute the criterion on the input tensor.
29    ///
30    /// # Shapes
31    ///
32    /// - logits: `[batch_size, num_targets]`
33    /// - targets: `[batch_size, num_targets]`
34    pub fn forward(
35        &self,
36        logits: Tensor<2>,
37        targets: Tensor<2>,
38        reduction: Reduction,
39    ) -> Tensor<1> {
40        let [batch_size, _num_targets] = logits.dims();
41        match logits.dtype() {
42            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
43                let tensor = self.forward_no_reduction(logits, targets);
44                match reduction {
45                    Reduction::Mean | Reduction::Auto => tensor.mean(),
46                    Reduction::BatchMean => tensor.mean() / batch_size as f32,
47                    Reduction::Sum => tensor.sum(),
48                }
49            }
50            DType::F16 => {
51                use burn::tensor::ElementConversion;
52                let div_eps: f16 = f16::from_elem(div_eps(logits.dtype())) * f16::from_f32(2.);
53                // avoid calculating sub² directly (due to overflow e.g. on 256 * 256)
54                let sub = logits.sub(targets);
55                let max = sub.clone().no_grad().detach().abs().max();
56                let sub_ = sub.clone() / (max.clone().expand(sub.shape()) + div_eps); // sub_.abs() <= 1
57                let partial = sub * sub_; // sub² = partial * max
58                let reduced_partial = match reduction {
59                    Reduction::Mean | Reduction::Auto => partial.mean(),
60                    Reduction::BatchMean => partial.mean() / batch_size as f32,
61                    Reduction::Sum => partial.sum(),
62                };
63                reduced_partial * max
64            }
65            DType::I64
66            | DType::I32
67            | DType::I16
68            | DType::I8
69            | DType::U64
70            | DType::U32
71            | DType::U16
72            | DType::U8 => {
73                unreachable!()
74            }
75            DType::Bool(_) => {
76                unreachable!()
77            }
78            DType::QFloat(_) => {
79                unimplemented!()
80            }
81        }
82    }
83
84    /// Compute the criterion on the input tensor without reducing.
85    pub fn forward_no_reduction<const D: usize>(
86        &self,
87        logits: Tensor<D>,
88        targets: Tensor<D>,
89    ) -> Tensor<D> {
90        logits.sub(targets).square()
91    }
92}