burn_mamba/modules/loss/bce.rs
1//! Binary cross-entropy loss.
2//!
3//! When `logits = true` the loss is computed in a numerically stable way from
4//! raw logits via [`log_sigmoid`]; otherwise the inputs are treated as
5//! probabilities and the logs are clamped to avoid `−∞`.
6
7use crate::modules::log_sigmoid;
8use burn::module::Module;
9use burn::prelude::*;
10
11/// Configuration to create a [`BinaryCrossEntropyLoss`] using the [`BinaryCrossEntropyLossConfig::init`].
12#[derive(Config, Debug)]
13pub struct BinaryCrossEntropyLossConfig {
14 /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
15 #[config(default = false)]
16 pub logits: bool,
17}
18
19impl BinaryCrossEntropyLossConfig {
20 /// Initialize [`BinaryCrossEntropyLoss`].
21 pub fn init(&self) -> BinaryCrossEntropyLoss {
22 BinaryCrossEntropyLoss {
23 logits: self.logits,
24 }
25 }
26}
27
28/// Calculate the binary cross entropy loss from the input logits and the targets.
29///
30/// Should be created using [BinaryCrossEntropyLossConfig]
31#[derive(Module, Debug)]
32pub struct BinaryCrossEntropyLoss {
33 /// Treat the inputs as logits
34 pub logits: bool,
35}
36
37impl BinaryCrossEntropyLoss {
38 /// Compute the criterion on the input tensor.
39 ///
40 /// # Shapes
41 ///
42 /// Binary:
43 /// - logits: `[batch_size]`
44 /// - targets: `[batch_size]`
45 ///
46 /// Multi-label:
47 /// - logits: `[batch_size, num_classes]`
48 /// - targets: `[batch_size, num_classes]`
49 pub fn forward<const D: usize>(&self, logits: Tensor<D>, targets: Tensor<D>) -> Tensor<1> {
50 let loss = if self.logits {
51 // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
52 (targets.neg() + 1.) * logits.clone() - log_sigmoid(logits)
53 } else {
54 // - (target * log(input) + (1 - target) * log(1 - input))
55 // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
56 (targets.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
57 - targets * logits.log().clamp_min(-100.0)
58 };
59
60 loss.mean()
61 }
62}