Skip to main content

burn_mamba/modules/loss/
cross_entropy.rs

1//! Cross-entropy loss accepting floating-point target distributions.
2//!
3//! A lighter variant of [`burn::nn::loss::CrossEntropyLoss`]: it takes soft
4//! float targets (one-hot, soft labels, or raw logits) rather than integer
5//! class indices, and drops padding/weights/label-smoothing.  `output_logits`
6//! and `target_logits` control whether each side is normalised (log-softmax /
7//! softmax) before the loss.
8
9use burn::module::Module;
10use burn::prelude::*;
11use burn::tensor::activation::{log_softmax, softmax};
12
13/// Configuration to create a [`CrossEntropyLoss`] using the [`CrossEntropyLossConfig::init`].
14#[derive(Config, Debug)]
15pub struct CrossEntropyLossConfig {
16    /// Treat the outputs as logits, applying log-softmax when computing the loss.
17    ///
18    /// When `false`, outputs are assumed to be probabilities (e.g. post-softmax).
19    #[config(default = true)]
20    pub output_logits: bool,
21
22    /// Treat the targets as logits, applying softmax to normalize them before computing the loss.
23    ///
24    /// When `false`, targets are assumed to already be a valid probability distribution
25    /// (e.g. one-hot or soft labels that sum to 1).
26    #[config(default = false)]
27    pub target_logits: bool,
28}
29
30impl CrossEntropyLossConfig {
31    /// Initialize [`CrossEntropyLoss`].
32    pub fn init(&self) -> CrossEntropyLoss {
33        CrossEntropyLoss {
34            output_logits: self.output_logits,
35            target_logits: self.target_logits,
36        }
37    }
38}
39
40/// Calculate the cross-entropy loss from the output logits and the targets.
41///
42/// Unlike the full [`burn::nn::loss::CrossEntropyLoss`], this variant accepts
43/// floating-point targets (e.g. one-hot, soft label distributions, or un-normalized logits)
44/// rather than integer class indices, and omits padding, per-class weights, and label smoothing.
45///
46/// Should be created using [CrossEntropyLossConfig].
47#[derive(Module, Debug)]
48pub struct CrossEntropyLoss {
49    /// Treat the outputs as logits.
50    pub output_logits: bool,
51    /// Treat the targets as logits.
52    pub target_logits: bool,
53}
54
55impl CrossEntropyLoss {
56    /// Compute the criterion on the output tensor.
57    ///
58    /// # Shapes
59    ///
60    /// - logits: `[batch_size, num_classes]`
61    /// - targets: `[batch_size, num_classes]`
62    pub fn forward(&self, logits: Tensor<2>, targets: Tensor<2>) -> Tensor<1> {
63        let log_probs = if self.output_logits {
64            // Numerically stable via log-softmax
65            log_softmax(logits, 1)
66        } else {
67            // outputs are probabilities; clamp at -100.0 after log to avoid undefined values
68            // for zero-probability classes (mirrors the BCE treatment of log(0))
69            logits.log().clamp_min(-100.0)
70        };
71
72        let targets = if self.target_logits {
73            softmax(targets, 1)
74        } else {
75            targets
76        };
77
78        (targets * log_probs).sum_dim(1).mean().neg()
79    }
80}