Skip to main content

burn_mamba/utils/
silu.rs

1use burn::prelude::*;
2
3// silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
4#[derive(Module, Clone, Debug, Default)]
5pub struct Silu;
6
7impl Silu {
8    /// Create the module.
9    pub fn new() -> Self {
10        Self {}
11    }
12
13    /// Applies the forward pass on the input tensor.
14    ///
15    /// # Shapes
16    ///
17    /// - input: `[..., any]`
18    /// - output: `[..., any]`
19    pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
20        // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
21        input.clone() / ((-input).exp() + 1.0)
22    }
23}