burn_mamba/modules/activation/log_sigmoid.rs
1//! Numerically-stable log-sigmoid: `log σ(x) = −log(1 + e^−x)`.
2//!
3//! The wider float formats evaluate `log(1 / (1 + e^−x))` directly; the fp16
4//! path uses the stable identity `log σ(x) = −softplus(−x)` (see
5//! [`softplus`](crate::utils::softplus)) to avoid overflow.
6
7use burn::prelude::*;
8use burn::tensor::DType;
9
10/// Applies the log-sigmoid function element-wise: `log(1 / (1 + e^−x))`.
11///
12/// Panics on non-float element types.
13pub fn log_sigmoid<const D: usize>(x: Tensor<D>) -> Tensor<D> {
14 match x.dtype() {
15 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
16 // log_sigmoid(x) = log(1 / (1 + exp(-x)))
17 (x.neg().exp() + 1.).recip().log()
18 }
19 DType::F16 => {
20 // log_sigmoid(x) = -softplus(-x)
21 -crate::modules::softplus(x.neg())
22 }
23 DType::I64
24 | DType::I32
25 | DType::I16
26 | DType::I8
27 | DType::U64
28 | DType::U32
29 | DType::U16
30 | DType::U8 => {
31 unreachable!()
32 }
33 DType::Bool(_) => {
34 unreachable!()
35 }
36 DType::QFloat(_) => {
37 unimplemented!()
38 }
39 }
40}