burn_mamba/modules/loss/cross_entropy.rs
1//! Cross-entropy loss accepting floating-point target distributions.
2//!
3//! A lighter variant of [`burn::nn::loss::CrossEntropyLoss`]: it takes soft
4//! float targets (one-hot, soft labels, or raw logits) rather than integer
5//! class indices, and drops padding/weights/label-smoothing. `output_logits`
6//! and `target_logits` control whether each side is normalised (log-softmax /
7//! softmax) before the loss.
8
9use burn::module::Module;
10use burn::prelude::*;
11use burn::tensor::activation::{log_softmax, softmax};
12
13/// Configuration to create a [`CrossEntropyLoss`] using the [`CrossEntropyLossConfig::init`].
14#[derive(Config, Debug)]
15pub struct CrossEntropyLossConfig {
16 /// Treat the outputs as logits, applying log-softmax when computing the loss.
17 ///
18 /// When `false`, outputs are assumed to be probabilities (e.g. post-softmax).
19 #[config(default = true)]
20 pub output_logits: bool,
21
22 /// Treat the targets as logits, applying softmax to normalize them before computing the loss.
23 ///
24 /// When `false`, targets are assumed to already be a valid probability distribution
25 /// (e.g. one-hot or soft labels that sum to 1).
26 #[config(default = false)]
27 pub target_logits: bool,
28}
29
30impl CrossEntropyLossConfig {
31 /// Initialize [`CrossEntropyLoss`].
32 pub fn init(&self) -> CrossEntropyLoss {
33 CrossEntropyLoss {
34 output_logits: self.output_logits,
35 target_logits: self.target_logits,
36 }
37 }
38}
39
40/// Calculate the cross-entropy loss from the output logits and the targets.
41///
42/// Unlike the full [`burn::nn::loss::CrossEntropyLoss`], this variant accepts
43/// floating-point targets (e.g. one-hot, soft label distributions, or un-normalized logits)
44/// rather than integer class indices, and omits padding, per-class weights, and label smoothing.
45///
46/// Should be created using [CrossEntropyLossConfig].
47#[derive(Module, Debug)]
48pub struct CrossEntropyLoss {
49 /// Treat the outputs as logits.
50 pub output_logits: bool,
51 /// Treat the targets as logits.
52 pub target_logits: bool,
53}
54
55impl CrossEntropyLoss {
56 /// Compute the criterion on the output tensor.
57 ///
58 /// # Shapes
59 ///
60 /// - logits: `[batch_size, num_classes]`
61 /// - targets: `[batch_size, num_classes]`
62 pub fn forward(&self, logits: Tensor<2>, targets: Tensor<2>) -> Tensor<1> {
63 let log_probs = if self.output_logits {
64 // Numerically stable via log-softmax
65 log_softmax(logits, 1)
66 } else {
67 // outputs are probabilities; clamp at -100.0 after log to avoid undefined values
68 // for zero-probability classes (mirrors the BCE treatment of log(0))
69 logits.log().clamp_min(-100.0)
70 };
71
72 let targets = if self.target_logits {
73 softmax(targets, 1)
74 } else {
75 targets
76 };
77
78 (targets * log_probs).sum_dim(1).mean().neg()
79 }
80}