Skip to main content

burn_mamba/modules/loss/
bce.rs

1//! Binary cross-entropy loss.
2//!
3//! When `logits = true` the loss is computed in a numerically stable way from
4//! raw logits via [`log_sigmoid`]; otherwise the inputs are treated as
5//! probabilities and the logs are clamped to avoid `−∞`.
6
7use crate::modules::log_sigmoid;
8use burn::module::Module;
9use burn::prelude::*;
10
11/// Configuration to create a [`BinaryCrossEntropyLoss`] using the [`BinaryCrossEntropyLossConfig::init`].
12#[derive(Config, Debug)]
13pub struct BinaryCrossEntropyLossConfig {
14    /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
15    #[config(default = false)]
16    pub logits: bool,
17}
18
19impl BinaryCrossEntropyLossConfig {
20    /// Initialize [`BinaryCrossEntropyLoss`].
21    pub fn init(&self) -> BinaryCrossEntropyLoss {
22        BinaryCrossEntropyLoss {
23            logits: self.logits,
24        }
25    }
26}
27
28/// Calculate the binary cross entropy loss from the input logits and the targets.
29///
30/// Should be created using [BinaryCrossEntropyLossConfig]
31#[derive(Module, Debug)]
32pub struct BinaryCrossEntropyLoss {
33    /// Treat the inputs as logits
34    pub logits: bool,
35}
36
37impl BinaryCrossEntropyLoss {
38    /// Compute the criterion on the input tensor.
39    ///
40    /// # Shapes
41    ///
42    /// Binary:
43    /// - logits: `[batch_size]`
44    /// - targets: `[batch_size]`
45    ///
46    /// Multi-label:
47    /// - logits: `[batch_size, num_classes]`
48    /// - targets: `[batch_size, num_classes]`
49    pub fn forward<const D: usize>(&self, logits: Tensor<D>, targets: Tensor<D>) -> Tensor<1> {
50        let loss = if self.logits {
51            // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
52            (targets.neg() + 1.) * logits.clone() - log_sigmoid(logits)
53        } else {
54            // - (target * log(input) + (1 - target) * log(1 - input))
55            // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
56            (targets.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
57                - targets * logits.log().clamp_min(-100.0)
58        };
59
60        loss.mean()
61    }
62}