burn_mamba/utils/
softplus.rs1use burn::prelude::*;
2use burn::tensor::DType;
3
4pub fn softplus<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
8 match x.dtype() {
9 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
10 x.exp().log1p()
12 }
13 DType::F16 => {
14 let xabs = x.clone().abs();
20 (x + xabs.clone()) / 2. + xabs.neg().exp().log1p()
21 }
22 DType::I64
23 | DType::I32
24 | DType::I16
25 | DType::I8
26 | DType::U64
27 | DType::U32
28 | DType::U16
29 | DType::U8 => {
30 unreachable!()
31 }
32 DType::Bool(_) => {
33 unreachable!()
34 }
35 DType::QFloat(_) => {
36 unimplemented!()
37 }
38 }
39}