Skip to main content

burn_mamba/modules/misc/
rope.rs

1use crate::mamba3::double_ssd::prelude::*;
2use crate::mamba3::helpers;
3use crate::mamba3::prelude::*;
4use crate::mamba3::rotation::{RotationState, rotate_bc_forward, rotate_bc_step};
5use crate::modules::Silu;
6use crate::modules::sanity as san;
7use burn::prelude::*;
8
9// TODO: move to mamba3/rotation/mod.rs
10
11// ---------------------------------------------------------------------------
12// RoPE utility
13// ---------------------------------------------------------------------------
14
15/// Reduce angles modulo `2π` into `[−π, π]`, leaving the autodiff graph intact.
16///
17/// `sin`/`cos` are `2π`-periodic, so subtracting an integer multiple of `2π` is
18/// value-exact. Keeping `|angle| ≤ π` preserves precision in low-bit floats —
19/// roughly half of `f16`'s representable values lie in `|x| ≤ 1`, and the
20/// periodic `sin`/`cos` only lose accuracy when the argument is allowed to drift
21/// to large magnitudes. The same applies to the cumulative angle accumulator,
22/// which would otherwise grow without bound across a long sequence / many decode
23/// steps.
24///
25/// The integer multiple `k` is `detach`ed, so it is a constant with respect to
26/// autodiff: `d/dx (x − k·2π) = 1`, i.e. the backward pass is identical to the
27/// un-wrapped angle. This mirrors the detached `max` rescaling in
28/// [`RmsNormGated`](crate::utils::rms_norm_gated::RmsNormGated).
29pub fn wrap_angle<const D: usize>(angles: Tensor<D>) -> Tensor<D> {
30    let two_pi = 2.0 * std::f32::consts::PI;
31    let k = (angles.clone().detach() * (1.0f32 / two_pi)).round();
32    angles - k * two_pi
33}
34
35/// Apply rotary position embeddings to `x` along its last dimension.
36///
37/// Two pairing conventions are supported, selected by `rotate_pairwise`:
38///
39/// - `rotate_pairwise = true` — **interleaved** (NeoX / Triton style): adjacent
40///   pairs `(0,1)`, `(2,3)`, … are rotated together. Used by the SISO Triton
41///   kernel (`mamba3_siso_*.py`).
42/// - `rotate_pairwise = false` — **half-and-half** (GPT-J style): position `n`
43///   is paired with `n + state_rank/2`. Used by the MIMO Tilelang kernel
44///   (`mamba3_mimo_fwd.py`).
45///
46/// Reference: `mamba3.py:335` sets `rotate_pairwise = not self.is_mimo`.
47///
48/// # Shapes
49/// - `x`:      `[..., state_rank]` where `state_rank` is even
50/// - `angles`: `[..., state_rank / 2]`  (one angle per pair)
51/// - output:   same shape as `x`
52pub fn apply_rope<const D: usize>(
53    x: Tensor<D>,
54    angles: Tensor<D>,
55    rotate_pairwise: bool,
56) -> Tensor<D> {
57    let dims = x.dims();
58    let n = dims[D - 1];
59    let n2 = n / 2;
60    let leading: usize = dims[..D - 1].iter().product();
61
62    let angles_flat = wrap_angle(angles.reshape([leading, n2]));
63    let cos = angles_flat.clone().cos();
64    let sin = angles_flat.sin();
65
66    if rotate_pairwise {
67        // Interleaved: reshape to [leading, n2, 2], pairs along last axis.
68        let x_pairs = x.reshape([leading, n2, 2]);
69        let x0 = x_pairs.clone().narrow(2, 0, 1).squeeze_dim(2);
70        let x1 = x_pairs.narrow(2, 1, 1).squeeze_dim(2);
71
72        let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
73        let x1r = sin * x0 + cos * x1;
74
75        Tensor::cat(
76            vec![x0r.unsqueeze_dim::<3>(2), x1r.unsqueeze_dim::<3>(2)],
77            2,
78        )
79        .reshape(dims)
80    } else {
81        // Half-and-half: reshape to [leading, 2, n2], halves along middle axis.
82        let x_halves = x.reshape([leading, 2, n2]);
83        let x0 = x_halves.clone().narrow(1, 0, 1).squeeze_dim(1);
84        let x1 = x_halves.narrow(1, 1, 1).squeeze_dim(1);
85
86        let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
87        let x1r = sin * x0 + cos * x1;
88
89        Tensor::cat(
90            vec![x0r.unsqueeze_dim::<3>(1), x1r.unsqueeze_dim::<3>(1)],
91            1,
92        )
93        .reshape(dims)
94    }
95}
96
97/// Apply RoPE to only the rotation-active entries of the last dimension; the
98/// remainder passes through unchanged. Falls back to [`apply_rope`] when
99/// `rope_dim == state_rank` (full RoPE), and is the **identity** when
100/// `rope_dim == 0` (RoPE disabled, `rope_fraction = 0`) — `angles` is ignored.
101///
102/// Pairing scheme (must match the reference kernels — see Section
103/// "Data-Dependent RoPE" in the paper, and `mamba3_siso_fwd.py` /
104/// `mamba3_mimo_fwd.py`):
105///
106/// - `rotate_pairwise = true` (SISO, interleaved/NeoX): pairs `(0,1), (2,3), …`.
107///   Only pairs `0..num_rope_angles` are rotated; pairs beyond are passed
108///   through. Equivalent to slicing the first `rope_dim` entries and rotating
109///   them.
110/// - `rotate_pairwise = false` (MIMO, half-and-half/GPT-J): pair distance is
111///   always `state_rank/2`, i.e. element `n` is paired with element
112///   `state_rank/2 + n`. With partial RoPE only the first `num_rope_angles`
113///   pairs are rotated; the remaining elements in both halves pass through.
114pub fn apply_rope_partial<const D: usize>(
115    x: Tensor<D>,
116    angles: Tensor<D>,
117    rope_dim: usize,
118    rotate_pairwise: bool,
119) -> Tensor<D> {
120    if rope_dim == 0 {
121        // RoPE disabled (rope_fraction = 0): identity. The upstream angle data
122        // flow is still computed and cached, but no rotation is applied. This
123        // also avoids zero-width narrows below (Burn has no zero-width tensors).
124        return x;
125    }
126
127    let state_rank = x.dims()[D - 1];
128    if rope_dim == state_rank {
129        return apply_rope::<D>(x, angles, rotate_pairwise);
130    }
131
132    if rotate_pairwise {
133        // Pairs are local — slicing the first rope_dim entries gives the same
134        // result as the reference (which rotates the whole headdim but with
135        // identity cos/sin for the tail pairs).
136        let x_rope = x.clone().narrow(D - 1, 0, rope_dim);
137        let x_rest = x.narrow(D - 1, rope_dim, state_rank - rope_dim);
138        let x_rope_rotated = apply_rope::<D>(x_rope, angles, true);
139        return Tensor::cat(vec![x_rope_rotated, x_rest], D - 1);
140    }
141
142    // Half-and-half partial RoPE: pair distance must be `state_rank/2`, not
143    // `rope_dim/2`. Slicing the first `rope_dim` entries and calling
144    // `apply_rope` would pair within the slice and produce the wrong rotation.
145    let half = state_rank / 2;
146    let num_rope_angles = rope_dim / 2;
147    debug_assert!(
148        num_rope_angles < half,
149        "partial RoPE requires rope_dim < state_rank here"
150    );
151
152    // Split x into the two halves, then within each half separate the
153    // rotation-active prefix from the pass-through suffix.
154    let x_h1 = x.clone().narrow(D - 1, 0, half);
155    let x_h2 = x.narrow(D - 1, half, half);
156    let x_h1_rope = x_h1.clone().narrow(D - 1, 0, num_rope_angles);
157    let x_h1_pass = x_h1.narrow(D - 1, num_rope_angles, half - num_rope_angles);
158    let x_h2_rope = x_h2.clone().narrow(D - 1, 0, num_rope_angles);
159    let x_h2_pass = x_h2.narrow(D - 1, num_rope_angles, half - num_rope_angles);
160
161    // angles: [..., num_rope_angles] — broadcasts element-wise against the rope-active slices.
162    let angles = wrap_angle(angles);
163    let cos = angles.clone().cos();
164    let sin = angles.sin();
165    let x_h1_rot = cos.clone() * x_h1_rope.clone() - sin.clone() * x_h2_rope.clone();
166    let x_h2_rot = sin * x_h1_rope + cos * x_h2_rope;
167
168    // Reassemble: [ first-half-rotated | first-half-passthrough | second-half-rotated | second-half-passthrough ]
169    let x_h1_out = Tensor::cat(vec![x_h1_rot, x_h1_pass], D - 1);
170    let x_h2_out = Tensor::cat(vec![x_h2_rot, x_h2_pass], D - 1);
171    Tensor::cat(vec![x_h1_out, x_h2_out], D - 1)
172}