burn_mamba/utils/
log_sigmoid.rs1use burn::prelude::*;
2use burn::tensor::DType;
3
4pub fn log_sigmoid<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
8 match x.dtype() {
9 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
10 (x.neg().exp() + 1.).recip().log()
12 }
13 DType::F16 => {
14 -crate::utils::softplus::softplus(x.neg())
16 }
17 DType::I64
18 | DType::I32
19 | DType::I16
20 | DType::I8
21 | DType::U64
22 | DType::U32
23 | DType::U16
24 | DType::U8 => {
25 unreachable!()
26 }
27 DType::Bool(_) => {
28 unreachable!()
29 }
30 DType::QFloat(_) => {
31 unimplemented!()
32 }
33 }
34}