Skip to main content

burn_mamba/modules/activation/
silu.rs

1//! SiLU (a.k.a. swish) activation: `silu(x) = x · sigmoid(x)`.
2//!
3//! Implemented as `x / (1 + exp(−x))`, which is fp16-aware (no separate
4//! `sigmoid` op) and used for the gating branches throughout the Mamba blocks.
5
6use burn::prelude::*;
7
8/// SiLU activation module: `silu(x) = x · sigmoid(x) = x / (1 + exp(−x))`.
9#[derive(Module, Debug, Default)]
10pub struct Silu;
11
12impl Silu {
13    /// Create the module.
14    pub fn new() -> Self {
15        Self {}
16    }
17
18    /// Applies the forward pass on the input tensor.
19    ///
20    /// # Shapes
21    ///
22    /// - input: `[..., any]`
23    /// - output: `[..., any]`
24    pub fn forward<const D: usize>(&self, input: Tensor<D>) -> Tensor<D> {
25        // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
26        input.clone() / ((-input).exp() + 1.0)
27    }
28}