Skip to main content

gqa_expand_to_heads

Function gqa_expand_to_heads 

Source
pub fn gqa_expand_to_heads<const D: usize, const DP1: usize>(
    t: Tensor<D>,
    group_dim: usize,
    nheads: usize,
) -> Tensor<D>
Expand description

Expand a tensor’s ngroups dim at group_dim into an nheads dim, by replicating each group’s slice across heads_per_group = nheads / ngroups heads of that group.

The const generic DP1 must equal D + 1 (the rank used during the intermediate unsqueeze+expand). Rust cannot yet express that constraint directly, so it is the caller’s responsibility — supplying a wrong value produces a compile-time rank mismatch from unsqueeze_dim::<DP1> / reshape.

§Panics

Panics if nheads % ngroups != 0 (i.e. nheads is not a multiple of the current group count).

§Example

// b_bnlgr: [batch, nchunks, chunk_len, ngroups, state_rank] (D=5)
// group_dim = 3 (the ngroups axis)
// result:  [batch, nchunks, chunk_len, nheads,  state_rank]
let b_bnlhr = gqa_expand_to_heads::<_, 5, 6>(b_bnlgr, 3, nheads);