burn_mamba/utils/silu.rs
1use burn::prelude::*;
2
3// silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
4#[derive(Module, Clone, Debug, Default)]
5pub struct Silu;
6
7impl Silu {
8 /// Create the module.
9 pub fn new() -> Self {
10 Self {}
11 }
12
13 /// Applies the forward pass on the input tensor.
14 ///
15 /// # Shapes
16 ///
17 /// - input: `[..., any]`
18 /// - output: `[..., any]`
19 pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
20 // silu(x) = x * sigmoid(x) = x / (1 + exp(-x))
21 input.clone() / ((-input).exp() + 1.0)
22 }
23}