pub struct MultiGateResidual {
pub w_beta: Param<Tensor<1>>,
pub w_alpha: Param<Tensor<1>>,
pub b_beta: Param<Tensor<1>>,
pub d_model: usize,
pub n_stream: usize,
}Expand description
One layer’s Multi-Gate Residual parameters: the mixer query w⁽ᵝ⁾ + bias
b⁽ᵝ⁾, and the aggregator (AttnPool) query w⁽ᵅ⁾.
Fields§
§w_beta: Param<Tensor<1>>Mixer query w⁽ᵝ⁾ ∈ ℝ^d (the per-stream sigmoid gate), [d_model].
w_alpha: Param<Tensor<1>>Aggregator query w⁽ᵅ⁾ ∈ ℝ^d (the AttnPool softmax), [d_model].
b_beta: Param<Tensor<1>>Per-stream mixer gate bias b⁽ᵝ⁾, [n_stream].
d_model: usizeModel width d.
n_stream: usizeNumber of parallel residual streams n.
Implementations§
Source§impl MultiGateResidual
impl MultiGateResidual
fn scale(&self) -> f32
Sourcefn rms_denom<const D: usize>(&self, x: Tensor<D>) -> Tensor<D>
fn rms_denom<const D: usize>(&self, x: Tensor<D>) -> Tensor<D>
The parameter-free RMS denominator d(x) ∈ [‥, 1] such that the RMSNorm
(matching RmsNorm math with γ ≡ 1) is x / d(x). Returning the
denominator rather than the normalised tensor lets Self::normed_score
fold it out of the (feature-axis) score reduction, so the full-width
normalised tensor is never built. The fp16 path keeps the same
overflow-safe max-rescale, folded into the same scalar denominator.
Sourcefn normed_score<const R: usize>(&self, x: Tensor<R>, w: Tensor<R>) -> Tensor<R>
fn normed_score<const R: usize>(&self, x: Tensor<R>, w: Tensor<R>) -> Tensor<R>
The RMSNorm-then-dot score scale · Σ_feat(x · w) / (rms(x)+eps),
shape [‥, 1]. The RMS denominator is constant over the feature axis, so
it is folded out of the reduction (via Self::rms_denom) — equal to
Σ_feat(rms_norm(x) · w) · scale but without materialising the full-width
normalised tensor.
Sourcefn mix_pool<const R: usize>(
&self,
layer_output: Tensor<R>,
streams: Tensor<R>,
) -> (Tensor<R>, Tensor<R>)
fn mix_pool<const R: usize>( &self, layer_output: Tensor<R>, streams: Tensor<R>, ) -> (Tensor<R>, Tensor<R>)
The shared mix + pool, generic over the streams rank R (the stream
axis is R-2, the feature axis R-1). Self::forward (R = 4) and
Self::step (R = 3) only differ by that rank, so both lift their
layer_output to a singleton stream axis, call this, and drop it again.
All reductions keep their axis (size 1) for broadcasting, so scores/gates
are […, n_stream, 1] throughout.
layer_output:F_llifted to a unit stream axis,[…, 1, d_model]streams: then_streamresidual streams,[…, n_stream, d_model]
Returns (h, streams') with h still carrying its unit stream axis
([…, 1, d_model]) and streams' the same shape as streams.
Sourcepub fn forward(
&self,
layer_output: Tensor<3>,
streams: Tensor<4>,
) -> (Tensor<3>, Tensor<4>)
pub fn forward( &self, layer_output: Tensor<3>, streams: Tensor<4>, ) -> (Tensor<3>, Tensor<4>)
Full-sequence mix + pool.
layer_output: this layer’s transformF_l,[batch, sequence, d_model]streams: then_streamresidual streams,[batch, sequence, n_stream, d_model]
Returns (h, streams'): the pooled input h for the next layer
([batch, sequence, d_model]) and the updated streams (same shape as in).
Sourcepub fn step(
&self,
layer_output: Tensor<2>,
streams: Tensor<3>,
) -> (Tensor<2>, Tensor<3>)
pub fn step( &self, layer_output: Tensor<2>, streams: Tensor<3>, ) -> (Tensor<2>, Tensor<3>)
Single-token mix + pool (the Self::forward math with the sequence axis
dropped).
layer_output:[batch, d_model]streams:[batch, n_stream, d_model]
Returns (h, streams'): [batch, d_model] and [batch, n_stream, d_model].