burn_mamba/utils/loss/bce.rs
1use crate::utils::log_sigmoid::log_sigmoid;
2use burn::module::Module;
3use burn::prelude::*;
4
5/// Configuration to create a [Binary Cross-entropy loss](BinaryCrossEntropyLoss) using the [init function](BinaryCrossEntropyLossConfig::init).
6#[derive(Config, Debug)]
7pub struct BinaryCrossEntropyLossConfig {
8 /// Treat the inputs as logits, applying a sigmoid activation when computing the loss.
9 #[config(default = false)]
10 pub logits: bool,
11}
12
13impl BinaryCrossEntropyLossConfig {
14 /// Initialize [Binary Cross-entropy loss](BinaryCrossEntropyLoss).
15 pub fn init(&self) -> BinaryCrossEntropyLoss {
16 BinaryCrossEntropyLoss {
17 logits: self.logits,
18 }
19 }
20}
21
22/// Calculate the binary cross entropy loss from the input logits and the targets.
23///
24/// Should be created using [BinaryCrossEntropyLossConfig]
25#[derive(Module, Clone, Debug)]
26pub struct BinaryCrossEntropyLoss {
27 /// Treat the inputs as logits
28 pub logits: bool,
29}
30
31impl BinaryCrossEntropyLoss {
32 /// Compute the criterion on the input tensor.
33 ///
34 /// # Shapes
35 ///
36 /// Binary:
37 /// - logits: `[batch_size]`
38 /// - targets: `[batch_size]`
39 ///
40 /// Multi-label:
41 /// - logits: `[batch_size, num_classes]`
42 /// - targets: `[batch_size, num_classes]`
43 pub fn forward<const D: usize, B: Backend>(
44 &self,
45 logits: Tensor<B, D>,
46 targets: Tensor<B, D>,
47 ) -> Tensor<B, 1> {
48 let loss = if self.logits {
49 // Numerically stable by combining `log(sigmoid(x))` with `log_sigmoid(x)`
50 (targets.neg() + 1.) * logits.clone() - log_sigmoid(logits)
51 } else {
52 // - (target * log(input) + (1 - target) * log(1 - input))
53 // https://github.com/tracel-ai/burn/issues/2739: clamp at -100.0 to avoid undefined values
54 (targets.clone() - 1) * logits.clone().neg().log1p().clamp_min(-100.0)
55 - targets * logits.log().clamp_min(-100.0)
56 };
57
58 loss.mean()
59 }
60}