burn_mamba/modules/mod.rs
1use burn::config::Config;
2use burn::prelude::*;
3
4/// Custom activations (fp16-stable `silu` / `softplus` / `log_sigmoid`).
5pub mod activation;
6/// Bidirectional layer stacks (straight + reversed passes, merged per pair).
7pub mod bidi;
8/// The per-network cache collection trait ([`CacheStack`]) + [`MambaCaches`].
9pub mod cache;
10/// A single Pre-LN residual layer wrapping one SSM block ([`Layer`]).
11pub mod layer;
12/// The (virtual-)layer stack over real weight sets ([`Layers`]).
13pub mod layers;
14/// Loss functions (binary cross-entropy, cross-entropy, mean squared error).
15pub mod loss;
16/// Tensor helpers: `segsum`, `gqa`, typed `split`, and `sanity` guards.
17pub mod misc;
18/// Multi-Gate Residuals: multi-stream gated depth-wise residuals ([`Residuals`]).
19pub mod multi_gate;
20/// Family-generic networks ([`MambaLatentNet`] / [`MambaVocabNet`]).
21pub mod network;
22/// RMS norms ([`RmsNorm`] QK-norm + [`RmsNormGated`]), fp16-safe.
23pub mod norm;
24
25pub use activation::log_sigmoid::log_sigmoid;
26pub use activation::silu::Silu;
27pub use activation::softplus::softplus;
28pub use misc::gqa::gqa_expand_to_heads;
29pub use misc::rope::{apply_rope, apply_rope_partial, wrap_angle};
30pub use misc::sanity::sanity;
31pub use misc::segsum::segsum;
32pub use misc::split::split_into;
33pub use norm::rms_norm::{RmsNorm, RmsNormConfig};
34pub use norm::rms_norm_gated::{RmsNormGated, RmsNormGatedConfig};
35
36pub use bidi::{MambaBidiLayers, MambaBidiLayersConfig};
37pub use cache::{CacheStack, MambaCaches};
38pub use layer::Layer;
39pub use layers::{Layers, LayersBuilder};
40pub use multi_gate::{
41 MultiGate, MultiGateResidual, MultiGateResidualConfig, Residuals, ResidualsConfig,
42};
43pub use network::{MambaLatentNet, MambaLatentNetConfig, MambaVocabNet, MambaVocabNetConfig};
44
45/// Per-family block interface the generic [`Layer`]/[`Layers`] delegate to.
46pub trait MambaBlock: Module {
47 /// Per-block streaming cache (one layer's worth of state).
48 type Cache;
49 /// The per-network cache collection for this family.
50 type Caches: CacheStack<Cache = Self::Cache>;
51 /// SSD algorithm / chunk-length selector. `()` for families without one.
52 type SsdPath;
53
54 /// Full-sequence (chunked) pass — training / prefill.
55 fn block_forward(
56 &self,
57 x: Tensor<3>,
58 cache: Option<Self::Cache>,
59 ssd_path: Self::SsdPath,
60 ) -> (Tensor<3>, Self::Cache);
61
62 /// Single-token recurrent step — decoding.
63 fn block_step(&self, x: Tensor<2>, cache: Option<Self::Cache>) -> (Tensor<2>, Self::Cache);
64
65 /// Closed-form **stationary fixed point**: the limit of
66 /// [`Self::block_step`] outputs when the same constant token is stepped
67 /// forever. The limit forgets the starting state, so no cache is taken or
68 /// returned. The default implementation panics — only Mamba-3 currently
69 /// provides the closed form (see
70 /// [`Mamba3::step_infinite`](crate::mamba3::prelude::Mamba3::step_infinite)).
71 fn block_step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
72 let _ = x;
73 unimplemented!("block_step_infinite: constant-input shortcuts are only implemented for Mamba-3")
74 }
75
76 /// Closed-form jump equivalent to `n` consecutive [`Self::block_step`]
77 /// calls on the same constant token: the last step's output and the final
78 /// cache, in O(1). The default implementation panics — only Mamba-3
79 /// currently provides it (see
80 /// [`Mamba3::step_n_approx`](crate::mamba3::prelude::Mamba3::step_n_approx)).
81 fn block_step_n_approx(
82 &self,
83 x: Tensor<2>,
84 n: usize,
85 cache: Option<Self::Cache>,
86 ) -> (Tensor<2>, Self::Cache) {
87 let _ = (x, n, cache);
88 unimplemented!("block_step_n_approx: constant-input shortcuts are only implemented for Mamba-3")
89 }
90
91 /// Build `n_virtual` zero caches sized for a `[batch, sequence, d_model]` input.
92 fn zero_caches_3d(&self, x: &Tensor<3>, n_virtual: usize) -> Self::Caches;
93 /// Build `n_virtual` zero caches sized for a `[batch, d_model]` input.
94 fn zero_caches_2d(&self, x: &Tensor<2>, n_virtual: usize) -> Self::Caches;
95}
96
97/// A block *config* that knows its `d_model` and how to build its [`MambaBlock`].
98/// Lets the generic builders construct `Layers<M>` without knowing the family.
99pub trait MambaBlockConfig: Config {
100 /// The block this config builds.
101 type Block: MambaBlock;
102 /// Model width, used to size each layer's pre-norm.
103 fn d_model(&self) -> usize;
104 /// Allocate and initialise the block on `device`.
105 fn init_block(&self, device: &Device) -> Self::Block;
106}
107
108// ===========================================================================
109// Unifying enums: one runtime + one serializable Config across all families
110// ===========================================================================
111//
112// The generic `LatentNetwork<M>` above is family-typed (`M` is fixed at the type
113// level). To let an example (or a user) choose the family at *runtime* — and to
114// serialize that choice for docs/config round-trips — we wrap the three
115// monomorphisations in enums. `#[derive(Module)]` and `#[derive(Config)]` both
116// support enums (verified), so this stays first-class Burn.
117
118/// An explicit, family-tagged SSD-path selector for the unified API.
119///
120/// Each variant carries the concrete per-family path so callers can choose the
121/// algorithm/chunk explicitly; the `*_default` constructors offer the common
122/// "ride along the family default" path without making it the *only* option.
123#[derive(Debug, Clone)]
124pub enum MambaSsdPath {
125 /// Mamba-1 has no SSD chunking (path is the unit type).
126 #[cfg(feature = "mamba1")]
127 Mamba1,
128 /// Mamba-2 SSD path.
129 #[cfg(feature = "mamba2")]
130 Mamba2(crate::mamba2::prelude::Mamba2SsdPath),
131 /// Mamba-3 SSD path.
132 #[cfg(feature = "mamba3")]
133 Mamba3(crate::mamba3::prelude::Mamba3SsdPath),
134}
135
136impl MambaSsdPath {
137 /// The Mamba-2 default path (`SerialRecalculated`, optimal chunk).
138 #[cfg(feature = "mamba2")]
139 pub fn mamba2_default() -> Self {
140 Self::Mamba2(Default::default())
141 }
142 /// The Mamba-3 default path (`SerialRecalculated`, optimal chunk).
143 #[cfg(feature = "mamba3")]
144 pub fn mamba3_default() -> Self {
145 Self::Mamba3(Default::default())
146 }
147}