Skip to main content

burn_mamba/modules/activation/
softplus.rs

1//! Softplus activation: `softplus(x) = log(1 + eˣ)`, a smooth ReLU.
2//!
3//! Used to produce the strictly-positive discretisation step `Δ` (and, in
4//! Mamba-3, the data-dependent `A`).  The fp16 path uses the numerically-stable
5//! identity `softplus(x) = max(x, 0) + log(1 + e^−|x|)` to avoid overflow in
6//! `eˣ`; the wider formats use `log1p(eˣ)` directly.
7
8use burn::prelude::*;
9use burn::tensor::DType;
10
11/// Applies the softplus function element-wise: `log(1 + eˣ)`.
12///
13/// Panics on non-float element types.
14pub fn softplus<const D: usize>(x: Tensor<D>) -> Tensor<D> {
15    match x.dtype() {
16        DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
17            // softplus = log(e^x + 1)
18            x.exp().log1p()
19        }
20        DType::F16 => {
21            // (x.exp() + 1.).log()
22
23            // max(a,b) = (a + b + |a-b|)/2
24            // softplus = max(x, 0) + log(e^-|x| + 1)
25            //          = (x + |x|) / 2 + log(e^-|x| + 1)
26            let xabs = x.clone().abs();
27            (x + xabs.clone()) / 2. + xabs.neg().exp().log1p()
28        }
29        DType::I64
30        | DType::I32
31        | DType::I16
32        | DType::I8
33        | DType::U64
34        | DType::U32
35        | DType::U16
36        | DType::U8 => {
37            unreachable!()
38        }
39        DType::Bool(_) => {
40            unreachable!()
41        }
42        DType::QFloat(_) => {
43            unimplemented!()
44        }
45    }
46}