pub fn quat_cumprod(
q_bshj4: Tensor<5>,
init: Option<Tensor<4>>,
) -> (Tensor<5>, Tensor<4>)Expand description
Cumulative (ordered, left-accumulating) quaternion product along the sequence axis, with a cross-chunk carry.
This is the non-abelian analogue of the cumulative sum of angles used by
RoPE: where complex rotations compose by adding angles (a cumsum),
quaternions compose by multiplication, which is order-dependent, so a real
scan is required.
§Shapes
q_bshj4:[batch, sequence, nheads, J, 4]per-step unit quaternions (block countJ = state_rank / 4).init: optional carry[batch, nheads, J, 4]— the cumulative rotation at the end of the previous chunk (identity(1,0,0,0)for a fresh start).- returns
(cum, final_carry)wherecumis[batch, sequence, nheads, J, 4]withcum[:, t] = qₜ ⊗ qₜ₋₁ ⊗ ⋯ ⊗ q₀ ⊗ init(newest on the left, matchingPₜ = Rₜ ⋯ R₁), andfinal_carry[batch, nheads, J, 4]iscum[:, −1]to thread into the next chunk.
Running this over a split sequence while threading final_carry is exactly
equal to running it over the whole sequence (asserted in the tests) — the
chunked-prefill / streaming guarantee, here for the rotation accumulator.
Implemented as a Hillis–Steele inclusive associative scan: the quaternion
product is associative (just not commutative), so a log-depth scan applies as
long as operand order is preserved (newest-on-left). Each doubling step is a
single full-tensor quat_mul plus a sequence shift, so the sequential
dependency depth is O(log sequence) rather than the O(sequence) of a
token-by-token loop — the same values, but a handful of large batched kernels
instead of thousands of serialized tiny ones (and a correspondingly shallow
autodiff graph). The sequential reference it replaces is kept as a test oracle
(quat_cumprod_sequential in the tests module) and asserted equal on values
and gradients.