burn_mamba/modules/misc/rope.rs
1use crate::mamba3::double_ssd::prelude::*;
2use crate::mamba3::helpers;
3use crate::mamba3::prelude::*;
4use crate::mamba3::rotation::{RotationState, rotate_bc_forward, rotate_bc_step};
5use crate::modules::Silu;
6use crate::modules::sanity as san;
7use burn::prelude::*;
8
9// TODO: move to mamba3/rotation/mod.rs
10
11// ---------------------------------------------------------------------------
12// RoPE utility
13// ---------------------------------------------------------------------------
14
15/// Reduce angles modulo `2π` into `[−π, π]`, leaving the autodiff graph intact.
16///
17/// `sin`/`cos` are `2π`-periodic, so subtracting an integer multiple of `2π` is
18/// value-exact. Keeping `|angle| ≤ π` preserves precision in low-bit floats —
19/// roughly half of `f16`'s representable values lie in `|x| ≤ 1`, and the
20/// periodic `sin`/`cos` only lose accuracy when the argument is allowed to drift
21/// to large magnitudes. The same applies to the cumulative angle accumulator,
22/// which would otherwise grow without bound across a long sequence / many decode
23/// steps.
24///
25/// The integer multiple `k` is `detach`ed, so it is a constant with respect to
26/// autodiff: `d/dx (x − k·2π) = 1`, i.e. the backward pass is identical to the
27/// un-wrapped angle. This mirrors the detached `max` rescaling in
28/// [`RmsNormGated`](crate::utils::rms_norm_gated::RmsNormGated).
29pub fn wrap_angle<const D: usize>(angles: Tensor<D>) -> Tensor<D> {
30 let two_pi = 2.0 * std::f32::consts::PI;
31 let k = (angles.clone().detach() * (1.0f32 / two_pi)).round();
32 angles - k * two_pi
33}
34
35/// Apply rotary position embeddings to `x` along its last dimension.
36///
37/// Two pairing conventions are supported, selected by `rotate_pairwise`:
38///
39/// - `rotate_pairwise = true` — **interleaved** (NeoX / Triton style): adjacent
40/// pairs `(0,1)`, `(2,3)`, … are rotated together. Used by the SISO Triton
41/// kernel (`mamba3_siso_*.py`).
42/// - `rotate_pairwise = false` — **half-and-half** (GPT-J style): position `n`
43/// is paired with `n + state_rank/2`. Used by the MIMO Tilelang kernel
44/// (`mamba3_mimo_fwd.py`).
45///
46/// Reference: `mamba3.py:335` sets `rotate_pairwise = not self.is_mimo`.
47///
48/// # Shapes
49/// - `x`: `[..., state_rank]` where `state_rank` is even
50/// - `angles`: `[..., state_rank / 2]` (one angle per pair)
51/// - output: same shape as `x`
52pub fn apply_rope<const D: usize>(
53 x: Tensor<D>,
54 angles: Tensor<D>,
55 rotate_pairwise: bool,
56) -> Tensor<D> {
57 let dims = x.dims();
58 let n = dims[D - 1];
59 let n2 = n / 2;
60 let leading: usize = dims[..D - 1].iter().product();
61
62 let angles_flat = wrap_angle(angles.reshape([leading, n2]));
63 let cos = angles_flat.clone().cos();
64 let sin = angles_flat.sin();
65
66 if rotate_pairwise {
67 // Interleaved: reshape to [leading, n2, 2], pairs along last axis.
68 let x_pairs = x.reshape([leading, n2, 2]);
69 let x0 = x_pairs.clone().narrow(2, 0, 1).squeeze_dim(2);
70 let x1 = x_pairs.narrow(2, 1, 1).squeeze_dim(2);
71
72 let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
73 let x1r = sin * x0 + cos * x1;
74
75 Tensor::cat(
76 vec![x0r.unsqueeze_dim::<3>(2), x1r.unsqueeze_dim::<3>(2)],
77 2,
78 )
79 .reshape(dims)
80 } else {
81 // Half-and-half: reshape to [leading, 2, n2], halves along middle axis.
82 let x_halves = x.reshape([leading, 2, n2]);
83 let x0 = x_halves.clone().narrow(1, 0, 1).squeeze_dim(1);
84 let x1 = x_halves.narrow(1, 1, 1).squeeze_dim(1);
85
86 let x0r = cos.clone() * x0.clone() - sin.clone() * x1.clone();
87 let x1r = sin * x0 + cos * x1;
88
89 Tensor::cat(
90 vec![x0r.unsqueeze_dim::<3>(1), x1r.unsqueeze_dim::<3>(1)],
91 1,
92 )
93 .reshape(dims)
94 }
95}
96
97/// Apply RoPE to only the rotation-active entries of the last dimension; the
98/// remainder passes through unchanged. Falls back to [`apply_rope`] when
99/// `rope_dim == state_rank` (full RoPE), and is the **identity** when
100/// `rope_dim == 0` (RoPE disabled, `rope_fraction = 0`) — `angles` is ignored.
101///
102/// Pairing scheme (must match the reference kernels — see Section
103/// "Data-Dependent RoPE" in the paper, and `mamba3_siso_fwd.py` /
104/// `mamba3_mimo_fwd.py`):
105///
106/// - `rotate_pairwise = true` (SISO, interleaved/NeoX): pairs `(0,1), (2,3), …`.
107/// Only pairs `0..num_rope_angles` are rotated; pairs beyond are passed
108/// through. Equivalent to slicing the first `rope_dim` entries and rotating
109/// them.
110/// - `rotate_pairwise = false` (MIMO, half-and-half/GPT-J): pair distance is
111/// always `state_rank/2`, i.e. element `n` is paired with element
112/// `state_rank/2 + n`. With partial RoPE only the first `num_rope_angles`
113/// pairs are rotated; the remaining elements in both halves pass through.
114pub fn apply_rope_partial<const D: usize>(
115 x: Tensor<D>,
116 angles: Tensor<D>,
117 rope_dim: usize,
118 rotate_pairwise: bool,
119) -> Tensor<D> {
120 if rope_dim == 0 {
121 // RoPE disabled (rope_fraction = 0): identity. The upstream angle data
122 // flow is still computed and cached, but no rotation is applied. This
123 // also avoids zero-width narrows below (Burn has no zero-width tensors).
124 return x;
125 }
126
127 let state_rank = x.dims()[D - 1];
128 if rope_dim == state_rank {
129 return apply_rope::<D>(x, angles, rotate_pairwise);
130 }
131
132 if rotate_pairwise {
133 // Pairs are local — slicing the first rope_dim entries gives the same
134 // result as the reference (which rotates the whole headdim but with
135 // identity cos/sin for the tail pairs).
136 let x_rope = x.clone().narrow(D - 1, 0, rope_dim);
137 let x_rest = x.narrow(D - 1, rope_dim, state_rank - rope_dim);
138 let x_rope_rotated = apply_rope::<D>(x_rope, angles, true);
139 return Tensor::cat(vec![x_rope_rotated, x_rest], D - 1);
140 }
141
142 // Half-and-half partial RoPE: pair distance must be `state_rank/2`, not
143 // `rope_dim/2`. Slicing the first `rope_dim` entries and calling
144 // `apply_rope` would pair within the slice and produce the wrong rotation.
145 let half = state_rank / 2;
146 let num_rope_angles = rope_dim / 2;
147 debug_assert!(
148 num_rope_angles < half,
149 "partial RoPE requires rope_dim < state_rank here"
150 );
151
152 // Split x into the two halves, then within each half separate the
153 // rotation-active prefix from the pass-through suffix.
154 let x_h1 = x.clone().narrow(D - 1, 0, half);
155 let x_h2 = x.narrow(D - 1, half, half);
156 let x_h1_rope = x_h1.clone().narrow(D - 1, 0, num_rope_angles);
157 let x_h1_pass = x_h1.narrow(D - 1, num_rope_angles, half - num_rope_angles);
158 let x_h2_rope = x_h2.clone().narrow(D - 1, 0, num_rope_angles);
159 let x_h2_pass = x_h2.narrow(D - 1, num_rope_angles, half - num_rope_angles);
160
161 // angles: [..., num_rope_angles] — broadcasts element-wise against the rope-active slices.
162 let angles = wrap_angle(angles);
163 let cos = angles.clone().cos();
164 let sin = angles.sin();
165 let x_h1_rot = cos.clone() * x_h1_rope.clone() - sin.clone() * x_h2_rope.clone();
166 let x_h2_rot = sin * x_h1_rope + cos * x_h2_rope;
167
168 // Reassemble: [ first-half-rotated | first-half-passthrough | second-half-rotated | second-half-passthrough ]
169 let x_h1_out = Tensor::cat(vec![x_h1_rot, x_h1_pass], D - 1);
170 let x_h2_out = Tensor::cat(vec![x_h2_rot, x_h2_pass], D - 1);
171 Tensor::cat(vec![x_h1_out, x_h2_out], D - 1)
172}