Skip to main content

quat_cumprod

Function quat_cumprod 

Source
pub fn quat_cumprod(
    q_bshj4: Tensor<5>,
    init: Option<Tensor<4>>,
) -> (Tensor<5>, Tensor<4>)
Expand description

Cumulative (ordered, left-accumulating) quaternion product along the sequence axis, with a cross-chunk carry.

This is the non-abelian analogue of the cumulative sum of angles used by RoPE: where complex rotations compose by adding angles (a cumsum), quaternions compose by multiplication, which is order-dependent, so a real scan is required.

§Shapes

  • q_bshj4 : [batch, sequence, nheads, J, 4] per-step unit quaternions (block count J = state_rank / 4).
  • init : optional carry [batch, nheads, J, 4] — the cumulative rotation at the end of the previous chunk (identity (1,0,0,0) for a fresh start).
  • returns (cum, final_carry) where cum is [batch, sequence, nheads, J, 4] with cum[:, t] = qₜ ⊗ qₜ₋₁ ⊗ ⋯ ⊗ q₀ ⊗ init (newest on the left, matching Pₜ = Rₜ ⋯ R₁), and final_carry [batch, nheads, J, 4] is cum[:, −1] to thread into the next chunk.

Running this over a split sequence while threading final_carry is exactly equal to running it over the whole sequence (asserted in the tests) — the chunked-prefill / streaming guarantee, here for the rotation accumulator.

Implemented as a Hillis–Steele inclusive associative scan: the quaternion product is associative (just not commutative), so a log-depth scan applies as long as operand order is preserved (newest-on-left). Each doubling step is a single full-tensor quat_mul plus a sequence shift, so the sequential dependency depth is O(log sequence) rather than the O(sequence) of a token-by-token loop — the same values, but a handful of large batched kernels instead of thousands of serialized tiny ones (and a correspondingly shallow autodiff graph). The sequential reference it replaces is kept as a test oracle (quat_cumprod_sequential in the tests module) and asserted equal on values and gradients.