Skip to main content

Module multi_gate

Module multi_gate 

Source
Expand description

Multi-Gate Residuals: multi-stream gated depth-wise residuals (Residuals). Multi-Gate Residuals (MGR) — a depth-wise residual scheme replacing the plain additive skip of a Layers stack.

Instead of one residual stream, MGR keeps n_stream parallel streams sᵢ (all seeded from the stack input). Between layers, one MultiGateResidual per layer does two convex, norm-bounded operations (paper §“Our Architecture”):

  1. Mixer (independent sigmoid gate) — each stream is interpolated towards the current layer output F_l by a per-stream gate βᵢ: sᵢ' = (1−βᵢ)·sᵢ + βᵢ·F_l, with βᵢ = σ( (w⁽ᵝ⁾ · RMSNorm(sᵢ))/√d + b⁽ᵝ⁾ᵢ ).
  2. Aggregator (depth-wise attention pooling, “AttnPool”) — the updated streams are pooled into the next layer’s input h by a softmax over streams: αᵢ = softmax_i( (w⁽ᵅ⁾ · RMSNorm(sᵢ'))/√d ), h = Σᵢ αᵢ·sᵢ'.

Both w vectors are learnable in ℝ^d (init zero), the RMSNorm is parameter-free, and b⁽ᵝ⁾ is a per-stream learnable bias. Only the independent (sigmoid) gate is implemented; the paper’s competitive (softmax) variant is omitted.

MGR is purely point-wise over (batch, sequence) — the streams only evolve along depth, never along the sequence — so forward over a sequence equals step unrolled token-by-token, and step carries no extra state (each token rebuilds its own depth-streams).

Gate-bias initialisation. Following Highway Networks, a negative init_bias biases the gates towards carry (small updates) at the start of training. The paper scales it with depth L: b_init = ln( √(L/L_base)·(exp(−b_base)+1) − n ) (with L_base = 21, b_base = −3); here init_bias is taken directly so the caller may apply that formula. Default 0 (gates open at σ(0)=0.5).

Structs§

MultiGate
A stack of MultiGateResiduals for the enclosing Layers. When per_virtual is false there is one module per real layer (virtual layers reuse them by real index); when true there is one per virtual layer (each virtual pass owns its own).
MultiGateRecord
The record type for the module.
MultiGateRecordItem
The record item type for the module.
MultiGateResidual
One layer’s Multi-Gate Residual parameters: the mixer query w⁽ᵝ⁾ + bias b⁽ᵝ⁾, and the aggregator (AttnPool) query w⁽ᵅ⁾.
MultiGateResidualConfig
Configuration for a single MultiGateResidual.
MultiGateResidualRecord
The record type for the module.
MultiGateResidualRecordItem
The record item type for the module.

Enums§

Residuals
How a Layers stack threads residuals between layers: the plain additive skip, or Multi-Gate Residuals.
ResidualsConfig
Configuration / factory for Residuals.
ResidualsConfigSerde 🔒
ResidualsRecord
The record type for the module.
ResidualsRecordItem
The record item type for the module.