Skip to main content

rotate_bc_forward

Function rotate_bc_forward 

Source
pub fn rotate_bc_forward(
    rot_bsa: Tensor<3>,
    dt_bsh: Tensor<3>,
    prev: RotationState,
    b_bsmhr: Tensor<5>,
    c_bsmhr: Tensor<5>,
    kind: RotationKind,
    rope_dim: usize,
) -> (Tensor<5>, Tensor<5>, RotationState)
Expand description

Rotate B/C for a full sequence by the data-dependent positional rotation, returning the rotated projections and the new cumulative RotationState to store in the cache.

Branches on RotationKind:

  • Complex2D: the abelian RoPE — cumulative angle cumsum continued from prev, then apply_rope_partial. Exactly the original Mamba-3 behaviour.
  • Quaternion4D: per-step unit quaternion quat_from_scaled_axis (the in-projection generators scaled per-head by Δ), composed by quat_cumprod continuing the cached quaternion, then applied to B/C as rotate(·, conj(Qₜ)) over the first 4·blocks state-rank entries.

§Shapes

  • rot_bsa : [batch, sequence, num_rotation_channels] — the in-projection rotation channels (angles for Complex2D, 3·blocks quaternion generators for Quaternion4D).
  • dt_bsh : [batch, sequence, nheads] (Δ).
  • b_bsmhr / c_bsmhr : [batch, sequence, mimo_rank, nheads, state_rank].