Skip to main content

Module rotation

Module rotation 

Source
Expand description

§Quaternion (k=4) rotational state — the non-abelian generalisation of RoPE

Mamba-3’s data-dependent RoPE realises a complex-valued SSM: the state transition factors as a per-head scalar decay times a block-diagonal of 2×2 rotations (paper Prop. Complex-to-Real SSM Equivalence), and because SO(2) ≅ U(1) is abelian the cumulative rotation collapses to a cumsum of angles and is absorbed into B/C (the “RoPE trick”, Prop. Complex SSM, Data-Dependent RoPE Equivalence). See [crate::mamba3::double_ssd::double_ssd::apply_rope].

This module implements the next rung of the ladder: a quaternion (k = 4) rotational state, i.e. the transition’s rotation lives in the left-isoclinic subgroup SU(2) ⊂ SO(4) instead of SO(2). Unit quaternions under multiplication are SU(2), which is non-abelian and contains non-solvable finite subgroups (the binary icosahedral group 2I = SL(2,5), a double cover of A₅). By Barrington’s theorem this lifts the layer’s reachable state-tracking from the solvable/TC⁰ regime (parity, mod-k) toward NC¹, which abelian rotations provably cannot reach.

§What survives, what changes

The key fact (derivable purely from telescoping + orthogonality, without commutativity — see the crate discussion) is that the RoPE factoring survives intact: with the ordered cumulative rotation Pₜ = Rₜ Rₜ₋₁ ⋯ R₁,

  Cₜᵀ (Rₜ⋯Rᵢ₊₁) Bᵢ  =  (Pₜᵀ Cₜ)ᵀ (Pᵢᵀ Bᵢ)  =  C̄ₜᵀ B̄ᵢ ,

so the scalar-decay SSD core (L ⊙ C̄B̄ᵀ) is unchanged — only the projections B̄ᵢ = Pᵢᵀ Bᵢ, C̄ₜ = Pₜᵀ Cₜ are rotated. What is lost is the closed-form cumsum: the cumulative rotation must be built by an associative scan over the per-step quaternions (quat_cumprod) rather than a sum of angles. Because a product of unit quaternions is again a unit quaternion, the scan stays exactly orthogonal (no drift, no wrap_angle needed), and the cross-chunk carry is a single quaternion per block/head — the exact analogue of cum_angle in the existing caches.

SO(2) (today’s apply_rope) is the abelian collapse: restricting each quaternion to a single fixed axis makes them commute and reduces quat_cumprod to a cumsum of half-angles (asserted in the tests).

§Pipeline (the k = 4 instantiation of the rotation block)

  per-step unit quaternion qₜ      (materialise from the in-projection; caller)
       │  quat_cumprod (assoc. scan, + cross-chunk carry)
       ▼
  cumulative rotation Qₜ
       │  rotate_state_rank_blocks(B, conj(Qₜ)) , rotate_state_rank_blocks(C, conj(Qₜ))
       ▼
  B̄, C̄  ──►  standard scalar-decay SSD  (unchanged)

Quaternion layout: the last axis has size 4 and holds (w, x, y, z) with w the real part. A state_rank of r = 4·J is treated as J independent quaternion blocks; the rotation acts within each block, exactly as RoPE acts within each 2-pair. This module is a self-contained, tested reference for the math; wiring it into the Mamba3 block is a separate, larger change (the SSD kernels themselves need no edits).

Enums§

RotationKind
Which rotational-state algebra the block uses for the data-dependent positional rotation of B/C.
RotationState
The cumulative-rotation accumulator carried between calls in a Mamba-3 cache — the variant matching the block’s RotationKind.
RotationStateRecord
The record type for the module.
RotationStateRecordItem
The record item type for the module.

Functions§

quat_conj
Quaternion conjugate q* = (w, −x, −y, −z) (shape [..., 4]).
quat_cumprod
Cumulative (ordered, left-accumulating) quaternion product along the sequence axis, with a cross-chunk carry.
quat_from_scaled_axis
Materialise a unit quaternion from a scaled rotation vector g ∈ ℝ³ (axis · angle) via the exponential map — the data-dependent “materialise Rₜ” step, analogous to RoPE’s Δₜ · π · tanh(θₜ) angle.
quat_mul
Hamilton product a ⊗ b of two quaternion tensors.
quat_normalize
Normalise quaternions to unit norm along the last axis (shape [..., 4]).
quat_to_rot4
Materialise the 4×4 orthogonal matrix of left-multiplication by q.
rotate_bc_forward
Rotate B/C for a full sequence by the data-dependent positional rotation, returning the rotated projections and the new cumulative RotationState to store in the cache.
rotate_bc_step
Single-token counterpart of rotate_bc_forward for the recurrent step.
rotate_blocks_partial
Apply a per-block quaternion rotation to the first rope_width entries of the state_rank axis (a multiple of 4); the remainder passes through. The quaternion analogue of apply_rope_partial.
rotate_state_rank_blocks
Apply a per-block quaternion rotation to the state_rank axis of v.