Skip to main content

burn_mamba/utils/loss/
bce.rs

1use crate::utils::log_sigmoid::log_sigmoid;
2use burn::module::Module;
3use burn::prelude::*;
4
5/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
6#[derive(Config, Debug)]
7pub struct BinaryCrossEntropyLossConfig {
8    /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
9    #[config(default = false)]
10    pub logits: bool,
11}
12
13impl BinaryCrossEntropyLossConfig {
14    /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
15    pub fn init(&self) -> BinaryCrossEntropyLoss {
16        BinaryCrossEntropyLoss {
17            logits: self.logits,
18        }
19    }
20}
21
22/// Calculate the binary cross entropy loss from the input logits and the targets.
23///
24/// Should be created using [BinaryCrossEntropyLossConfig]
25#[derive(Module, Clone, Debug)]
26pub struct BinaryCrossEntropyLoss {
27    /// Treat the inputs as logits
28    pub logits: bool,
29}
30
31impl BinaryCrossEntropyLoss {
32    /// Compute the criterion on the input tensor.
33    ///
34    /// # Shapes
35    ///
36    /// Binary:
37    /// - logits: `[batch_size]`
38    /// - targets: `[batch_size]`
39    ///
40    /// Multi-label:
41    /// - logits: `[batch_size, num_classes]`
42    /// - targets: `[batch_size, num_classes]`
43    pub fn forward<const D: usize, B: Backend>(
44        &self,
45        logits: Tensor<B, D>,
46        targets: Tensor<B, D>,
47    ) -> Tensor<B, 1> {
48        let loss = if self.logits {
49            // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
50            (targets.neg() + 1.) * logits.clone() - log_sigmoid(logits)
51        } else {
52            // - (target * log(input) + (1 - target) * log(1 - input))
53            // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
54            (targets.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
55                - targets * logits.log().clamp_min(-100.0)
56        };
57
58        loss.mean()
59    }
60}