Skip to main content

burn_mamba/utils/
softplus.rs

1use burn::prelude::*;
2use burn::tensor::DType;
3
4/// Applies the SoftPlus function element-wise.
5///
6/// The SoftPlus function is a smooth approximation of the ReLU function.
7pub fn softplus<const D: usize, B: Backend>(x: Tensor<B, D>) -> Tensor<B, D> {
8    match x.dtype() {
9        DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
10            // softplus = log(e^x + 1)
11            x.exp().log1p()
12        }
13        DType::F16 => {
14            // (x.exp() + 1.).log()
15
16            // max(a,b) = (a + b + |a-b|)/2
17            // softplus = max(x, 0) + log(e^-|x| + 1)
18            //          = (x + |x|) / 2 + log(e^-|x| + 1)
19            let xabs = x.clone().abs();
20            (x + xabs.clone()) / 2. + xabs.neg().exp().log1p()
21        }
22        DType::I64
23        | DType::I32
24        | DType::I16
25        | DType::I8
26        | DType::U64
27        | DType::U32
28        | DType::U16
29        | DType::U8 => {
30            unreachable!()
31        }
32        DType::Bool(_) => {
33            unreachable!()
34        }
35        DType::QFloat(_) => {
36            unimplemented!()
37        }
38    }
39}