burn_mamba/modules/multi_gate.rs
1//! Multi-Gate Residuals (MGR) — a depth-wise residual scheme replacing the plain
2//! additive skip of a [`Layers`](crate::modules::Layers) stack.
3//!
4//! Instead of one residual stream, MGR keeps **`n_stream` parallel streams**
5//! `sᵢ` (all seeded from the stack input). Between layers, one
6//! [`MultiGateResidual`] per layer does two convex, norm-bounded operations
7//! (paper §"Our Architecture"):
8//!
9//! 1. **Mixer** (independent sigmoid gate) — each stream is interpolated towards
10//! the current layer output `F_l` by a per-stream gate `βᵢ`:
11//! `sᵢ' = (1−βᵢ)·sᵢ + βᵢ·F_l`, with
12//! `βᵢ = σ( (w⁽ᵝ⁾ · RMSNorm(sᵢ))/√d + b⁽ᵝ⁾ᵢ )`.
13//! 2. **Aggregator** (depth-wise attention pooling, "AttnPool") — the updated
14//! streams are pooled into the next layer's input `h` by a softmax over
15//! streams: `αᵢ = softmax_i( (w⁽ᵅ⁾ · RMSNorm(sᵢ'))/√d )`, `h = Σᵢ αᵢ·sᵢ'`.
16//!
17//! Both `w` vectors are learnable in `ℝ^d` (init zero), the RMSNorm is
18//! parameter-free, and `b⁽ᵝ⁾` is a per-stream learnable bias. Only the
19//! **independent** (sigmoid) gate is implemented; the paper's competitive
20//! (softmax) variant is omitted.
21//!
22//! MGR is purely **point-wise over `(batch, sequence)`** — the streams only
23//! evolve along *depth*, never along the sequence — so `forward` over a sequence
24//! equals `step` unrolled token-by-token, and `step` carries no extra state
25//! (each token rebuilds its own depth-streams).
26//!
27//! **Gate-bias initialisation.** Following Highway Networks, a negative
28//! `init_bias` biases the gates towards *carry* (small updates) at the start of
29//! training. The paper scales it with depth `L`:
30//! `b_init = ln( √(L/L_base)·(exp(−b_base)+1) − n )` (with `L_base = 21`,
31//! `b_base = −3`); here `init_bias` is taken directly so the caller may apply
32//! that formula. Default `0` (gates open at `σ(0)=0.5`).
33
34use crate::modules::bidi::NoOp;
35use crate::utils::div_eps;
36use burn::config::Config;
37use burn::module::Param;
38use burn::nn::Initializer;
39use burn::prelude::*;
40use burn::tensor::activation::{sigmoid, softmax};
41use burn::tensor::{DType, f16};
42
43/// One layer's Multi-Gate Residual parameters: the mixer query `w⁽ᵝ⁾` + bias
44/// `b⁽ᵝ⁾`, and the aggregator (AttnPool) query `w⁽ᵅ⁾`.
45#[derive(Module, Debug)]
46pub struct MultiGateResidual {
47 /// Mixer query `w⁽ᵝ⁾ ∈ ℝ^d` (the per-stream sigmoid gate), `[d_model]`.
48 pub w_beta: Param<Tensor<1>>,
49 /// Aggregator query `w⁽ᵅ⁾ ∈ ℝ^d` (the AttnPool softmax), `[d_model]`.
50 pub w_alpha: Param<Tensor<1>>,
51 /// Per-stream mixer gate bias `b⁽ᵝ⁾`, `[n_stream]`.
52 pub b_beta: Param<Tensor<1>>,
53 /// Model width `d`.
54 #[module(skip)]
55 pub d_model: usize,
56 /// Number of parallel residual streams `n`.
57 #[module(skip)]
58 pub n_stream: usize,
59}
60
61impl MultiGateResidual {
62 fn scale(&self) -> f32 {
63 (self.d_model as f32).powf(-0.5)
64 }
65
66 /// The parameter-free RMS denominator `d(x) ∈ [‥, 1]` such that the RMSNorm
67 /// (matching [`RmsNorm`] math with `γ ≡ 1`) is `x / d(x)`. Returning the
68 /// denominator rather than the normalised tensor lets [`Self::normed_score`]
69 /// fold it out of the (feature-axis) score reduction, so the full-width
70 /// normalised tensor is never built. The fp16 path keeps the same
71 /// overflow-safe max-rescale, folded into the same scalar denominator.
72 ///
73 /// [`RmsNorm`]: crate::modules::RmsNorm
74 fn rms_denom<const D: usize>(&self, x: Tensor<D>) -> Tensor<D> {
75 match x.dtype() {
76 DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
77 let eps = div_eps(x.dtype());
78 (x.clone() * x).mean_dim(D - 1).sqrt() + eps
79 }
80 DType::F16 => {
81 use burn::tensor::ElementConversion;
82 let eps: f16 = f16::from_elem(div_eps(x.dtype())) * f16::from_f32(2.);
83 // Single global scalar `max`, reshaped to `[1; D]` so it
84 // broadcasts against the `[‥, 1]` partial RMS.
85 let max = x.clone().no_grad().detach().abs().max().reshape([1; D]);
86 let x_ = x.clone() / (max.clone() + eps); // x_.abs() <= 1
87 let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt();
88 (rms_partial + eps) * max.sqrt()
89 }
90 _ => unreachable!("rms_denom expects a float dtype"),
91 }
92 }
93
94 /// The RMSNorm-then-dot score `scale · Σ_feat(x · w) / (rms(x)+eps)`,
95 /// shape `[‥, 1]`. The RMS denominator is constant over the feature axis, so
96 /// it is folded out of the reduction (via [`Self::rms_denom`]) — equal to
97 /// `Σ_feat(rms_norm(x) · w) · scale` but without materialising the full-width
98 /// normalised tensor.
99 fn normed_score<const R: usize>(&self, x: Tensor<R>, w: Tensor<R>) -> Tensor<R> {
100 let dot = (x.clone() * w).sum_dim(R - 1);
101 dot * self.scale() / self.rms_denom(x)
102 }
103
104 /// The shared mix + pool, generic over the streams rank `R` (the *stream*
105 /// axis is `R-2`, the *feature* axis `R-1`). [`Self::forward`] (`R = 4`) and
106 /// [`Self::step`] (`R = 3`) only differ by that rank, so both lift their
107 /// `layer_output` to a singleton stream axis, call this, and drop it again.
108 /// All reductions keep their axis (size 1) for broadcasting, so scores/gates
109 /// are `[…, n_stream, 1]` throughout.
110 ///
111 /// - `layer_output`: `F_l` lifted to a unit stream axis, `[…, 1, d_model]`
112 /// - `streams`: the `n_stream` residual streams, `[…, n_stream, d_model]`
113 ///
114 /// Returns `(h, streams')` with `h` still carrying its unit stream axis
115 /// (`[…, 1, d_model]`) and `streams'` the same shape as `streams`.
116 fn mix_pool<const R: usize>(
117 &self,
118 layer_output: Tensor<R>,
119 streams: Tensor<R>,
120 ) -> (Tensor<R>, Tensor<R>) {
121 let dims = streams.dims();
122 let (stream_axis, feat_axis) = (R - 2, R - 1);
123 assert_eq!(
124 dims[feat_axis], self.d_model,
125 "stream width must equal d_model"
126 );
127 assert_eq!(
128 dims[stream_axis], self.n_stream,
129 "stream count must equal n_stream"
130 );
131
132 // `b_beta` reshaped to broadcast on the stream axis: `[1, …, n_stream, 1]`.
133 let mut bias_shape = [1usize; R];
134 bias_shape[stream_axis] = self.n_stream;
135 let b_beta = self.b_beta.val().reshape(bias_shape);
136 // The query vectors broadcast on the feature axis: `[1, …, 1, d_model]`.
137 let w_beta = self.w_beta.val().unsqueeze::<R>();
138 let w_alpha = self.w_alpha.val().unsqueeze::<R>();
139
140 // Mixer: independent per-stream sigmoid gate, `β`: `[…, n_stream, 1]`.
141 let beta = sigmoid(self.normed_score(streams.clone(), w_beta) + b_beta);
142 // Lerp `(1−β)·streams + β·layer_output` (equal to the paper's
143 // `streams + β·(layer_output − streams)`) — written so no full-width
144 // intermediate is retained: `streams` is the already-saved input and
145 // `layer_output` is `[…, 1, d_model]`, so neither `mul` saves a new
146 // `[…, n_stream, d_model]` tensor and the `+` saves nothing.
147 let new_streams = streams * (-beta.clone() + 1.0) + layer_output * beta;
148
149 // Aggregator: depth-wise attention pooling (softmax over the stream axis).
150 let score_ns = self.normed_score(new_streams.clone(), w_alpha);
151 let alpha = softmax(score_ns, stream_axis);
152 let new_h = (alpha * new_streams.clone()).sum_dim(stream_axis);
153 (new_h, new_streams)
154 }
155
156 /// Full-sequence mix + pool.
157 ///
158 /// - `layer_output`: this layer's transform `F_l`, `[batch, sequence, d_model]`
159 /// - `streams`: the `n_stream` residual streams, `[batch, sequence, n_stream, d_model]`
160 ///
161 /// Returns `(h, streams')`: the pooled input `h` for the next layer
162 /// (`[batch, sequence, d_model]`) and the updated streams (same shape as in).
163 pub fn forward(&self, layer_output: Tensor<3>, streams: Tensor<4>) -> (Tensor<3>, Tensor<4>) {
164 let (new_h, new_streams) = self.mix_pool::<4>(layer_output.unsqueeze_dim(2), streams);
165 (new_h.squeeze_dim(2), new_streams)
166 }
167
168 /// Single-token mix + pool (the [`Self::forward`] math with the sequence axis
169 /// dropped).
170 ///
171 /// - `layer_output`: `[batch, d_model]`
172 /// - `streams`: `[batch, n_stream, d_model]`
173 ///
174 /// Returns `(h, streams')`: `[batch, d_model]` and `[batch, n_stream, d_model]`.
175 pub fn step(&self, layer_output: Tensor<2>, streams: Tensor<3>) -> (Tensor<2>, Tensor<3>) {
176 let (new_h, new_streams) = self.mix_pool::<3>(layer_output.unsqueeze_dim(1), streams);
177 (new_h.squeeze_dim(1), new_streams)
178 }
179}
180
181/// Configuration for a single [`MultiGateResidual`].
182#[derive(Config, Debug)]
183pub struct MultiGateResidualConfig {
184 /// Model width `d`.
185 pub d_model: usize,
186 /// Number of parallel residual streams `n`.
187 pub n_stream: usize,
188 /// Initial value for every entry of the gate bias `b⁽ᵝ⁾` (see module header).
189 #[config(default = 0.0)]
190 pub init_bias: f64,
191}
192
193impl MultiGateResidualConfig {
194 /// Allocate one layer's MGR parameters (`w⁽ᵝ⁾`, `w⁽ᵅ⁾` zero; `b⁽ᵝ⁾` constant).
195 pub fn init(&self, device: &Device) -> MultiGateResidual {
196 MultiGateResidual {
197 w_beta: Initializer::Zeros.init::<1, _>([self.d_model], device),
198 w_alpha: Initializer::Zeros.init::<1, _>([self.d_model], device),
199 b_beta: Param::from_tensor(Tensor::full([self.n_stream], self.init_bias, device)),
200 d_model: self.d_model,
201 n_stream: self.n_stream,
202 }
203 }
204}
205
206/// A stack of [`MultiGateResidual`]s for the enclosing
207/// [`Layers`](crate::modules::Layers). When `per_virtual` is `false` there is one
208/// module **per real layer** (virtual layers reuse them by real index); when
209/// `true` there is one **per virtual layer** (each virtual pass owns its own).
210#[derive(Module, Debug)]
211pub struct MultiGate {
212 /// The MGR modules: length `n_real_layers` (per-real) or `n_virtual_layers`
213 /// (per-virtual) — see [`Self::per_virtual`].
214 pub layers: Vec<MultiGateResidual>,
215 /// Number of parallel residual streams `n`.
216 #[module(skip)]
217 pub n_stream: usize,
218 /// `true` ⇒ one MGR per *virtual* layer (indexed by virtual position);
219 /// `false` ⇒ one per *real* layer (reused across virtual passes by real index).
220 #[module(skip)]
221 pub per_virtual: bool,
222}
223
224impl MultiGate {
225 /// Index into [`Self::layers`] for a given `(virtual_idx, real_idx)` layer
226 /// position: the virtual index when each virtual layer owns its MGR
227 /// ([`Self::per_virtual`]), otherwise the real index.
228 pub fn module_index(&self, virtual_idx: usize, real_idx: usize) -> usize {
229 if self.per_virtual {
230 virtual_idx
231 } else {
232 real_idx
233 }
234 }
235}
236
237/// How a [`Layers`](crate::modules::Layers) stack threads residuals between
238/// layers: the plain additive skip, or Multi-Gate Residuals.
239#[derive(Module, Debug)]
240pub enum Residuals {
241 /// Plain Pre-LN additive residual — each [`Layer`](crate::modules::Layer)
242 /// adds its own skip connection.
243 Standard(NoOp),
244 /// Multi-Gate Residuals: `n_stream` parallel streams with per-layer gated
245 /// mixing + attention pooling.
246 MultiGate(MultiGate),
247}
248
249/// Configuration / factory for [`Residuals`].
250#[derive(Config, Debug)]
251pub enum ResidualsConfig {
252 /// Plain additive Pre-LN residual.
253 Standard,
254 /// Multi-Gate Residuals over `n_stream` streams.
255 MultiGate {
256 /// Number of parallel residual streams `n`.
257 n_stream: usize,
258 /// Initial gate bias (see [`MultiGateResidualConfig::init_bias`]).
259 init_bias: f64,
260 /// `true` ⇒ one MGR per *virtual* layer; `false` ⇒ one per *real* layer
261 /// (reused across virtual passes). See [`MultiGate::per_virtual`].
262 per_virtual_layer: bool,
263 },
264}
265
266impl ResidualsConfig {
267 /// Build the runtime [`Residuals`] for a stack of `n_real_layers` real weight
268 /// sets unrolled over `n_virtual_layers` (virtual) passes. The MGR module
269 /// count follows `per_virtual_layer` (one per virtual layer vs one per real
270 /// layer).
271 pub fn init(
272 &self,
273 d_model: usize,
274 n_real_layers: usize,
275 n_virtual_layers: usize,
276 device: &Device,
277 ) -> Residuals {
278 match self {
279 ResidualsConfig::Standard => Residuals::Standard(NoOp),
280 ResidualsConfig::MultiGate {
281 n_stream,
282 init_bias,
283 per_virtual_layer,
284 } => {
285 let count = if *per_virtual_layer {
286 n_virtual_layers
287 } else {
288 n_real_layers
289 };
290 let layers = (0..count)
291 .map(|_| {
292 MultiGateResidualConfig::new(d_model, *n_stream)
293 .with_init_bias(*init_bias)
294 .init(device)
295 })
296 .collect();
297 Residuals::MultiGate(MultiGate {
298 layers,
299 n_stream: *n_stream,
300 per_virtual: *per_virtual_layer,
301 })
302 }
303 }
304 }
305}
306
307#[cfg(test)]
308mod tests;