Skip to main content

burn_mamba/utils/loss/
cross_entropy.rs

1use burn::module::Module;
2use burn::prelude::*;
3use burn::tensor::activation::{log_softmax, softmax};
4
5/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
6#[derive(Config, Debug)]
7pub struct CrossEntropyLossConfig {
8    /// Treat the outputs as logits, applying log-softmax when computing the loss.
9    ///
10    /// When `false`, outputs are assumed to be probabilities (e.g. post-softmax).
11    #[config(default = true)]
12    pub output_logits: bool,
13
14    /// Treat the targets as logits, applying softmax to normalize them before computing the loss.
15    ///
16    /// When `false`, targets are assumed to already be a valid probability distribution
17    /// (e.g. one-hot or soft labels that sum to 1).
18    #[config(default = false)]
19    pub target_logits: bool,
20}
21
22impl CrossEntropyLossConfig {
23    /// Initialize [Cross-entropy loss](CrossEntropyLoss).
24    pub fn init(&self) -> CrossEntropyLoss {
25        CrossEntropyLoss {
26            output_logits: self.output_logits,
27            target_logits: self.target_logits,
28        }
29    }
30}
31
32/// Calculate the cross-entropy loss from the output logits and the targets.
33///
34/// Unlike the full [`CrossEntropyLoss`](super::CrossEntropyLoss), this variant accepts
35/// floating-point targets (e.g. one-hot, soft label distributions, or un-normalized logits)
36/// rather than integer class indices, and omits padding, per-class weights, and label smoothing.
37///
38/// Should be created using [CrossEntropyLossConfig].
39#[derive(Module, Clone, Debug)]
40pub struct CrossEntropyLoss {
41    /// Treat the outputs as logits.
42    pub output_logits: bool,
43    /// Treat the targets as logits.
44    pub target_logits: bool,
45}
46
47impl CrossEntropyLoss {
48    /// Compute the criterion on the output tensor.
49    ///
50    /// # Shapes
51    ///
52    /// - logits: `[batch_size, num_classes]`
53    /// - targets: `[batch_size, num_classes]`
54    pub fn forward<B: Backend>(&self, logits: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
55        let log_probs = if self.output_logits {
56            // Numerically stable via log-softmax
57            log_softmax(logits, 1)
58        } else {
59            // outputs are probabilities; clamp at -100.0 after log to avoid undefined values
60            // for zero-probability classes (mirrors the BCE treatment of log(0))
61            logits.log().clamp_min(-100.0)
62        };
63
64        let targets = if self.target_logits {
65            softmax(targets, 1)
66        } else {
67            targets
68        };
69
70        (targets * log_probs).sum_dim(1).mean().neg()
71    }
72}