Skip to main content

burn_mamba/utils/
log_sigmoid.rs

1use burn::prelude::*;
2use burn::tensor::DType;
3
4/// Applies the log sigmoid function element-wise.
5///
6/// `log_sigmoid(x) = log(1 / (1 + exp(-x)))`
7pub fn log_sigmoid<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
8    match x.dtype() {
9        DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
10            // log_sigmoid(x) = log(1 / (1 + exp(-x)))
11            (x.neg().exp() + 1.).recip().log()
12        }
13        DType::F16 => {
14            // log_sigmoid(x) = -softplus(-x)
15            -crate::utils::softplus::softplus(x.neg())
16        }
17        DType::I64
18        | DType::I32
19        | DType::I16
20        | DType::I8
21        | DType::U64
22        | DType::U32
23        | DType::U16
24        | DType::U8 => {
25            unreachable!()
26        }
27        DType::Bool(_) => {
28            unreachable!()
29        }
30        DType::QFloat(_) => {
31            unimplemented!()
32        }
33    }
34}