burn_mamba/modules/activation/softplus.rs
1//! Softplus activation: `softplus(x) = log(1 + eˣ)`, a smooth ReLU.
2//!
3//! Used to produce the strictly-positive discretisation step `Δ` (and, in
4//! Mamba-3, the data-dependent `A`). The fp16 path uses the numerically-stable
5//! identity `softplus(x) = max(x, 0) + log(1 + e^−|x|)` to avoid overflow in
6//! `eˣ`; the wider formats use `log1p(eˣ)` directly.
7
8use burn::prelude::*;
9use burn::tensor::DType;
10
11/// Applies the softplus function element-wise: `log(1 + eˣ)`.
12///
13/// Panics on non-float element types.
14pub fn softplus<const D: usize>(x: Tensor<D>) -> Tensor<D> {
15 match x.dtype() {
16 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
17 // softplus = log(e^x + 1)
18 x.exp().log1p()
19 }
20 DType::F16 => {
21 // (x.exp() + 1.).log()
22
23 // max(a,b) = (a + b + |a-b|)/2
24 // softplus = max(x, 0) + log(e^-|x| + 1)
25 // = (x + |x|) / 2 + log(e^-|x| + 1)
26 let xabs = x.clone().abs();
27 (x + xabs.clone()) / 2. + xabs.neg().exp().log1p()
28 }
29 DType::I64
30 | DType::I32
31 | DType::I16
32 | DType::I8
33 | DType::U64
34 | DType::U32
35 | DType::U16
36 | DType::U8 => {
37 unreachable!()
38 }
39 DType::Bool(_) => {
40 unreachable!()
41 }
42 DType::QFloat(_) => {
43 unimplemented!()
44 }
45 }
46}