burn_mamba/modules/activation/silu.rs
1//! SiLU (a.k.a. swish) activation: `silu(x) = x · sigmoid(x)`.
2//!
3//! Implemented as `x / (1 + exp(−x))`, which is fp16-aware (no separate
4//! `sigmoid` op) and used for the gating branches throughout the Mamba blocks.
5
6use burn::prelude::*;
7
8/// SiLU activation module: `silu(x) = x · sigmoid(x) = x / (1 + exp(−x))`.
9#[derive(Module, Debug, Default)]
10pub struct Silu;
11
12impl Silu {
13 /// Create the module.
14 pub fn new() -> Self {
15 Self {}
16 }
17
18 /// Applies the forward pass on the input tensor.
19 ///
20 /// # Shapes
21 ///
22 /// - input: `[..., any]`
23 /// - output: `[..., any]`
24 pub fn forward<const D: usize>(&self, input: Tensor<D>) -> Tensor<D> {
25 // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
26 input.clone() / ((-input).exp() + 1.0)
27 }
28}