Skip to main content

burn_mamba/modules/activation/
log_sigmoid.rs

1//! Numerically-stable log-sigmoid: `log σ(x) = −log(1 + e^−x)`.
2//!
3//! The wider float formats evaluate `log(1 / (1 + e^−x))` directly; the fp16
4//! path uses the stable identity `log σ(x) = −softplus(−x)` (see
5//! [`softplus`](crate::utils::softplus)) to avoid overflow.
6
7use burn::prelude::*;
8use burn::tensor::DType;
9
10/// Applies the log-sigmoid function element-wise: `log(1 / (1 + e^−x))`.
11///
12/// Panics on non-float element types.
13pub fn log_sigmoid<const D: usize>(x: Tensor<D>) -> Tensor<D> {
14    match x.dtype() {
15        DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
16            // log_sigmoid(x) = log(1 / (1 + exp(-x)))
17            (x.neg().exp() + 1.).recip().log()
18        }
19        DType::F16 => {
20            // log_sigmoid(x) = -softplus(-x)
21            -crate::modules::softplus(x.neg())
22        }
23        DType::I64
24        | DType::I32
25        | DType::I16
26        | DType::I8
27        | DType::U64
28        | DType::U32
29        | DType::U16
30        | DType::U8 => {
31            unreachable!()
32        }
33        DType::Bool(_) => {
34            unreachable!()
35        }
36        DType::QFloat(_) => {
37            unimplemented!()
38        }
39    }
40}