Skip to main content

burn_mamba/modules/
mod.rs

1use burn::config::Config;
2use burn::prelude::*;
3
4/// Custom activations (fp16-stable `silu` / `softplus` / `log_sigmoid`).
5pub mod activation;
6/// Bidirectional layer stacks (straight + reversed passes, merged per pair).
7pub mod bidi;
8/// The per-network cache collection trait ([`CacheStack`]) + [`MambaCaches`].
9pub mod cache;
10/// A single Pre-LN residual layer wrapping one SSM block ([`Layer`]).
11pub mod layer;
12/// The (virtual-)layer stack over real weight sets ([`Layers`]).
13pub mod layers;
14/// Loss functions (binary cross-entropy, cross-entropy, mean squared error).
15pub mod loss;
16/// Tensor helpers: `segsum`, `gqa`, typed `split`, and `sanity` guards.
17pub mod misc;
18/// Multi-Gate Residuals: multi-stream gated depth-wise residuals ([`Residuals`]).
19pub mod multi_gate;
20/// Family-generic networks ([`MambaLatentNet`] / [`MambaVocabNet`]).
21pub mod network;
22/// RMS norms ([`RmsNorm`] QK-norm + [`RmsNormGated`]), fp16-safe.
23pub mod norm;
24
25pub use activation::log_sigmoid::log_sigmoid;
26pub use activation::silu::Silu;
27pub use activation::softplus::softplus;
28pub use misc::gqa::gqa_expand_to_heads;
29pub use misc::rope::{apply_rope, apply_rope_partial, wrap_angle};
30pub use misc::sanity::sanity;
31pub use misc::segsum::segsum;
32pub use misc::split::split_into;
33pub use norm::rms_norm::{RmsNorm, RmsNormConfig};
34pub use norm::rms_norm_gated::{RmsNormGated, RmsNormGatedConfig};
35
36pub use bidi::{MambaBidiLayers, MambaBidiLayersConfig};
37pub use cache::{CacheStack, MambaCaches};
38pub use layer::Layer;
39pub use layers::{Layers, LayersBuilder};
40pub use multi_gate::{
41    MultiGate, MultiGateResidual, MultiGateResidualConfig, Residuals, ResidualsConfig,
42};
43pub use network::{MambaLatentNet, MambaLatentNetConfig, MambaVocabNet, MambaVocabNetConfig};
44
45/// Per-family block interface the generic [`Layer`]/[`Layers`] delegate to.
46pub trait MambaBlock: Module {
47    /// Per-block streaming cache (one layer's worth of state).
48    type Cache;
49    /// The per-network cache collection for this family.
50    type Caches: CacheStack<Cache = Self::Cache>;
51    /// SSD algorithm / chunk-length selector. `()` for families without one.
52    type SsdPath;
53
54    /// Full-sequence (chunked) pass — training / prefill.
55    fn block_forward(
56        &self,
57        x: Tensor<3>,
58        cache: Option<Self::Cache>,
59        ssd_path: Self::SsdPath,
60    ) -> (Tensor<3>, Self::Cache);
61
62    /// Single-token recurrent step — decoding.
63    fn block_step(&self, x: Tensor<2>, cache: Option<Self::Cache>) -> (Tensor<2>, Self::Cache);
64
65    /// Closed-form **stationary fixed point**: the limit of
66    /// [`Self::block_step`] outputs when the same constant token is stepped
67    /// forever. The limit forgets the starting state, so no cache is taken or
68    /// returned. The default implementation panics — only Mamba-3 currently
69    /// provides the closed form (see
70    /// [`Mamba3::step_infinite`](crate::mamba3::prelude::Mamba3::step_infinite)).
71    fn block_step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
72        let _ = x;
73        unimplemented!("block_step_infinite: constant-input shortcuts are only implemented for Mamba-3")
74    }
75
76    /// Closed-form jump equivalent to `n` consecutive [`Self::block_step`]
77    /// calls on the same constant token: the last step's output and the final
78    /// cache, in O(1). The default implementation panics — only Mamba-3
79    /// currently provides it (see
80    /// [`Mamba3::step_n_approx`](crate::mamba3::prelude::Mamba3::step_n_approx)).
81    fn block_step_n_approx(
82        &self,
83        x: Tensor<2>,
84        n: usize,
85        cache: Option<Self::Cache>,
86    ) -> (Tensor<2>, Self::Cache) {
87        let _ = (x, n, cache);
88        unimplemented!("block_step_n_approx: constant-input shortcuts are only implemented for Mamba-3")
89    }
90
91    /// Build `n_virtual` zero caches sized for a `[batch, sequence, d_model]` input.
92    fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Self::Caches;
93    /// Build `n_virtual` zero caches sized for a `[batch, d_model]` input.
94    fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Self::Caches;
95}
96
97/// A block *config* that knows its `d_model` and how to build its [`MambaBlock`].
98/// Lets the generic builders construct `Layers<M>` without knowing the family.
99pub trait MambaBlockConfig: Config {
100    /// The block this config builds.
101    type Block: MambaBlock;
102    /// Model width, used to size each layer's pre-norm.
103    fn d_model(&self) -> usize;
104    /// Allocate and initialise the block on `device`.
105    fn init_block(&self, device: &Device) -> Self::Block;
106}
107
108// ===========================================================================
109// Unifying enums: one runtime + one serializable Config across all families
110// ===========================================================================
111//
112// The generic `LatentNetwork<M>` above is family-typed (`M` is fixed at the type
113// level). To let an example (or a user) choose the family at *runtime* — and to
114// serialize that choice for docs/config round-trips — we wrap the three
115// monomorphisations in enums. `#[derive(Module)]` and `#[derive(Config)]` both
116// support enums (verified), so this stays first-class Burn.
117
118/// An explicit, family-tagged SSD-path selector for the unified API.
119///
120/// Each variant carries the concrete per-family path so callers can choose the
121/// algorithm/chunk explicitly; the `*_default` constructors offer the common
122/// "ride along the family default" path without making it the *only* option.
123#[derive(Debug, Clone)]
124pub enum MambaSsdPath {
125    /// Mamba-1 has no SSD chunking (path is the unit type).
126    #[cfg(feature = "mamba1")]
127    Mamba1,
128    /// Mamba-2 SSD path.
129    #[cfg(feature = "mamba2")]
130    Mamba2(crate::mamba2::prelude::Mamba2SsdPath),
131    /// Mamba-3 SSD path.
132    #[cfg(feature = "mamba3")]
133    Mamba3(crate::mamba3::prelude::Mamba3SsdPath),
134}
135
136impl MambaSsdPath {
137    /// The Mamba-2 default path (`SerialRecalculated`, optimal chunk).
138    #[cfg(feature = "mamba2")]
139    pub fn mamba2_default() -> Self {
140        Self::Mamba2(Default::default())
141    }
142    /// The Mamba-3 default path (`SerialRecalculated`, optimal chunk).
143    #[cfg(feature = "mamba3")]
144    pub fn mamba3_default() -> Self {
145        Self::Mamba3(Default::default())
146    }
147}