pub fn gqa_expand_to_heads<const D: usize, const DP1: usize>(
t: Tensor<D>,
group_dim: usize,
nheads: usize,
) -> Tensor<D>Expand description
Expand a tensor’s ngroups dim at group_dim into an nheads dim, by
replicating each group’s slice across heads_per_group = nheads / ngroups
heads of that group.
The const generic DP1 must equal D + 1 (the rank used during the
intermediate unsqueeze+expand). Rust cannot yet express that constraint
directly, so it is the caller’s responsibility — supplying a wrong value
produces a compile-time rank mismatch from unsqueeze_dim::<DP1> / reshape.
§Panics
Panics if nheads % ngroups != 0 (i.e. nheads is not a multiple of the
current group count).
§Example
ⓘ
// b_bnlgr: [batch, nchunks, chunk_len, ngroups, state_rank] (D=5)
// group_dim = 3 (the ngroups axis)
// result: [batch, nchunks, chunk_len, nheads, state_rank]
let b_bnlhr = gqa_expand_to_heads::<_, 5, 6>(b_bnlgr, 3, nheads);