burn_mamba/utils/loss/cross_entropy.rs
1use burn::module::Module;
2use burn::prelude::*;
3use burn::tensor::activation::{log_softmax, softmax};
4
5/// Configuration to create a [Cross-entropy loss](CrossEntropyLoss) using the [init function](CrossEntropyLossConfig::init).
6#[derive(Config, Debug)]
7pub struct CrossEntropyLossConfig {
8 /// Treat the outputs as logits, applying log-softmax when computing the loss.
9 ///
10 /// When `false`, outputs are assumed to be probabilities (e.g. post-softmax).
11 #[config(default = true)]
12 pub output_logits: bool,
13
14 /// Treat the targets as logits, applying softmax to normalize them before computing the loss.
15 ///
16 /// When `false`, targets are assumed to already be a valid probability distribution
17 /// (e.g. one-hot or soft labels that sum to 1).
18 #[config(default = false)]
19 pub target_logits: bool,
20}
21
22impl CrossEntropyLossConfig {
23 /// Initialize [Cross-entropy loss](CrossEntropyLoss).
24 pub fn init(&self) -> CrossEntropyLoss {
25 CrossEntropyLoss {
26 output_logits: self.output_logits,
27 target_logits: self.target_logits,
28 }
29 }
30}
31
32/// Calculate the cross-entropy loss from the output logits and the targets.
33///
34/// Unlike the full [`CrossEntropyLoss`](super::CrossEntropyLoss), this variant accepts
35/// floating-point targets (e.g. one-hot, soft label distributions, or un-normalized logits)
36/// rather than integer class indices, and omits padding, per-class weights, and label smoothing.
37///
38/// Should be created using [CrossEntropyLossConfig].
39#[derive(Module, Clone, Debug)]
40pub struct CrossEntropyLoss {
41 /// Treat the outputs as logits.
42 pub output_logits: bool,
43 /// Treat the targets as logits.
44 pub target_logits: bool,
45}
46
47impl CrossEntropyLoss {
48 /// Compute the criterion on the output tensor.
49 ///
50 /// # Shapes
51 ///
52 /// - logits: `[batch_size, num_classes]`
53 /// - targets: `[batch_size, num_classes]`
54 pub fn forward<B: Backend>(&self, logits: Tensor<B, 2>, targets: Tensor<B, 2>) -> Tensor<B, 1> {
55 let log_probs = if self.output_logits {
56 // Numerically stable via log-softmax
57 log_softmax(logits, 1)
58 } else {
59 // outputs are probabilities; clamp at -100.0 after log to avoid undefined values
60 // for zero-probability classes (mirrors the BCE treatment of log(0))
61 logits.log().clamp_min(-100.0)
62 };
63
64 let targets = if self.target_logits {
65 softmax(targets, 1)
66 } else {
67 targets
68 };
69
70 (targets * log_probs).sum_dim(1).mean().neg()
71 }
72}