Skip to main content

apply_rope

Function apply_rope 

Source
pub fn apply_rope<const D: usize>(
    x: Tensor<D>,
    angles: Tensor<D>,
    rotate_pairwise: bool,
) -> Tensor<D>
Expand description

Apply rotary position embeddings to x along its last dimension.

Two pairing conventions are supported, selected by rotate_pairwise:

  • rotate_pairwise = trueinterleaved (NeoX / Triton style): adjacent pairs (0,1), (2,3), … are rotated together. Used by the SISO Triton kernel (mamba3_siso_*.py).
  • rotate_pairwise = falsehalf-and-half (GPT-J style): position n is paired with n + state_rank/2. Used by the MIMO Tilelang kernel (mamba3_mimo_fwd.py).

Reference: mamba3.py:335 sets rotate_pairwise = not self.is_mimo.

§Shapes

  • x: [..., state_rank] where state_rank is even
  • angles: [..., state_rank / 2] (one angle per pair)
  • output: same shape as x