Skip to main content

burn_mamba/modules/
multi_gate.rs

1//! Multi-Gate Residuals (MGR) — a depth-wise residual scheme replacing the plain
2//! additive skip of a [`Layers`](crate::modules::Layers) stack.
3//!
4//! Instead of one residual stream, MGR keeps **`n_stream` parallel streams**
5//! `sᵢ` (all seeded from the stack input). Between layers, one
6//! [`MultiGateResidual`] per layer does two convex, norm-bounded operations
7//! (paper §"Our Architecture"):
8//!
9//! 1. **Mixer** (independent sigmoid gate) — each stream is interpolated towards
10//!    the current layer output `F_l` by a per-stream gate `βᵢ`:
11//!    `sᵢ' = (1−βᵢ)·sᵢ + βᵢ·F_l`, with
12//!    `βᵢ = σ( (w⁽ᵝ⁾ · RMSNorm(sᵢ))/√d + b⁽ᵝ⁾ᵢ )`.
13//! 2. **Aggregator** (depth-wise attention pooling, "AttnPool") — the updated
14//!    streams are pooled into the next layer's input `h` by a softmax over
15//!    streams: `αᵢ = softmax_i( (w⁽ᵅ⁾ · RMSNorm(sᵢ'))/√d )`, `h = Σᵢ αᵢ·sᵢ'`.
16//!
17//! Both `w` vectors are learnable in `ℝ^d` (init zero), the RMSNorm is
18//! parameter-free, and `b⁽ᵝ⁾` is a per-stream learnable bias. Only the
19//! **independent** (sigmoid) gate is implemented; the paper's competitive
20//! (softmax) variant is omitted.
21//!
22//! MGR is purely **point-wise over `(batch, sequence)`** — the streams only
23//! evolve along *depth*, never along the sequence — so `forward` over a sequence
24//! equals `step` unrolled token-by-token, and `step` carries no extra state
25//! (each token rebuilds its own depth-streams).
26//!
27//! **Gate-bias initialisation.** Following Highway Networks, a negative
28//! `init_bias` biases the gates towards *carry* (small updates) at the start of
29//! training. The paper scales it with depth `L`:
30//! `b_init = ln( √(L/L_base)·(exp(−b_base)+1) − n )` (with `L_base = 21`,
31//! `b_base = −3`); here `init_bias` is taken directly so the caller may apply
32//! that formula. Default `0` (gates open at `σ(0)=0.5`).
33
34use crate::modules::bidi::NoOp;
35use crate::utils::div_eps;
36use burn::config::Config;
37use burn::module::Param;
38use burn::nn::Initializer;
39use burn::prelude::*;
40use burn::tensor::activation::{sigmoid, softmax};
41use burn::tensor::{DType, f16};
42
43/// One layer's Multi-Gate Residual parameters: the mixer query `w⁽ᵝ⁾` + bias
44/// `b⁽ᵝ⁾`, and the aggregator (AttnPool) query `w⁽ᵅ⁾`.
45#[derive(Module, Debug)]
46pub struct MultiGateResidual {
47    /// Mixer query `w⁽ᵝ⁾ ∈ ℝ^d` (the per-stream sigmoid gate), `[d_model]`.
48    pub w_beta: Param<Tensor<1>>,
49    /// Aggregator query `w⁽ᵅ⁾ ∈ ℝ^d` (the AttnPool softmax), `[d_model]`.
50    pub w_alpha: Param<Tensor<1>>,
51    /// Per-stream mixer gate bias `b⁽ᵝ⁾`, `[n_stream]`.
52    pub b_beta: Param<Tensor<1>>,
53    /// Model width `d`.
54    #[module(skip)]
55    pub d_model: usize,
56    /// Number of parallel residual streams `n`.
57    #[module(skip)]
58    pub n_stream: usize,
59}
60
61impl MultiGateResidual {
62    fn scale(&self) -> f32 {
63        (self.d_model as f32).powf(-0.5)
64    }
65
66    /// The parameter-free RMS denominator `d(x) ∈ [‥, 1]` such that the RMSNorm
67    /// (matching [`RmsNorm`] math with `γ ≡ 1`) is `x / d(x)`. Returning the
68    /// denominator rather than the normalised tensor lets [`Self::normed_score`]
69    /// fold it out of the (feature-axis) score reduction, so the full-width
70    /// normalised tensor is never built. The fp16 path keeps the same
71    /// overflow-safe max-rescale, folded into the same scalar denominator.
72    ///
73    /// [`RmsNorm`]: crate::modules::RmsNorm
74    fn rms_denom<const D: usize>(&self, x: Tensor<D>) -> Tensor<D> {
75        match x.dtype() {
76            DType::F64 | DType::F32 | DType::Flex32 | DType::BF16 => {
77                let eps = div_eps(x.dtype());
78                (x.clone() * x).mean_dim(D - 1).sqrt() + eps
79            }
80            DType::F16 => {
81                use burn::tensor::ElementConversion;
82                let eps: f16 = f16::from_elem(div_eps(x.dtype())) * f16::from_f32(2.);
83                // Single global scalar `max`, reshaped to `[1; D]` so it
84                // broadcasts against the `[‥, 1]` partial RMS.
85                let max = x.clone().no_grad().detach().abs().max().reshape([1; D]);
86                let x_ = x.clone() / (max.clone() + eps); // x_.abs() <= 1
87                let rms_partial = (x.clone() * x_).mean_dim(D - 1).sqrt();
88                (rms_partial + eps) * max.sqrt()
89            }
90            _ => unreachable!("rms_denom expects a float dtype"),
91        }
92    }
93
94    /// The RMSNorm-then-dot score `scale · Σ_feat(x · w) / (rms(x)+eps)`,
95    /// shape `[‥, 1]`. The RMS denominator is constant over the feature axis, so
96    /// it is folded out of the reduction (via [`Self::rms_denom`]) — equal to
97    /// `Σ_feat(rms_norm(x) · w) · scale` but without materialising the full-width
98    /// normalised tensor.
99    fn normed_score<const R: usize>(&self, x: Tensor<R>, w: Tensor<R>) -> Tensor<R> {
100        let dot = (x.clone() * w).sum_dim(R - 1);
101        dot * self.scale() / self.rms_denom(x)
102    }
103
104    /// The shared mix + pool, generic over the streams rank `R` (the *stream*
105    /// axis is `R-2`, the *feature* axis `R-1`). [`Self::forward`] (`R = 4`) and
106    /// [`Self::step`] (`R = 3`) only differ by that rank, so both lift their
107    /// `layer_output` to a singleton stream axis, call this, and drop it again.
108    /// All reductions keep their axis (size 1) for broadcasting, so scores/gates
109    /// are `[…, n_stream, 1]` throughout.
110    ///
111    /// - `layer_output`: `F_l` lifted to a unit stream axis, `[…, 1, d_model]`
112    /// - `streams`: the `n_stream` residual streams, `[…, n_stream, d_model]`
113    ///
114    /// Returns `(h, streams')` with `h` still carrying its unit stream axis
115    /// (`[…, 1, d_model]`) and `streams'` the same shape as `streams`.
116    fn mix_pool<const R: usize>(
117        &self,
118        layer_output: Tensor<R>,
119        streams: Tensor<R>,
120    ) -> (Tensor<R>, Tensor<R>) {
121        let dims = streams.dims();
122        let (stream_axis, feat_axis) = (R - 2, R - 1);
123        assert_eq!(
124            dims[feat_axis], self.d_model,
125            "stream width must equal d_model"
126        );
127        assert_eq!(
128            dims[stream_axis], self.n_stream,
129            "stream count must equal n_stream"
130        );
131
132        // `b_beta` reshaped to broadcast on the stream axis: `[1, …, n_stream, 1]`.
133        let mut bias_shape = [1usize; R];
134        bias_shape[stream_axis] = self.n_stream;
135        let b_beta = self.b_beta.val().reshape(bias_shape);
136        // The query vectors broadcast on the feature axis: `[1, …, 1, d_model]`.
137        let w_beta = self.w_beta.val().unsqueeze::<R>();
138        let w_alpha = self.w_alpha.val().unsqueeze::<R>();
139
140        // Mixer: independent per-stream sigmoid gate, `β`: `[…, n_stream, 1]`.
141        let beta = sigmoid(self.normed_score(streams.clone(), w_beta) + b_beta);
142        // Lerp `(1−β)·streams + β·layer_output` (equal to the paper's
143        // `streams + β·(layer_output − streams)`) — written so no full-width
144        // intermediate is retained: `streams` is the already-saved input and
145        // `layer_output` is `[…, 1, d_model]`, so neither `mul` saves a new
146        // `[…, n_stream, d_model]` tensor and the `+` saves nothing.
147        let new_streams = streams * (-beta.clone() + 1.0) + layer_output * beta;
148
149        // Aggregator: depth-wise attention pooling (softmax over the stream axis).
150        let score_ns = self.normed_score(new_streams.clone(), w_alpha);
151        let alpha = softmax(score_ns, stream_axis);
152        let new_h = (alpha * new_streams.clone()).sum_dim(stream_axis);
153        (new_h, new_streams)
154    }
155
156    /// Full-sequence mix + pool.
157    ///
158    /// - `layer_output`: this layer's transform `F_l`, `[batch, sequence, d_model]`
159    /// - `streams`: the `n_stream` residual streams, `[batch, sequence, n_stream, d_model]`
160    ///
161    /// Returns `(h, streams')`: the pooled input `h` for the next layer
162    /// (`[batch, sequence, d_model]`) and the updated streams (same shape as in).
163    pub fn forward(&self, layer_output: Tensor<3>, streams: Tensor<4>) -> (Tensor<3>, Tensor<4>) {
164        let (new_h, new_streams) = self.mix_pool::<4>(layer_output.unsqueeze_dim(2), streams);
165        (new_h.squeeze_dim(2), new_streams)
166    }
167
168    /// Single-token mix + pool (the [`Self::forward`] math with the sequence axis
169    /// dropped).
170    ///
171    /// - `layer_output`: `[batch, d_model]`
172    /// - `streams`: `[batch, n_stream, d_model]`
173    ///
174    /// Returns `(h, streams')`: `[batch, d_model]` and `[batch, n_stream, d_model]`.
175    pub fn step(&self, layer_output: Tensor<2>, streams: Tensor<3>) -> (Tensor<2>, Tensor<3>) {
176        let (new_h, new_streams) = self.mix_pool::<3>(layer_output.unsqueeze_dim(1), streams);
177        (new_h.squeeze_dim(1), new_streams)
178    }
179}
180
181/// Configuration for a single [`MultiGateResidual`].
182#[derive(Config, Debug)]
183pub struct MultiGateResidualConfig {
184    /// Model width `d`.
185    pub d_model: usize,
186    /// Number of parallel residual streams `n`.
187    pub n_stream: usize,
188    /// Initial value for every entry of the gate bias `b⁽ᵝ⁾` (see module header).
189    #[config(default = 0.0)]
190    pub init_bias: f64,
191}
192
193impl MultiGateResidualConfig {
194    /// Allocate one layer's MGR parameters (`w⁽ᵝ⁾`, `w⁽ᵅ⁾` zero; `b⁽ᵝ⁾` constant).
195    pub fn init(&self, device: &Device) -> MultiGateResidual {
196        MultiGateResidual {
197            w_beta: Initializer::Zeros.init::<1, _>([self.d_model], device),
198            w_alpha: Initializer::Zeros.init::<1, _>([self.d_model], device),
199            b_beta: Param::from_tensor(Tensor::full([self.n_stream], self.init_bias, device)),
200            d_model: self.d_model,
201            n_stream: self.n_stream,
202        }
203    }
204}
205
206/// A stack of [`MultiGateResidual`]s for the enclosing
207/// [`Layers`](crate::modules::Layers). When `per_virtual` is `false` there is one
208/// module **per real layer** (virtual layers reuse them by real index); when
209/// `true` there is one **per virtual layer** (each virtual pass owns its own).
210#[derive(Module, Debug)]
211pub struct MultiGate {
212    /// The MGR modules: length `n_real_layers` (per-real) or `n_virtual_layers`
213    /// (per-virtual) — see [`Self::per_virtual`].
214    pub layers: Vec<MultiGateResidual>,
215    /// Number of parallel residual streams `n`.
216    #[module(skip)]
217    pub n_stream: usize,
218    /// `true` ⇒ one MGR per *virtual* layer (indexed by virtual position);
219    /// `false` ⇒ one per *real* layer (reused across virtual passes by real index).
220    #[module(skip)]
221    pub per_virtual: bool,
222}
223
224impl MultiGate {
225    /// Index into [`Self::layers`] for a given `(virtual_idx, real_idx)` layer
226    /// position: the virtual index when each virtual layer owns its MGR
227    /// ([`Self::per_virtual`]), otherwise the real index.
228    pub fn module_index(&self, virtual_idx: usize, real_idx: usize) -> usize {
229        if self.per_virtual {
230            virtual_idx
231        } else {
232            real_idx
233        }
234    }
235}
236
237/// How a [`Layers`](crate::modules::Layers) stack threads residuals between
238/// layers: the plain additive skip, or Multi-Gate Residuals.
239#[derive(Module, Debug)]
240pub enum Residuals {
241    /// Plain Pre-LN additive residual — each [`Layer`](crate::modules::Layer)
242    /// adds its own skip connection.
243    Standard(NoOp),
244    /// Multi-Gate Residuals: `n_stream` parallel streams with per-layer gated
245    /// mixing + attention pooling.
246    MultiGate(MultiGate),
247}
248
249/// Configuration / factory for [`Residuals`].
250#[derive(Config, Debug)]
251pub enum ResidualsConfig {
252    /// Plain additive Pre-LN residual.
253    Standard,
254    /// Multi-Gate Residuals over `n_stream` streams.
255    MultiGate {
256        /// Number of parallel residual streams `n`.
257        n_stream: usize,
258        /// Initial gate bias (see [`MultiGateResidualConfig::init_bias`]).
259        init_bias: f64,
260        /// `true` ⇒ one MGR per *virtual* layer; `false` ⇒ one per *real* layer
261        /// (reused across virtual passes). See [`MultiGate::per_virtual`].
262        per_virtual_layer: bool,
263    },
264}
265
266impl ResidualsConfig {
267    /// Build the runtime [`Residuals`] for a stack of `n_real_layers` real weight
268    /// sets unrolled over `n_virtual_layers` (virtual) passes. The MGR module
269    /// count follows `per_virtual_layer` (one per virtual layer vs one per real
270    /// layer).
271    pub fn init(
272        &self,
273        d_model: usize,
274        n_real_layers: usize,
275        n_virtual_layers: usize,
276        device: &Device,
277    ) -> Residuals {
278        match self {
279            ResidualsConfig::Standard => Residuals::Standard(NoOp),
280            ResidualsConfig::MultiGate {
281                n_stream,
282                init_bias,
283                per_virtual_layer,
284            } => {
285                let count = if *per_virtual_layer {
286                    n_virtual_layers
287                } else {
288                    n_real_layers
289                };
290                let layers = (0..count)
291                    .map(|_| {
292                        MultiGateResidualConfig::new(d_model, *n_stream)
293                            .with_init_bias(*init_bias)
294                            .init(device)
295                    })
296                    .collect();
297                Residuals::MultiGate(MultiGate {
298                    layers,
299                    n_stream: *n_stream,
300                    per_virtual: *per_virtual_layer,
301                })
302            }
303        }
304    }
305}
306
307#[cfg(test)]
308mod tests;