Skip to main content

burn_mamba/mamba3/rotation/
mod.rs

1//! # Quaternion (k=4) rotational state — the non-abelian generalisation of RoPE
2//!
3//! Mamba-3's data-dependent RoPE realises a **complex-valued** SSM: the state
4//! transition factors as a per-head scalar decay times a block-diagonal of
5//! `2×2` rotations (paper Prop. *Complex-to-Real SSM Equivalence*), and because
6//! `SO(2) ≅ U(1)` is **abelian** the cumulative rotation collapses to a
7//! `cumsum` of angles and is absorbed into `B`/`C` (the "RoPE trick", Prop.
8//! *Complex SSM, Data-Dependent RoPE Equivalence*).  See
9//! [`crate::mamba3::double_ssd::double_ssd::apply_rope`].
10//!
11//! This module implements the next rung of the ladder: a **quaternion**
12//! (`k = 4`) rotational state, i.e. the transition's rotation lives in the
13//! left-isoclinic subgroup `SU(2) ⊂ SO(4)` instead of `SO(2)`.  Unit
14//! quaternions under multiplication are `SU(2)`, which is **non-abelian** and
15//! contains non-solvable finite subgroups (the binary icosahedral group
16//! `2I = SL(2,5)`, a double cover of `A₅`).  By Barrington's theorem this lifts
17//! the layer's reachable state-tracking from the solvable/`TC⁰` regime (parity,
18//! mod-k) toward `NC¹`, which abelian rotations provably cannot reach.
19//!
20//! ## What survives, what changes
21//!
22//! The key fact (derivable purely from telescoping + orthogonality, **without**
23//! commutativity — see the crate discussion) is that the RoPE *factoring*
24//! survives intact: with the **ordered** cumulative rotation
25//! `Pₜ = Rₜ Rₜ₋₁ ⋯ R₁`,
26//!
27//! ```text
28//!   Cₜᵀ (Rₜ⋯Rᵢ₊₁) Bᵢ  =  (Pₜᵀ Cₜ)ᵀ (Pᵢᵀ Bᵢ)  =  C̄ₜᵀ B̄ᵢ ,
29//! ```
30//!
31//! so the scalar-decay SSD core (`L ⊙ C̄B̄ᵀ`) is **unchanged** — only the
32//! projections `B̄ᵢ = Pᵢᵀ Bᵢ`, `C̄ₜ = Pₜᵀ Cₜ` are rotated.  What is lost is the
33//! closed-form `cumsum`: the cumulative rotation must be built by an
34//! **associative scan over the per-step quaternions** ([`quat_cumprod`]) rather
35//! than a sum of angles.  Because a product of unit quaternions is again a unit
36//! quaternion, the scan stays exactly orthogonal (no drift, no `wrap_angle`
37//! needed), and the cross-chunk carry is a single quaternion per block/head —
38//! the exact analogue of `cum_angle` in the existing caches.
39//!
40//! `SO(2)` (today's `apply_rope`) is the abelian collapse: restricting each
41//! quaternion to a single fixed axis makes them commute and reduces
42//! [`quat_cumprod`] to a `cumsum` of half-angles (asserted in the tests).
43//!
44//! ## Pipeline (the `k = 4` instantiation of the rotation block)
45//!
46//! ```text
47//!   per-step unit quaternion qₜ      (materialise from the in-projection; caller)
48//!        │  quat_cumprod (assoc. scan, + cross-chunk carry)
49//!        ▼
50//!   cumulative rotation Qₜ
51//!        │  rotate_state_rank_blocks(B, conj(Qₜ)) , rotate_state_rank_blocks(C, conj(Qₜ))
52//!        ▼
53//!   B̄, C̄  ──►  standard scalar-decay SSD  (unchanged)
54//! ```
55//!
56//! Quaternion layout: the last axis has size 4 and holds `(w, x, y, z)` with
57//! `w` the real part.  A `state_rank` of `r = 4·J` is treated as `J` independent
58//! quaternion blocks; the rotation acts within each block, exactly as RoPE acts
59//! within each `2`-pair.  This module is a self-contained, tested reference for
60//! the math; wiring it into the [`Mamba3`](crate::mamba3::mamba3::Mamba3) block
61//! is a separate, larger change (the SSD kernels themselves need no edits).
62
63use crate::modules::{apply_rope_partial, wrap_angle};
64use burn::module::Module;
65use burn::prelude::*;
66
67// ---------------------------------------------------------------------------
68// Rotation kind (config switch) and cache accumulator variant
69// ---------------------------------------------------------------------------
70
71/// Which rotational-state algebra the block uses for the data-dependent
72/// positional rotation of `B`/`C`.
73///
74/// - [`Complex2D`](RotationKind::Complex2D) — the abelian `SO(2)`/complex RoPE
75///   that Mamba-3 ships: cumulative *angles* via `cumsum`, applied by
76///   [`apply_rope`]. The default; behaviourally unchanged.
77/// - [`Quaternion4D`](RotationKind::Quaternion4D) — the non-abelian
78///   `SU(2) ⊂ SO(4)` quaternion rotation of this module: cumulative *product*
79///   via [`quat_cumprod`], applied by [`rotate_state_rank_blocks`]. Richer
80///   state-tracking; selects the [`RotationState::Quaternion`] cache accumulator.
81#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
82pub enum RotationKind {
83    /// Abelian complex (`SO(2)`) RoPE — the current default behaviour.
84    Complex2D,
85    /// Non-abelian quaternion (`SU(2)`) rotation.
86    Quaternion4D,
87}
88
89impl Default for RotationKind {
90    fn default() -> Self {
91        RotationKind::Complex2D
92    }
93}
94
95/// The cumulative-rotation accumulator carried between calls in a Mamba-3 cache
96/// — the variant matching the block's [`RotationKind`].
97///
98/// - [`Angle`](RotationState::Angle) — abelian per-pair cumulative RoPE angle,
99///   shape `[batch, nheads, num_rope_angles]` (today's `cum_angle`).
100/// - [`Quaternion`](RotationState::Quaternion) — per-block cumulative unit
101///   quaternion, shape `[batch, nheads, blocks, 4]`, produced by
102///   [`quat_cumprod`].
103///
104/// This is the cache-level counterpart of [`RotationKind`]. It is defined here
105/// (the rotation module owns the accumulator type); substituting it for the
106/// pathway caches' `cum_angle_bha` field happens together with the forward/step
107/// wiring that consumes it.
108#[derive(Module, Debug)]
109pub enum RotationState {
110    /// Abelian RoPE cumulative angle, shape `[batch, nheads, num_rope_angles]`.
111    Angle(Tensor<3>),
112    /// Quaternion cumulative rotation, shape `[batch, nheads, blocks, 4]`.
113    Quaternion(Tensor<4>),
114}
115
116impl RotationState {
117    /// Zero-initialised abelian angle accumulator `[batch, nheads, num_rope_angles]`.
118    pub fn zeros_angle(
119        batch: usize,
120        nheads: usize,
121        num_rope_angles: usize,
122        device: &Device,
123    ) -> Self {
124        RotationState::Angle(Tensor::zeros([batch, nheads, num_rope_angles], device))
125    }
126
127    /// Identity-initialised quaternion accumulator `[batch, nheads, blocks, 4]`
128    /// (every block is the identity quaternion `(1, 0, 0, 0)`).
129    pub fn identity_quaternion(
130        batch: usize,
131        nheads: usize,
132        blocks: usize,
133        device: &Device,
134    ) -> Self {
135        let w = Tensor::ones([batch, nheads, blocks, 1], device);
136        let xyz = Tensor::zeros([batch, nheads, blocks, 3], device);
137        RotationState::Quaternion(Tensor::cat(vec![w, xyz], 3))
138    }
139
140    /// Unwrap the abelian angle accumulator; panics if this is a quaternion.
141    pub fn angle(self) -> Tensor<3> {
142        match self {
143            RotationState::Angle(a) => a,
144            RotationState::Quaternion(_) => {
145                panic!("RotationState is Quaternion, expected Angle")
146            }
147        }
148    }
149
150    /// Unwrap the quaternion accumulator; panics if this is an angle.
151    pub fn quaternion(self) -> Tensor<4> {
152        match self {
153            RotationState::Quaternion(q) => q,
154            RotationState::Angle(_) => panic!("RotationState is Angle, expected Quaternion"),
155        }
156    }
157
158    /// Run the [`NaN`/`Inf` guards](crate::utils::sanity) on the held tensor.
159    pub fn sanity(&self) {
160        match self {
161            RotationState::Angle(a) => crate::modules::sanity(a),
162            RotationState::Quaternion(q) => crate::modules::sanity(q),
163        }
164    }
165}
166
167// ---------------------------------------------------------------------------
168// Quaternion algebra on the trailing `(w, x, y, z)` axis
169// ---------------------------------------------------------------------------
170
171/// Hamilton product `a ⊗ b` of two quaternion tensors.
172///
173/// Both inputs have shape `[..., 4]` with the last axis ordered `(w, x, y, z)`;
174/// the product is computed component-wise and broadcasts over the leading dims.
175/// Quaternion multiplication is **non-commutative** (`a ⊗ b ≠ b ⊗ a` in
176/// general) but associative.
177///
178/// Identifying `ℝ⁴` with the quaternions, left-multiplication `v ↦ a ⊗ v` is
179/// exactly the action of the `4×4` rotation matrix [`quat_to_rot4`]`(a)`, so
180/// this is also how a rotation is *applied* to a state/`B`/`C` block (see
181/// [`rotate_state_rank_blocks`]).
182pub fn quat_mul<const D: usize>(a: Tensor<D>, b: Tensor<D>) -> Tensor<D> {
183    let n = D - 1;
184    let aw = a.clone().narrow(n, 0, 1);
185    let ax = a.clone().narrow(n, 1, 1);
186    let ay = a.clone().narrow(n, 2, 1);
187    let az = a.narrow(n, 3, 1);
188    let bw = b.clone().narrow(n, 0, 1);
189    let bx = b.clone().narrow(n, 1, 1);
190    let by = b.clone().narrow(n, 2, 1);
191    let bz = b.narrow(n, 3, 1);
192
193    // Hamilton product (each term is shape [..., 1]).
194    let w = aw.clone() * bw.clone()
195        - ax.clone() * bx.clone()
196        - ay.clone() * by.clone()
197        - az.clone() * bz.clone();
198    let x = aw.clone() * bx.clone() + ax.clone() * bw.clone() + ay.clone() * bz.clone()
199        - az.clone() * by.clone();
200    let y = aw.clone() * by.clone() - ax.clone() * bz.clone()
201        + ay.clone() * bw.clone()
202        + az.clone() * bx.clone();
203    let z = aw * bz + ax * by - ay * bx + az * bw;
204
205    Tensor::cat(vec![w, x, y, z], n)
206}
207
208/// Quaternion conjugate `q* = (w, −x, −y, −z)` (shape `[..., 4]`).
209///
210/// For a **unit** quaternion `q* = q⁻¹`, and the corresponding rotation matrix
211/// satisfies `Lₚ⋆ = Lₚᵀ = Lₚ⁻¹`.  Hence rotating by the *inverse* cumulative
212/// rotation (`B̄ = Pᵀ B`) is `rotate_state_rank_blocks(B, conj(Q))`.
213pub fn quat_conj<const D: usize>(q: Tensor<D>) -> Tensor<D> {
214    let n = D - 1;
215    let w = q.clone().narrow(n, 0, 1);
216    let xyz = q.narrow(n, 1, 3);
217    Tensor::cat(vec![w, -xyz], n)
218}
219
220/// Normalise quaternions to unit norm along the last axis (shape `[..., 4]`).
221///
222/// The per-step rotation is materialised from a raw, unconstrained projection
223/// and normalised here so it is a genuine unit quaternion (an element of
224/// `SU(2)`), the analogue of `tanh(θ)·π` bounding the RoPE angle.  A tiny floor
225/// guards the zero-quaternion.
226pub fn quat_normalize<const D: usize>(q: Tensor<D>) -> Tensor<D> {
227    let n = D - 1;
228    // Clamp the sum-of-squares *before* `sqrt`: at a zero quaternion the forward
229    // `sqrt(0)=0` is fine, but `sqrt`'s backward is `1/(2·0)=∞`, and `∞·(2·0)=NaN`.
230    // Clamping pre-`sqrt` puts the degenerate point in `clamp_min`'s flat region,
231    // so its gradient is a finite 0 (and a genuine unit quaternion, sumsq=1, is
232    // untouched). The floor also keeps `norm` away from 0 for the division.
233    //
234    // The floor is the dtype-aware `div_eps` applied to the *sum-of-squares*
235    // (giving a norm floor of `√div_eps`). It must engage as a representable
236    // normal in the working dtype: in f16 a `div_eps²`-sized floor (~5e-7) would
237    // underflow below the min-normal (~6.1e-5) and silently no-op, so we floor
238    // the squared quantity at `div_eps` itself, which sits above each format's
239    // denormal floor by construction.
240    let eps = crate::utils::div_eps(q.dtype());
241    let norm = (q.clone() * q.clone()).sum_dim(n).clamp_min(eps).sqrt();
242    q / norm
243}
244
245/// Materialise a unit quaternion from a **scaled rotation vector** `g ∈ ℝ³`
246/// (axis · angle) via the exponential map — the data-dependent "materialise
247/// `Rₜ`" step, analogous to RoPE's `Δₜ · π · tanh(θₜ)` angle.
248///
249/// With `‖g‖ = angle` and `ĝ = g / angle` the axis, returns the unit quaternion
250/// `q = (cos(angle/2), sin(angle/2)·ĝ)`.  A vanishing `g` maps to the identity
251/// `(1, 0, 0, 0)`, so scaling `g` by a small `Δₜ` (the discretisation step)
252/// yields a near-identity rotation — exactly the regime where a small step
253/// barely rotates the state.  The `sin(angle/2)/angle` factor is the numerically
254/// stable form of the (otherwise `0/0`) per-component scale near `g = 0`.
255///
256/// # Shapes
257/// - `g` : `[..., 3]`
258/// - out : `[..., 4]` (ordered `(w, x, y, z)`), unit norm.
259pub fn quat_from_scaled_axis<const D: usize>(g: Tensor<D>) -> Tensor<D> {
260    let n = D - 1;
261    // Clamp the sum-of-squares *before* `sqrt`: at `g = 0` the forward `sqrt(0)=0`
262    // is finite, but `sqrt`'s backward is `1/(2·0)=∞` and `∞·(2·0)=NaN`. Clamping
263    // pre-`sqrt` puts `g = 0` in `clamp_min`'s flat (zero-gradient) region, so the
264    // near-identity rotation gets a finite 0 gradient instead of a NaN. (This is
265    // the FiLM-triggered decoder-backward NaN: a per-position rotation generator
266    // hitting exactly zero.) The floor is the dtype-aware `div_eps` on the squared
267    // quantity — see [`quat_normalize`] for why it floors `sumsq`, not the norm.
268    let eps = crate::utils::div_eps(g.dtype());
269    let angle = (g.clone() * g.clone()).sum_dim(n).clamp_min(eps).sqrt(); // [..., 1]
270    let half = angle.clone() * 0.5;
271    let w = half.clone().cos(); // [..., 1]
272    // sin(angle/2) / angle  → 1/2 as angle → 0 (no rotation); `angle ≥ √div_eps`
273    // after the pre-`sqrt` clamp above, so the division is already guarded.
274    let scale = half.sin() / angle; // [..., 1]
275    let v = g * scale; // [..., 3]
276    quat_normalize(Tensor::cat(vec![w, v], n))
277}
278
279/// Materialise the `4×4` orthogonal matrix of left-multiplication by `q`.
280///
281/// Maps `q` of shape `[..., 4]` to `[..., 4, 4]` such that, for `v` of shape
282/// `[..., 4]`, `Lq · v == quat_mul(q, v)`.  Concretely (rows = output coords,
283/// cols = input coords, all in `(w, x, y, z)` order):
284///
285/// ```text
286///   ⎡ w  -x  -y  -z ⎤
287///   ⎢ x   w  -z   y ⎥
288///   ⎢ y   z   w  -x ⎥
289///   ⎣ z  -y   x   w ⎦
290/// ```
291///
292/// For a unit `q` this is orthogonal with `det = 1` (a left-isoclinic rotation).
293/// Provided mainly for the generic / verification path; the cheap way to apply a
294/// rotation is [`rotate_state_rank_blocks`] (a quaternion product, no `4×4`
295/// materialisation).  `DR` must equal `D + 1`.
296pub fn quat_to_rot4<const D: usize, const DR: usize>(q: Tensor<D>) -> Tensor<DR> {
297    assert_eq!(D + 1, DR, "quat_to_rot4 maps rank D to rank D+1");
298    let n = D - 1;
299    let w = q.clone().narrow(n, 0, 1);
300    let x = q.clone().narrow(n, 1, 1);
301    let y = q.clone().narrow(n, 2, 1);
302    let z = q.narrow(n, 3, 1);
303
304    // Each row is a [..., 4] tensor (the four column entries).
305    let row0 = Tensor::cat(vec![w.clone(), -x.clone(), -y.clone(), -z.clone()], n);
306    let row1 = Tensor::cat(vec![x.clone(), w.clone(), -z.clone(), y.clone()], n);
307    let row2 = Tensor::cat(vec![y.clone(), z.clone(), w.clone(), -x.clone()], n);
308    let row3 = Tensor::cat(vec![z, -y, x, w], n);
309
310    // Stack the rows along a freshly inserted row axis → [..., 4, 4].
311    Tensor::cat(
312        vec![
313            row0.unsqueeze_dim::<DR>(n),
314            row1.unsqueeze_dim::<DR>(n),
315            row2.unsqueeze_dim::<DR>(n),
316            row3.unsqueeze_dim::<DR>(n),
317        ],
318        n,
319    )
320}
321
322// ---------------------------------------------------------------------------
323// Rotation application on the state_rank axis
324// ---------------------------------------------------------------------------
325
326/// Apply a per-block quaternion rotation to the `state_rank` axis of `v`.
327///
328/// `v` has shape `[..., state_rank]` with `state_rank = 4·J`, viewed as `J`
329/// independent quaternion blocks; `q` has shape `[..., J, 4]` (one unit
330/// quaternion per block, same leading dims as `v`).  Returns `q ⊗ v` per block,
331/// i.e. the rotation `L_q` applied within each `4`-block, reshaped back to
332/// `[..., state_rank]`.
333///
334/// This is the generalisation of RoPE's per-pair `2×2` rotation to per-block
335/// `4×4`.  To rotate by the *inverse* cumulative rotation when absorbing into
336/// `B`/`C` (`B̄ = Pᵀ B`), pass `q = conj(Qcum)`:
337/// `rotate_state_rank_blocks(b, conj(qcum))`.
338///
339/// `DB` must equal `D + 1` (the block-split inserts the `J` axis).
340pub fn rotate_state_rank_blocks<const D: usize, const DB: usize>(
341    v: Tensor<D>,
342    q: Tensor<DB>,
343) -> Tensor<D> {
344    assert_eq!(
345        D + 1,
346        DB,
347        "rotate_state_rank_blocks splits one axis into (J, 4)"
348    );
349    let dims = v.dims();
350    let state_rank = dims[D - 1];
351    assert_eq!(
352        state_rank % 4,
353        0,
354        "state_rank must be a multiple of 4 (quaternion blocks)"
355    );
356    let blocks = state_rank / 4;
357
358    // Build the block-split shape [..., J, 4] (rank DB) and the flat shape
359    // [..., state_rank] (rank D) for the round trip.
360    let mut split_shape = [0usize; DB];
361    split_shape[..D - 1].copy_from_slice(&dims[..D - 1]);
362    split_shape[DB - 2] = blocks;
363    split_shape[DB - 1] = 4;
364
365    let v_blocks = v.reshape(split_shape); // [..., J, 4]
366    let rotated = quat_mul(q, v_blocks); // L_q applied per block
367    rotated.reshape(dims) // [..., state_rank]
368}
369
370// ---------------------------------------------------------------------------
371// Cumulative rotation scan (the associative, non-abelian replacement for cumsum)
372// ---------------------------------------------------------------------------
373
374/// Cumulative (ordered, left-accumulating) quaternion product along the
375/// sequence axis, with a cross-chunk carry.
376///
377/// This is the non-abelian analogue of the cumulative *sum of angles* used by
378/// RoPE: where complex rotations compose by adding angles (a `cumsum`),
379/// quaternions compose by multiplication, which is order-dependent, so a real
380/// scan is required.
381///
382/// # Shapes
383/// - `q_bshj4` : `[batch, sequence, nheads, J, 4]` per-step **unit** quaternions
384///   (block count `J = state_rank / 4`).
385/// - `init`    : optional carry `[batch, nheads, J, 4]` — the cumulative
386///   rotation at the end of the previous chunk (identity `(1,0,0,0)` for a fresh
387///   start).
388/// - returns `(cum, final_carry)` where `cum` is `[batch, sequence, nheads, J, 4]`
389///   with `cum[:, t] = qₜ ⊗ qₜ₋₁ ⊗ ⋯ ⊗ q₀ ⊗ init` (newest on the left, matching
390///   `Pₜ = Rₜ ⋯ R₁`), and `final_carry` `[batch, nheads, J, 4]` is `cum[:, −1]`
391///   to thread into the next chunk.
392///
393/// Running this over a split sequence while threading `final_carry` is exactly
394/// equal to running it over the whole sequence (asserted in the tests) — the
395/// chunked-prefill / streaming guarantee, here for the rotation accumulator.
396///
397/// Implemented as a **Hillis–Steele** inclusive associative scan: the quaternion
398/// product is associative (just not commutative), so a log-depth scan applies as
399/// long as operand order is preserved (newest-on-left). Each doubling step is a
400/// single full-tensor [`quat_mul`] plus a sequence shift, so the *sequential
401/// dependency depth* is `O(log sequence)` rather than the `O(sequence)` of a
402/// token-by-token loop — the same values, but a handful of large batched kernels
403/// instead of thousands of serialized tiny ones (and a correspondingly shallow
404/// autodiff graph). The sequential reference it replaces is kept as a test oracle
405/// (`quat_cumprod_sequential` in the tests module) and asserted equal on values
406/// **and** gradients.
407pub fn quat_cumprod(q_bshj4: Tensor<5>, init: Option<Tensor<4>>) -> (Tensor<5>, Tensor<4>) {
408    let [batch, sequence, nheads, blocks, _four] = q_bshj4.dims();
409    let device = q_bshj4.device();
410
411    // Pure prefix product Pₜ = qₜ ⊗ qₜ₋₁ ⊗ ⋯ ⊗ q₀ by Hillis–Steele doubling.
412    // Invariant after each step with offset `d`: a[t] holds the product of the
413    // window [t .. max(t-2d+1, 0)] (newest on the left). After ⌈log₂ sequence⌉
414    // doublings the window covers [t .. 0], i.e. a[t] = Pₜ.
415    let mut a = q_bshj4;
416    let mut offset = 1usize;
417    while offset < sequence {
418        // shifted[t] = a[t-offset] for t ≥ offset, else the identity quaternion
419        // (1,0,0,0) — so the first `offset` prefixes pass through unchanged
420        // (a ⊗ identity = a).
421        let ident = {
422            let w = Tensor::ones([batch, offset, nheads, blocks, 1], &device);
423            let xyz = Tensor::zeros([batch, offset, nheads, blocks, 3], &device);
424            Tensor::cat(vec![w, xyz], 4)
425        };
426        let shifted = Tensor::cat(vec![ident, a.clone()], 1).narrow(1, 0, sequence);
427        // Recent block (a) on the left, older block (shifted) on the right.
428        a = quat_mul(a, shifted);
429        offset *= 2;
430    }
431
432    // Fold the cross-chunk carry once: cumₜ = Pₜ ⊗ init. `init` (the previous
433    // chunk's final cumulative rotation) is the oldest factor, hence on the
434    // right; a missing carry is the identity and needs no multiply.
435    let cum = match init {
436        Some(init_bhj4) => {
437            assert_eq!([batch, nheads, blocks, 4], init_bhj4.dims());
438            quat_mul(a, init_bhj4.unsqueeze_dim::<5>(1)) // [batch, 1, nheads, J, 4] broadcasts over seq
439        }
440        None => a,
441    };
442
443    let final_carry = cum.clone().narrow(1, sequence - 1, 1).squeeze_dim::<4>(1); // [batch, nheads, J, 4]
444    (cum, final_carry)
445}
446
447// ---------------------------------------------------------------------------
448// Partial block rotation (rope_fraction support)
449// ---------------------------------------------------------------------------
450
451/// Apply a per-block quaternion rotation to the first `rope_width` entries of
452/// the `state_rank` axis (a multiple of 4); the remainder passes through. The
453/// quaternion analogue of [`apply_rope_partial`].
454///
455/// `q` has one quaternion per rotated block (`rope_width / 4` of them). `DB`
456/// must equal `D + 1`.
457pub fn rotate_blocks_partial<const D: usize, const DB: usize>(
458    v: Tensor<D>,
459    q: Tensor<DB>,
460    rope_width: usize,
461) -> Tensor<D> {
462    let r = v.dims()[D - 1];
463    if rope_width == r {
464        rotate_state_rank_blocks::<D, DB>(v, q)
465    } else {
466        let head = v.clone().narrow(D - 1, 0, rope_width);
467        let tail = v.narrow(D - 1, rope_width, r - rope_width);
468        let head_rot = rotate_state_rank_blocks::<D, DB>(head, q);
469        Tensor::cat(vec![head_rot, tail], D - 1)
470    }
471}
472
473// ---------------------------------------------------------------------------
474// Forward / step rotation of B and C (shared by both SSD pathways)
475// ---------------------------------------------------------------------------
476
477/// Rotate `B`/`C` for a **full sequence** by the data-dependent positional
478/// rotation, returning the rotated projections and the new cumulative
479/// [`RotationState`] to store in the cache.
480///
481/// Branches on [`RotationKind`]:
482/// - [`Complex2D`](RotationKind::Complex2D): the abelian RoPE — cumulative
483///   angle `cumsum` continued from `prev`, then [`apply_rope_partial`]. Exactly
484///   the original Mamba-3 behaviour.
485/// - [`Quaternion4D`](RotationKind::Quaternion4D): per-step unit quaternion
486///   [`quat_from_scaled_axis`] (the in-projection generators scaled per-head by
487///   `Δ`), composed by [`quat_cumprod`] continuing the cached quaternion, then
488///   applied to `B`/`C` as `rotate(·, conj(Qₜ))` over the first `4·blocks`
489///   state-rank entries.
490///
491/// # Shapes
492/// - `rot_bsa` : `[batch, sequence, num_rotation_channels]` — the in-projection
493///   rotation channels (angles for Complex2D, `3·blocks` quaternion generators
494///   for Quaternion4D).
495/// - `dt_bsh`  : `[batch, sequence, nheads]` (`Δ`).
496/// - `b_bsmhr` / `c_bsmhr` : `[batch, sequence, mimo_rank, nheads, state_rank]`.
497pub fn rotate_bc_forward(
498    rot_bsa: Tensor<3>,
499    dt_bsh: Tensor<3>,
500    prev: RotationState,
501    b_bsmhr: Tensor<5>,
502    c_bsmhr: Tensor<5>,
503    kind: RotationKind,
504    rope_dim: usize,
505) -> (Tensor<5>, Tensor<5>, RotationState) {
506    let [batch, sequence, mimo_rank, nheads, _state_rank] = b_bsmhr.dims();
507    match kind {
508        RotationKind::Complex2D => {
509            let prev_angle_bha = prev.angle();
510            let num_rope_angles = prev_angle_bha.dims()[2];
511            let theta_scaled_bsa = rot_bsa.tanh() * std::f32::consts::PI;
512            let raw_angles_bsha =
513                dt_bsh.unsqueeze_dim::<4>(3) * theta_scaled_bsa.unsqueeze_dim::<4>(2);
514            let cum_angles_bsha = prev_angle_bha.unsqueeze_dim::<4>(1) + raw_angles_bsha.cumsum(1);
515            let cum_angles_bsmha = cum_angles_bsha.clone().unsqueeze_dim::<5>(2).expand([
516                batch,
517                sequence,
518                mimo_rank,
519                nheads,
520                num_rope_angles,
521            ]);
522            let rotate_pairwise = mimo_rank == 1;
523            let b = apply_rope_partial::<5>(
524                b_bsmhr,
525                cum_angles_bsmha.clone(),
526                rope_dim,
527                rotate_pairwise,
528            );
529            let c = apply_rope_partial::<5>(c_bsmhr, cum_angles_bsmha, rope_dim, rotate_pairwise);
530            let last = wrap_angle(
531                cum_angles_bsha
532                    .narrow(1, sequence - 1, 1)
533                    .squeeze_dim::<3>(1),
534            );
535            (b, c, RotationState::Angle(last))
536        }
537        RotationKind::Quaternion4D => {
538            let prev_q_bhj4 = prev.quaternion();
539            let blocks = prev_q_bhj4.dims()[2];
540            let rope_width = blocks * 4;
541            // Generators [b,s,blocks,3] (shared across heads), scaled per-head by Δ.
542            //
543            // Bound the raw generator with `tanh·π` before scaling by Δ — the
544            // direct analogue of the Complex2D path (`rot.tanh()·π`). Without it
545            // the generator is unbounded, so a large in-projection activation makes
546            // `g = rot·Δ` overflow f32 to `inf`, and `quat_from_scaled_axis`'s
547            // `cos(∞)` then yields a forward NaN. The bound caps each per-step
548            // rotation to `±π·Δ` (cos/sin still give the periodicity within range);
549            // healthy `O(1)` generators stay in tanh's near-linear region.
550            let g_bshj3 = (rot_bsa.tanh() * core::f32::consts::PI)
551                .reshape([batch, sequence, blocks, 3])
552                .unsqueeze_dim::<5>(2)
553                * dt_bsh.unsqueeze_dim::<4>(3).unsqueeze_dim::<5>(4);
554            let q_step_bshj4 = quat_from_scaled_axis::<5>(g_bshj3);
555            // Memory-efficient scan: a custom recompute backward (saves only the
556            // leaf inputs) instead of retaining the scan's intermediates. Equal
557            // to [`quat_cumprod`] on values and gradients (asserted in tests).
558            let (cum_bshj4, final_bhj4) = crate::mamba3::quat_scan::quat_cumprod_recalculated(
559                q_step_bshj4,
560                Some(prev_q_bhj4),
561            );
562            // B̄ = rotate by the inverse cumulative rotation (conjugate), per block,
563            // broadcast over the mimo_rank axis.
564            let conj_bsmhj4 = quat_conj(cum_bshj4)
565                .unsqueeze_dim::<6>(2)
566                .expand([batch, sequence, mimo_rank, nheads, blocks, 4]);
567            let b = rotate_blocks_partial::<5, 6>(b_bsmhr, conj_bsmhj4.clone(), rope_width);
568            let c = rotate_blocks_partial::<5, 6>(c_bsmhr, conj_bsmhj4, rope_width);
569            (b, c, RotationState::Quaternion(quat_normalize(final_bhj4)))
570        }
571    }
572}
573
574/// Single-token counterpart of [`rotate_bc_forward`] for the recurrent `step`.
575///
576/// # Shapes
577/// - `rot_ba`  : `[batch, num_rotation_channels]`.
578/// - `dt_bh`   : `[batch, nheads]`.
579/// - `b_bmhr` / `c_bmhr` : `[batch, mimo_rank, nheads, state_rank]`.
580pub fn rotate_bc_step(
581    rot_ba: Tensor<2>,
582    dt_bh: Tensor<2>,
583    prev: RotationState,
584    b_bmhr: Tensor<4>,
585    c_bmhr: Tensor<4>,
586    kind: RotationKind,
587    rope_dim: usize,
588) -> (Tensor<4>, Tensor<4>, RotationState) {
589    let [batch, mimo_rank, nheads, _state_rank] = b_bmhr.dims();
590    match kind {
591        RotationKind::Complex2D => {
592            let prev_angle_bha = prev.angle();
593            let num_rope_angles = prev_angle_bha.dims()[2];
594            let theta_scaled_ba = rot_ba.tanh() * std::f32::consts::PI;
595            let raw_angle_bha = dt_bh.unsqueeze_dim::<3>(2) * theta_scaled_ba.unsqueeze_dim::<3>(1);
596            let new_cum_angle_bha = wrap_angle(prev_angle_bha + raw_angle_bha);
597            let new_cum_angle_bmha = new_cum_angle_bha.clone().unsqueeze_dim::<4>(1).expand([
598                batch,
599                mimo_rank,
600                nheads,
601                num_rope_angles,
602            ]);
603            let rotate_pairwise = mimo_rank == 1;
604            let b = apply_rope_partial::<4>(
605                b_bmhr,
606                new_cum_angle_bmha.clone(),
607                rope_dim,
608                rotate_pairwise,
609            );
610            let c = apply_rope_partial::<4>(c_bmhr, new_cum_angle_bmha, rope_dim, rotate_pairwise);
611            (b, c, RotationState::Angle(new_cum_angle_bha))
612        }
613        RotationKind::Quaternion4D => {
614            let prev_q_bhj4 = prev.quaternion();
615            let blocks = prev_q_bhj4.dims()[2];
616            let rope_width = blocks * 4;
617            // `tanh·π` bound, matching `rotate_bc_forward` (see the note there).
618            let g_bhj3 = (rot_ba.tanh() * core::f32::consts::PI)
619                .reshape([batch, blocks, 3])
620                .unsqueeze_dim::<4>(1)
621                * dt_bh.unsqueeze_dim::<3>(2).unsqueeze_dim::<4>(3);
622            let q_step_bhj4 = quat_from_scaled_axis::<4>(g_bhj3);
623            // Single step: Qₜ = qₜ ⊗ Qₜ₋₁.
624            let new_q_bhj4 = quat_normalize(quat_mul(q_step_bhj4, prev_q_bhj4));
625            let conj_bmhj4 = quat_conj(new_q_bhj4.clone())
626                .unsqueeze_dim::<5>(1)
627                .expand([batch, mimo_rank, nheads, blocks, 4]);
628            let b = rotate_blocks_partial::<4, 5>(b_bmhr, conj_bmhj4.clone(), rope_width);
629            let c = rotate_blocks_partial::<4, 5>(c_bmhr, conj_bmhj4, rope_width);
630            (b, c, RotationState::Quaternion(new_q_bhj4))
631        }
632    }
633}
634
635#[cfg(all(test, feature = "_dev-test"))]
636mod tests;