burn_mamba/modules/misc/gqa.rs
1//! Grouped-Query Attention (GQA) dimension expansion.
2//!
3//! Mamba-2 and Mamba-3 produce B and C projections per-group (size `ngroups`),
4//! but the chunkwise SSD algorithms consume them per-head (size `nheads`).
5//! This helper bridges the two by replicating each group's vector across the
6//! `heads_per_group = nheads / ngroups` heads of that group.
7
8use burn::prelude::*;
9
10/// Expand a tensor's `ngroups` dim at `group_dim` into an `nheads` dim, by
11/// replicating each group's slice across `heads_per_group = nheads / ngroups`
12/// heads of that group.
13///
14/// The const generic `DP1` must equal `D + 1` (the rank used during the
15/// intermediate `unsqueeze`+`expand`). Rust cannot yet express that constraint
16/// directly, so it is the caller's responsibility — supplying a wrong value
17/// produces a compile-time rank mismatch from `unsqueeze_dim::<DP1>` / `reshape`.
18///
19/// # Panics
20/// Panics if `nheads % ngroups != 0` (i.e. `nheads` is not a multiple of the
21/// current group count).
22///
23/// # Example
24/// ```ignore
25/// // b_bnlgr: [batch, nchunks, chunk_len, ngroups, state_rank] (D=5)
26/// // group_dim = 3 (the ngroups axis)
27/// // result: [batch, nchunks, chunk_len, nheads, state_rank]
28/// let b_bnlhr = gqa_expand_to_heads::<_, 5, 6>(b_bnlgr, 3, nheads);
29/// ```
30pub fn gqa_expand_to_heads<const D: usize, const DP1: usize>(
31 t: Tensor<D>,
32 group_dim: usize,
33 nheads: usize,
34) -> Tensor<D> {
35 let dims = t.dims();
36 let ngroups = dims[group_dim];
37 assert!(
38 nheads.is_multiple_of(ngroups),
39 "nheads ({nheads}) must be a multiple of ngroups ({ngroups})"
40 );
41 let heads_per_group = nheads / ngroups;
42
43 // Expanded shape: insert `heads_per_group` immediately after `group_dim`.
44 let mut expanded = [0usize; DP1];
45 expanded[..=group_dim].copy_from_slice(&dims[..=group_dim]);
46 expanded[group_dim + 1] = heads_per_group;
47 expanded[group_dim + 2..].copy_from_slice(&dims[group_dim + 1..]);
48
49 // Final shape: collapse `(ngroups, heads_per_group)` back into `nheads`.
50 let mut final_shape = dims;
51 final_shape[group_dim] = nheads;
52
53 t.unsqueeze_dim::<DP1>(group_dim + 1)
54 .expand(expanded)
55 .reshape(final_shape)
56}