Skip to main content

wrap_angle

Function wrap_angle 

Source
pub fn wrap_angle<const D: usize>(angles: Tensor<D>) -> Tensor<D>
Expand description

Reduce angles modulo into [−π, π], leaving the autodiff graph intact.

sin/cos are -periodic, so subtracting an integer multiple of is value-exact. Keeping |angle| ≤ π preserves precision in low-bit floats — roughly half of f16’s representable values lie in |x| ≤ 1, and the periodic sin/cos only lose accuracy when the argument is allowed to drift to large magnitudes. The same applies to the cumulative angle accumulator, which would otherwise grow without bound across a long sequence / many decode steps.

The integer multiple k is detached, so it is a constant with respect to autodiff: d/dx (x − k·2π) = 1, i.e. the backward pass is identical to the un-wrapped angle. This mirrors the detached max rescaling in RmsNormGated.