burn_mamba/mamba3/rotation/mod.rs
1//! # Quaternion (k=4) rotational state — the non-abelian generalisation of RoPE
2//!
3//! Mamba-3's data-dependent RoPE realises a **complex-valued** SSM: the state
4//! transition factors as a per-head scalar decay times a block-diagonal of
5//! `2×2` rotations (paper Prop. *Complex-to-Real SSM Equivalence*), and because
6//! `SO(2) ≅ U(1)` is **abelian** the cumulative rotation collapses to a
7//! `cumsum` of angles and is absorbed into `B`/`C` (the "RoPE trick", Prop.
8//! *Complex SSM, Data-Dependent RoPE Equivalence*). See
9//! [`crate::mamba3::double_ssd::double_ssd::apply_rope`].
10//!
11//! This module implements the next rung of the ladder: a **quaternion**
12//! (`k = 4`) rotational state, i.e. the transition's rotation lives in the
13//! left-isoclinic subgroup `SU(2) ⊂ SO(4)` instead of `SO(2)`. Unit
14//! quaternions under multiplication are `SU(2)`, which is **non-abelian** and
15//! contains non-solvable finite subgroups (the binary icosahedral group
16//! `2I = SL(2,5)`, a double cover of `A₅`). By Barrington's theorem this lifts
17//! the layer's reachable state-tracking from the solvable/`TC⁰` regime (parity,
18//! mod-k) toward `NC¹`, which abelian rotations provably cannot reach.
19//!
20//! ## What survives, what changes
21//!
22//! The key fact (derivable purely from telescoping + orthogonality, **without**
23//! commutativity — see the crate discussion) is that the RoPE *factoring*
24//! survives intact: with the **ordered** cumulative rotation
25//! `Pₜ = Rₜ Rₜ₋₁ ⋯ R₁`,
26//!
27//! ```text
28//! Cₜᵀ (Rₜ⋯Rᵢ₊₁) Bᵢ = (Pₜᵀ Cₜ)ᵀ (Pᵢᵀ Bᵢ) = C̄ₜᵀ B̄ᵢ ,
29//! ```
30//!
31//! so the scalar-decay SSD core (`L ⊙ C̄B̄ᵀ`) is **unchanged** — only the
32//! projections `B̄ᵢ = Pᵢᵀ Bᵢ`, `C̄ₜ = Pₜᵀ Cₜ` are rotated. What is lost is the
33//! closed-form `cumsum`: the cumulative rotation must be built by an
34//! **associative scan over the per-step quaternions** ([`quat_cumprod`]) rather
35//! than a sum of angles. Because a product of unit quaternions is again a unit
36//! quaternion, the scan stays exactly orthogonal (no drift, no `wrap_angle`
37//! needed), and the cross-chunk carry is a single quaternion per block/head —
38//! the exact analogue of `cum_angle` in the existing caches.
39//!
40//! `SO(2)` (today's `apply_rope`) is the abelian collapse: restricting each
41//! quaternion to a single fixed axis makes them commute and reduces
42//! [`quat_cumprod`] to a `cumsum` of half-angles (asserted in the tests).
43//!
44//! ## Pipeline (the `k = 4` instantiation of the rotation block)
45//!
46//! ```text
47//! per-step unit quaternion qₜ (materialise from the in-projection; caller)
48//! │ quat_cumprod (assoc. scan, + cross-chunk carry)
49//! ▼
50//! cumulative rotation Qₜ
51//! │ rotate_state_rank_blocks(B, conj(Qₜ)) , rotate_state_rank_blocks(C, conj(Qₜ))
52//! ▼
53//! B̄, C̄ ──► standard scalar-decay SSD (unchanged)
54//! ```
55//!
56//! Quaternion layout: the last axis has size 4 and holds `(w, x, y, z)` with
57//! `w` the real part. A `state_rank` of `r = 4·J` is treated as `J` independent
58//! quaternion blocks; the rotation acts within each block, exactly as RoPE acts
59//! within each `2`-pair. This module is a self-contained, tested reference for
60//! the math; wiring it into the [`Mamba3`](crate::mamba3::mamba3::Mamba3) block
61//! is a separate, larger change (the SSD kernels themselves need no edits).
62
63use crate::modules::{apply_rope_partial, wrap_angle};
64use burn::module::Module;
65use burn::prelude::*;
66
67// ---------------------------------------------------------------------------
68// Rotation kind (config switch) and cache accumulator variant
69// ---------------------------------------------------------------------------
70
71/// Which rotational-state algebra the block uses for the data-dependent
72/// positional rotation of `B`/`C`.
73///
74/// - [`Complex2D`](RotationKind::Complex2D) — the abelian `SO(2)`/complex RoPE
75/// that Mamba-3 ships: cumulative *angles* via `cumsum`, applied by
76/// [`apply_rope`]. The default; behaviourally unchanged.
77/// - [`Quaternion4D`](RotationKind::Quaternion4D) — the non-abelian
78/// `SU(2) ⊂ SO(4)` quaternion rotation of this module: cumulative *product*
79/// via [`quat_cumprod`], applied by [`rotate_state_rank_blocks`]. Richer
80/// state-tracking; selects the [`RotationState::Quaternion`] cache accumulator.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
82pub enum RotationKind {
83 /// Abelian complex (`SO(2)`) RoPE — the current default behaviour.
84 Complex2D,
85 /// Non-abelian quaternion (`SU(2)`) rotation.
86 Quaternion4D,
87}
88
89impl Default for RotationKind {
90 fn default() -> Self {
91 RotationKind::Complex2D
92 }
93}
94
95/// The cumulative-rotation accumulator carried between calls in a Mamba-3 cache
96/// — the variant matching the block's [`RotationKind`].
97///
98/// - [`Angle`](RotationState::Angle) — abelian per-pair cumulative RoPE angle,
99/// shape `[batch, nheads, num_rope_angles]` (today's `cum_angle`).
100/// - [`Quaternion`](RotationState::Quaternion) — per-block cumulative unit
101/// quaternion, shape `[batch, nheads, blocks, 4]`, produced by
102/// [`quat_cumprod`].
103///
104/// This is the cache-level counterpart of [`RotationKind`]. It is defined here
105/// (the rotation module owns the accumulator type); substituting it for the
106/// pathway caches' `cum_angle_bha` field happens together with the forward/step
107/// wiring that consumes it.
108#[derive(Module, Debug)]
109pub enum RotationState {
110 /// Abelian RoPE cumulative angle, shape `[batch, nheads, num_rope_angles]`.
111 Angle(Tensor<3>),
112 /// Quaternion cumulative rotation, shape `[batch, nheads, blocks, 4]`.
113 Quaternion(Tensor<4>),
114}
115
116impl RotationState {
117 /// Zero-initialised abelian angle accumulator `[batch, nheads, num_rope_angles]`.
118 pub fn zeros_angle(
119 batch: usize,
120 nheads: usize,
121 num_rope_angles: usize,
122 device: &Device,
123 ) -> Self {
124 RotationState::Angle(Tensor::zeros([batch, nheads, num_rope_angles], device))
125 }
126
127 /// Identity-initialised quaternion accumulator `[batch, nheads, blocks, 4]`
128 /// (every block is the identity quaternion `(1, 0, 0, 0)`).
129 pub fn identity_quaternion(
130 batch: usize,
131 nheads: usize,
132 blocks: usize,
133 device: &Device,
134 ) -> Self {
135 let w = Tensor::ones([batch, nheads, blocks, 1], device);
136 let xyz = Tensor::zeros([batch, nheads, blocks, 3], device);
137 RotationState::Quaternion(Tensor::cat(vec![w, xyz], 3))
138 }
139
140 /// Unwrap the abelian angle accumulator; panics if this is a quaternion.
141 pub fn angle(self) -> Tensor<3> {
142 match self {
143 RotationState::Angle(a) => a,
144 RotationState::Quaternion(_) => {
145 panic!("RotationState is Quaternion, expected Angle")
146 }
147 }
148 }
149
150 /// Unwrap the quaternion accumulator; panics if this is an angle.
151 pub fn quaternion(self) -> Tensor<4> {
152 match self {
153 RotationState::Quaternion(q) => q,
154 RotationState::Angle(_) => panic!("RotationState is Angle, expected Quaternion"),
155 }
156 }
157
158 /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on the held tensor.
159 pub fn sanity(&self) {
160 match self {
161 RotationState::Angle(a) => crate::modules::sanity(a),
162 RotationState::Quaternion(q) => crate::modules::sanity(q),
163 }
164 }
165}
166
167// ---------------------------------------------------------------------------
168// Quaternion algebra on the trailing `(w, x, y, z)` axis
169// ---------------------------------------------------------------------------
170
171/// Hamilton product `a ⊗ b` of two quaternion tensors.
172///
173/// Both inputs have shape `[..., 4]` with the last axis ordered `(w, x, y, z)`;
174/// the product is computed component-wise and broadcasts over the leading dims.
175/// Quaternion multiplication is **non-commutative** (`a ⊗ b ≠ b ⊗ a` in
176/// general) but associative.
177///
178/// Identifying `ℝ⁴` with the quaternions, left-multiplication `v ↦ a ⊗ v` is
179/// exactly the action of the `4×4` rotation matrix [`quat_to_rot4`]`(a)`, so
180/// this is also how a rotation is *applied* to a state/`B`/`C` block (see
181/// [`rotate_state_rank_blocks`]).
182pub fn quat_mul<const D: usize>(a: Tensor<D>, b: Tensor<D>) -> Tensor<D> {
183 let n = D - 1;
184 let aw = a.clone().narrow(n, 0, 1);
185 let ax = a.clone().narrow(n, 1, 1);
186 let ay = a.clone().narrow(n, 2, 1);
187 let az = a.narrow(n, 3, 1);
188 let bw = b.clone().narrow(n, 0, 1);
189 let bx = b.clone().narrow(n, 1, 1);
190 let by = b.clone().narrow(n, 2, 1);
191 let bz = b.narrow(n, 3, 1);
192
193 // Hamilton product (each term is shape [..., 1]).
194 let w = aw.clone() * bw.clone()
195 - ax.clone() * bx.clone()
196 - ay.clone() * by.clone()
197 - az.clone() * bz.clone();
198 let x = aw.clone() * bx.clone() + ax.clone() * bw.clone() + ay.clone() * bz.clone()
199 - az.clone() * by.clone();
200 let y = aw.clone() * by.clone() - ax.clone() * bz.clone()
201 + ay.clone() * bw.clone()
202 + az.clone() * bx.clone();
203 let z = aw * bz + ax * by - ay * bx + az * bw;
204
205 Tensor::cat(vec![w, x, y, z], n)
206}
207
208/// Quaternion conjugate `q* = (w, −x, −y, −z)` (shape `[..., 4]`).
209///
210/// For a **unit** quaternion `q* = q⁻¹`, and the corresponding rotation matrix
211/// satisfies `Lₚ⋆ = Lₚᵀ = Lₚ⁻¹`. Hence rotating by the *inverse* cumulative
212/// rotation (`B̄ = Pᵀ B`) is `rotate_state_rank_blocks(B, conj(Q))`.
213pub fn quat_conj<const D: usize>(q: Tensor<D>) -> Tensor<D> {
214 let n = D - 1;
215 let w = q.clone().narrow(n, 0, 1);
216 let xyz = q.narrow(n, 1, 3);
217 Tensor::cat(vec![w, -xyz], n)
218}
219
220/// Normalise quaternions to unit norm along the last axis (shape `[..., 4]`).
221///
222/// The per-step rotation is materialised from a raw, unconstrained projection
223/// and normalised here so it is a genuine unit quaternion (an element of
224/// `SU(2)`), the analogue of `tanh(θ)·π` bounding the RoPE angle. A tiny floor
225/// guards the zero-quaternion.
226pub fn quat_normalize<const D: usize>(q: Tensor<D>) -> Tensor<D> {
227 let n = D - 1;
228 // Clamp the sum-of-squares *before* `sqrt`: at a zero quaternion the forward
229 // `sqrt(0)=0` is fine, but `sqrt`'s backward is `1/(2·0)=∞`, and `∞·(2·0)=NaN`.
230 // Clamping pre-`sqrt` puts the degenerate point in `clamp_min`'s flat region,
231 // so its gradient is a finite 0 (and a genuine unit quaternion, sumsq=1, is
232 // untouched). The floor also keeps `norm` away from 0 for the division.
233 //
234 // The floor is the dtype-aware `div_eps` applied to the *sum-of-squares*
235 // (giving a norm floor of `√div_eps`). It must engage as a representable
236 // normal in the working dtype: in f16 a `div_eps²`-sized floor (~5e-7) would
237 // underflow below the min-normal (~6.1e-5) and silently no-op, so we floor
238 // the squared quantity at `div_eps` itself, which sits above each format's
239 // denormal floor by construction.
240 let eps = crate::utils::div_eps(q.dtype());
241 let norm = (q.clone() * q.clone()).sum_dim(n).clamp_min(eps).sqrt();
242 q / norm
243}
244
245/// Materialise a unit quaternion from a **scaled rotation vector** `g ∈ ℝ³`
246/// (axis · angle) via the exponential map — the data-dependent "materialise
247/// `Rₜ`" step, analogous to RoPE's `Δₜ · π · tanh(θₜ)` angle.
248///
249/// With `‖g‖ = angle` and `ĝ = g / angle` the axis, returns the unit quaternion
250/// `q = (cos(angle/2), sin(angle/2)·ĝ)`. A vanishing `g` maps to the identity
251/// `(1, 0, 0, 0)`, so scaling `g` by a small `Δₜ` (the discretisation step)
252/// yields a near-identity rotation — exactly the regime where a small step
253/// barely rotates the state. The `sin(angle/2)/angle` factor is the numerically
254/// stable form of the (otherwise `0/0`) per-component scale near `g = 0`.
255///
256/// # Shapes
257/// - `g` : `[..., 3]`
258/// - out : `[..., 4]` (ordered `(w, x, y, z)`), unit norm.
259pub fn quat_from_scaled_axis<const D: usize>(g: Tensor<D>) -> Tensor<D> {
260 let n = D - 1;
261 // Clamp the sum-of-squares *before* `sqrt`: at `g = 0` the forward `sqrt(0)=0`
262 // is finite, but `sqrt`'s backward is `1/(2·0)=∞` and `∞·(2·0)=NaN`. Clamping
263 // pre-`sqrt` puts `g = 0` in `clamp_min`'s flat (zero-gradient) region, so the
264 // near-identity rotation gets a finite 0 gradient instead of a NaN. (This is
265 // the FiLM-triggered decoder-backward NaN: a per-position rotation generator
266 // hitting exactly zero.) The floor is the dtype-aware `div_eps` on the squared
267 // quantity — see [`quat_normalize`] for why it floors `sumsq`, not the norm.
268 let eps = crate::utils::div_eps(g.dtype());
269 let angle = (g.clone() * g.clone()).sum_dim(n).clamp_min(eps).sqrt(); // [..., 1]
270 let half = angle.clone() * 0.5;
271 let w = half.clone().cos(); // [..., 1]
272 // sin(angle/2) / angle → 1/2 as angle → 0 (no rotation); `angle ≥ √div_eps`
273 // after the pre-`sqrt` clamp above, so the division is already guarded.
274 let scale = half.sin() / angle; // [..., 1]
275 let v = g * scale; // [..., 3]
276 quat_normalize(Tensor::cat(vec![w, v], n))
277}
278
279/// Materialise the `4×4` orthogonal matrix of left-multiplication by `q`.
280///
281/// Maps `q` of shape `[..., 4]` to `[..., 4, 4]` such that, for `v` of shape
282/// `[..., 4]`, `Lq · v == quat_mul(q, v)`. Concretely (rows = output coords,
283/// cols = input coords, all in `(w, x, y, z)` order):
284///
285/// ```text
286/// ⎡ w -x -y -z ⎤
287/// ⎢ x w -z y ⎥
288/// ⎢ y z w -x ⎥
289/// ⎣ z -y x w ⎦
290/// ```
291///
292/// For a unit `q` this is orthogonal with `det = 1` (a left-isoclinic rotation).
293/// Provided mainly for the generic / verification path; the cheap way to apply a
294/// rotation is [`rotate_state_rank_blocks`] (a quaternion product, no `4×4`
295/// materialisation). `DR` must equal `D + 1`.
296pub fn quat_to_rot4<const D: usize, const DR: usize>(q: Tensor<D>) -> Tensor<DR> {
297 assert_eq!(D + 1, DR, "quat_to_rot4 maps rank D to rank D+1");
298 let n = D - 1;
299 let w = q.clone().narrow(n, 0, 1);
300 let x = q.clone().narrow(n, 1, 1);
301 let y = q.clone().narrow(n, 2, 1);
302 let z = q.narrow(n, 3, 1);
303
304 // Each row is a [..., 4] tensor (the four column entries).
305 let row0 = Tensor::cat(vec![w.clone(), -x.clone(), -y.clone(), -z.clone()], n);
306 let row1 = Tensor::cat(vec![x.clone(), w.clone(), -z.clone(), y.clone()], n);
307 let row2 = Tensor::cat(vec![y.clone(), z.clone(), w.clone(), -x.clone()], n);
308 let row3 = Tensor::cat(vec![z, -y, x, w], n);
309
310 // Stack the rows along a freshly inserted row axis → [..., 4, 4].
311 Tensor::cat(
312 vec![
313 row0.unsqueeze_dim::<DR>(n),
314 row1.unsqueeze_dim::<DR>(n),
315 row2.unsqueeze_dim::<DR>(n),
316 row3.unsqueeze_dim::<DR>(n),
317 ],
318 n,
319 )
320}
321
322// ---------------------------------------------------------------------------
323// Rotation application on the state_rank axis
324// ---------------------------------------------------------------------------
325
326/// Apply a per-block quaternion rotation to the `state_rank` axis of `v`.
327///
328/// `v` has shape `[..., state_rank]` with `state_rank = 4·J`, viewed as `J`
329/// independent quaternion blocks; `q` has shape `[..., J, 4]` (one unit
330/// quaternion per block, same leading dims as `v`). Returns `q ⊗ v` per block,
331/// i.e. the rotation `L_q` applied within each `4`-block, reshaped back to
332/// `[..., state_rank]`.
333///
334/// This is the generalisation of RoPE's per-pair `2×2` rotation to per-block
335/// `4×4`. To rotate by the *inverse* cumulative rotation when absorbing into
336/// `B`/`C` (`B̄ = Pᵀ B`), pass `q = conj(Qcum)`:
337/// `rotate_state_rank_blocks(b, conj(qcum))`.
338///
339/// `DB` must equal `D + 1` (the block-split inserts the `J` axis).
340pub fn rotate_state_rank_blocks<const D: usize, const DB: usize>(
341 v: Tensor<D>,
342 q: Tensor<DB>,
343) -> Tensor<D> {
344 assert_eq!(
345 D + 1,
346 DB,
347 "rotate_state_rank_blocks splits one axis into (J, 4)"
348 );
349 let dims = v.dims();
350 let state_rank = dims[D - 1];
351 assert_eq!(
352 state_rank % 4,
353 0,
354 "state_rank must be a multiple of 4 (quaternion blocks)"
355 );
356 let blocks = state_rank / 4;
357
358 // Build the block-split shape [..., J, 4] (rank DB) and the flat shape
359 // [..., state_rank] (rank D) for the round trip.
360 let mut split_shape = [0usize; DB];
361 split_shape[..D - 1].copy_from_slice(&dims[..D - 1]);
362 split_shape[DB - 2] = blocks;
363 split_shape[DB - 1] = 4;
364
365 let v_blocks = v.reshape(split_shape); // [..., J, 4]
366 let rotated = quat_mul(q, v_blocks); // L_q applied per block
367 rotated.reshape(dims) // [..., state_rank]
368}
369
370// ---------------------------------------------------------------------------
371// Cumulative rotation scan (the associative, non-abelian replacement for cumsum)
372// ---------------------------------------------------------------------------
373
374/// Cumulative (ordered, left-accumulating) quaternion product along the
375/// sequence axis, with a cross-chunk carry.
376///
377/// This is the non-abelian analogue of the cumulative *sum of angles* used by
378/// RoPE: where complex rotations compose by adding angles (a `cumsum`),
379/// quaternions compose by multiplication, which is order-dependent, so a real
380/// scan is required.
381///
382/// # Shapes
383/// - `q_bshj4` : `[batch, sequence, nheads, J, 4]` per-step **unit** quaternions
384/// (block count `J = state_rank / 4`).
385/// - `init` : optional carry `[batch, nheads, J, 4]` — the cumulative
386/// rotation at the end of the previous chunk (identity `(1,0,0,0)` for a fresh
387/// start).
388/// - returns `(cum, final_carry)` where `cum` is `[batch, sequence, nheads, J, 4]`
389/// with `cum[:, t] = qₜ ⊗ qₜ₋₁ ⊗ ⋯ ⊗ q₀ ⊗ init` (newest on the left, matching
390/// `Pₜ = Rₜ ⋯ R₁`), and `final_carry` `[batch, nheads, J, 4]` is `cum[:, −1]`
391/// to thread into the next chunk.
392///
393/// Running this over a split sequence while threading `final_carry` is exactly
394/// equal to running it over the whole sequence (asserted in the tests) — the
395/// chunked-prefill / streaming guarantee, here for the rotation accumulator.
396///
397/// Implemented as a **Hillis–Steele** inclusive associative scan: the quaternion
398/// product is associative (just not commutative), so a log-depth scan applies as
399/// long as operand order is preserved (newest-on-left). Each doubling step is a
400/// single full-tensor [`quat_mul`] plus a sequence shift, so the *sequential
401/// dependency depth* is `O(log sequence)` rather than the `O(sequence)` of a
402/// token-by-token loop — the same values, but a handful of large batched kernels
403/// instead of thousands of serialized tiny ones (and a correspondingly shallow
404/// autodiff graph). The sequential reference it replaces is kept as a test oracle
405/// (`quat_cumprod_sequential` in the tests module) and asserted equal on values
406/// **and** gradients.
407pub fn quat_cumprod(q_bshj4: Tensor<5>, init: Option<Tensor<4>>) -> (Tensor<5>, Tensor<4>) {
408 let [batch, sequence, nheads, blocks, _four] = q_bshj4.dims();
409 let device = q_bshj4.device();
410
411 // Pure prefix product Pₜ = qₜ ⊗ qₜ₋₁ ⊗ ⋯ ⊗ q₀ by Hillis–Steele doubling.
412 // Invariant after each step with offset `d`: a[t] holds the product of the
413 // window [t .. max(t-2d+1, 0)] (newest on the left). After ⌈log₂ sequence⌉
414 // doublings the window covers [t .. 0], i.e. a[t] = Pₜ.
415 let mut a = q_bshj4;
416 let mut offset = 1usize;
417 while offset < sequence {
418 // shifted[t] = a[t-offset] for t ≥ offset, else the identity quaternion
419 // (1,0,0,0) — so the first `offset` prefixes pass through unchanged
420 // (a ⊗ identity = a).
421 let ident = {
422 let w = Tensor::ones([batch, offset, nheads, blocks, 1], &device);
423 let xyz = Tensor::zeros([batch, offset, nheads, blocks, 3], &device);
424 Tensor::cat(vec![w, xyz], 4)
425 };
426 let shifted = Tensor::cat(vec![ident, a.clone()], 1).narrow(1, 0, sequence);
427 // Recent block (a) on the left, older block (shifted) on the right.
428 a = quat_mul(a, shifted);
429 offset *= 2;
430 }
431
432 // Fold the cross-chunk carry once: cumₜ = Pₜ ⊗ init. `init` (the previous
433 // chunk's final cumulative rotation) is the oldest factor, hence on the
434 // right; a missing carry is the identity and needs no multiply.
435 let cum = match init {
436 Some(init_bhj4) => {
437 assert_eq!([batch, nheads, blocks, 4], init_bhj4.dims());
438 quat_mul(a, init_bhj4.unsqueeze_dim::<5>(1)) // [batch, 1, nheads, J, 4] broadcasts over seq
439 }
440 None => a,
441 };
442
443 let final_carry = cum.clone().narrow(1, sequence - 1, 1).squeeze_dim::<4>(1); // [batch, nheads, J, 4]
444 (cum, final_carry)
445}
446
447// ---------------------------------------------------------------------------
448// Partial block rotation (rope_fraction support)
449// ---------------------------------------------------------------------------
450
451/// Apply a per-block quaternion rotation to the first `rope_width` entries of
452/// the `state_rank` axis (a multiple of 4); the remainder passes through. The
453/// quaternion analogue of [`apply_rope_partial`].
454///
455/// `q` has one quaternion per rotated block (`rope_width / 4` of them). `DB`
456/// must equal `D + 1`.
457pub fn rotate_blocks_partial<const D: usize, const DB: usize>(
458 v: Tensor<D>,
459 q: Tensor<DB>,
460 rope_width: usize,
461) -> Tensor<D> {
462 let r = v.dims()[D - 1];
463 if rope_width == r {
464 rotate_state_rank_blocks::<D, DB>(v, q)
465 } else {
466 let head = v.clone().narrow(D - 1, 0, rope_width);
467 let tail = v.narrow(D - 1, rope_width, r - rope_width);
468 let head_rot = rotate_state_rank_blocks::<D, DB>(head, q);
469 Tensor::cat(vec![head_rot, tail], D - 1)
470 }
471}
472
473// ---------------------------------------------------------------------------
474// Forward / step rotation of B and C (shared by both SSD pathways)
475// ---------------------------------------------------------------------------
476
477/// Rotate `B`/`C` for a **full sequence** by the data-dependent positional
478/// rotation, returning the rotated projections and the new cumulative
479/// [`RotationState`] to store in the cache.
480///
481/// Branches on [`RotationKind`]:
482/// - [`Complex2D`](RotationKind::Complex2D): the abelian RoPE — cumulative
483/// angle `cumsum` continued from `prev`, then [`apply_rope_partial`]. Exactly
484/// the original Mamba-3 behaviour.
485/// - [`Quaternion4D`](RotationKind::Quaternion4D): per-step unit quaternion
486/// [`quat_from_scaled_axis`] (the in-projection generators scaled per-head by
487/// `Δ`), composed by [`quat_cumprod`] continuing the cached quaternion, then
488/// applied to `B`/`C` as `rotate(·, conj(Qₜ))` over the first `4·blocks`
489/// state-rank entries.
490///
491/// # Shapes
492/// - `rot_bsa` : `[batch, sequence, num_rotation_channels]` — the in-projection
493/// rotation channels (angles for Complex2D, `3·blocks` quaternion generators
494/// for Quaternion4D).
495/// - `dt_bsh` : `[batch, sequence, nheads]` (`Δ`).
496/// - `b_bsmhr` / `c_bsmhr` : `[batch, sequence, mimo_rank, nheads, state_rank]`.
497pub fn rotate_bc_forward(
498 rot_bsa: Tensor<3>,
499 dt_bsh: Tensor<3>,
500 prev: RotationState,
501 b_bsmhr: Tensor<5>,
502 c_bsmhr: Tensor<5>,
503 kind: RotationKind,
504 rope_dim: usize,
505) -> (Tensor<5>, Tensor<5>, RotationState) {
506 let [batch, sequence, mimo_rank, nheads, _state_rank] = b_bsmhr.dims();
507 match kind {
508 RotationKind::Complex2D => {
509 let prev_angle_bha = prev.angle();
510 let num_rope_angles = prev_angle_bha.dims()[2];
511 let theta_scaled_bsa = rot_bsa.tanh() * std::f32::consts::PI;
512 let raw_angles_bsha =
513 dt_bsh.unsqueeze_dim::<4>(3) * theta_scaled_bsa.unsqueeze_dim::<4>(2);
514 let cum_angles_bsha = prev_angle_bha.unsqueeze_dim::<4>(1) + raw_angles_bsha.cumsum(1);
515 let cum_angles_bsmha = cum_angles_bsha.clone().unsqueeze_dim::<5>(2).expand([
516 batch,
517 sequence,
518 mimo_rank,
519 nheads,
520 num_rope_angles,
521 ]);
522 let rotate_pairwise = mimo_rank == 1;
523 let b = apply_rope_partial::<5>(
524 b_bsmhr,
525 cum_angles_bsmha.clone(),
526 rope_dim,
527 rotate_pairwise,
528 );
529 let c = apply_rope_partial::<5>(c_bsmhr, cum_angles_bsmha, rope_dim, rotate_pairwise);
530 let last = wrap_angle(
531 cum_angles_bsha
532 .narrow(1, sequence - 1, 1)
533 .squeeze_dim::<3>(1),
534 );
535 (b, c, RotationState::Angle(last))
536 }
537 RotationKind::Quaternion4D => {
538 let prev_q_bhj4 = prev.quaternion();
539 let blocks = prev_q_bhj4.dims()[2];
540 let rope_width = blocks * 4;
541 // Generators [b,s,blocks,3] (shared across heads), scaled per-head by Δ.
542 //
543 // Bound the raw generator with `tanh·π` before scaling by Δ — the
544 // direct analogue of the Complex2D path (`rot.tanh()·π`). Without it
545 // the generator is unbounded, so a large in-projection activation makes
546 // `g = rot·Δ` overflow f32 to `inf`, and `quat_from_scaled_axis`'s
547 // `cos(∞)` then yields a forward NaN. The bound caps each per-step
548 // rotation to `±π·Δ` (cos/sin still give the periodicity within range);
549 // healthy `O(1)` generators stay in tanh's near-linear region.
550 let g_bshj3 = (rot_bsa.tanh() * core::f32::consts::PI)
551 .reshape([batch, sequence, blocks, 3])
552 .unsqueeze_dim::<5>(2)
553 * dt_bsh.unsqueeze_dim::<4>(3).unsqueeze_dim::<5>(4);
554 let q_step_bshj4 = quat_from_scaled_axis::<5>(g_bshj3);
555 // Memory-efficient scan: a custom recompute backward (saves only the
556 // leaf inputs) instead of retaining the scan's intermediates. Equal
557 // to [`quat_cumprod`] on values and gradients (asserted in tests).
558 let (cum_bshj4, final_bhj4) = crate::mamba3::quat_scan::quat_cumprod_recalculated(
559 q_step_bshj4,
560 Some(prev_q_bhj4),
561 );
562 // B̄ = rotate by the inverse cumulative rotation (conjugate), per block,
563 // broadcast over the mimo_rank axis.
564 let conj_bsmhj4 = quat_conj(cum_bshj4)
565 .unsqueeze_dim::<6>(2)
566 .expand([batch, sequence, mimo_rank, nheads, blocks, 4]);
567 let b = rotate_blocks_partial::<5, 6>(b_bsmhr, conj_bsmhj4.clone(), rope_width);
568 let c = rotate_blocks_partial::<5, 6>(c_bsmhr, conj_bsmhj4, rope_width);
569 (b, c, RotationState::Quaternion(quat_normalize(final_bhj4)))
570 }
571 }
572}
573
574/// Single-token counterpart of [`rotate_bc_forward`] for the recurrent `step`.
575///
576/// # Shapes
577/// - `rot_ba` : `[batch, num_rotation_channels]`.
578/// - `dt_bh` : `[batch, nheads]`.
579/// - `b_bmhr` / `c_bmhr` : `[batch, mimo_rank, nheads, state_rank]`.
580pub fn rotate_bc_step(
581 rot_ba: Tensor<2>,
582 dt_bh: Tensor<2>,
583 prev: RotationState,
584 b_bmhr: Tensor<4>,
585 c_bmhr: Tensor<4>,
586 kind: RotationKind,
587 rope_dim: usize,
588) -> (Tensor<4>, Tensor<4>, RotationState) {
589 let [batch, mimo_rank, nheads, _state_rank] = b_bmhr.dims();
590 match kind {
591 RotationKind::Complex2D => {
592 let prev_angle_bha = prev.angle();
593 let num_rope_angles = prev_angle_bha.dims()[2];
594 let theta_scaled_ba = rot_ba.tanh() * std::f32::consts::PI;
595 let raw_angle_bha = dt_bh.unsqueeze_dim::<3>(2) * theta_scaled_ba.unsqueeze_dim::<3>(1);
596 let new_cum_angle_bha = wrap_angle(prev_angle_bha + raw_angle_bha);
597 let new_cum_angle_bmha = new_cum_angle_bha.clone().unsqueeze_dim::<4>(1).expand([
598 batch,
599 mimo_rank,
600 nheads,
601 num_rope_angles,
602 ]);
603 let rotate_pairwise = mimo_rank == 1;
604 let b = apply_rope_partial::<4>(
605 b_bmhr,
606 new_cum_angle_bmha.clone(),
607 rope_dim,
608 rotate_pairwise,
609 );
610 let c = apply_rope_partial::<4>(c_bmhr, new_cum_angle_bmha, rope_dim, rotate_pairwise);
611 (b, c, RotationState::Angle(new_cum_angle_bha))
612 }
613 RotationKind::Quaternion4D => {
614 let prev_q_bhj4 = prev.quaternion();
615 let blocks = prev_q_bhj4.dims()[2];
616 let rope_width = blocks * 4;
617 // `tanh·π` bound, matching `rotate_bc_forward` (see the note there).
618 let g_bhj3 = (rot_ba.tanh() * core::f32::consts::PI)
619 .reshape([batch, blocks, 3])
620 .unsqueeze_dim::<4>(1)
621 * dt_bh.unsqueeze_dim::<3>(2).unsqueeze_dim::<4>(3);
622 let q_step_bhj4 = quat_from_scaled_axis::<4>(g_bhj3);
623 // Single step: Qₜ = qₜ ⊗ Qₜ₋₁.
624 let new_q_bhj4 = quat_normalize(quat_mul(q_step_bhj4, prev_q_bhj4));
625 let conj_bmhj4 = quat_conj(new_q_bhj4.clone())
626 .unsqueeze_dim::<5>(1)
627 .expand([batch, mimo_rank, nheads, blocks, 4]);
628 let b = rotate_blocks_partial::<4, 5>(b_bmhr, conj_bmhj4.clone(), rope_width);
629 let c = rotate_blocks_partial::<4, 5>(c_bmhr, conj_bmhj4, rope_width);
630 (b, c, RotationState::Quaternion(new_q_bhj4))
631 }
632 }
633}
634
635#[cfg(all(test, feature = "_dev-test"))]
636mod tests;