Skip to main content

Module rms_norm

Module rms_norm 

Source
Expand description

Root-mean-square normalisation (last-dim, fp16-safe); also the Mamba-3 QK-Norm. Root-mean-square normalisation over the last dimension.

RMSNorm(x) = x / rms(x) · γ where rms(x) = √(mean(x²)). Unlike LayerNorm there is no mean-subtraction or bias — only a learnable per-channel scale γ. Used both as the Pre-LN of every residual block and, in Mamba-3, as the QK-Norm applied to the B/C projections.

The fp16 path avoids forming directly (which overflows for moderately large activations, e.g. 256·256): it first normalises against max(|x|) so the squared values stay ≤ 1, then rescales. See rms_norm_gated for the SiLU-gated variant.

Structs§

RmsNorm
Applies RMS normalisation over an input tensor along the last dimension: y = x / √(mean(x²)) · γ.
RmsNormConfig
Configuration to create a RmsNorm layer.
RmsNormRecord
The record type for the module.
RmsNormRecordItem
The record item type for the module.