Skip to main content

burn_mamba/modules/misc/
gqa.rs

1//! Grouped-Query Attention (GQA) dimension expansion.
2//!
3//! Mamba-2 and Mamba-3 produce B and C projections per-group (size `ngroups`),
4//! but the chunkwise SSD algorithms consume them per-head (size `nheads`).
5//! This helper bridges the two by replicating each group's vector across the
6//! `heads_per_group = nheads / ngroups` heads of that group.
7
8use burn::prelude::*;
9
10/// Expand a tensor's `ngroups` dim at `group_dim` into an `nheads` dim, by
11/// replicating each group's slice across `heads_per_group = nheads / ngroups`
12/// heads of that group.
13///
14/// The const generic `DP1` must equal `D + 1` (the rank used during the
15/// intermediate `unsqueeze`+`expand`). Rust cannot yet express that constraint
16/// directly, so it is the caller's responsibility — supplying a wrong value
17/// produces a compile-time rank mismatch from `unsqueeze_dim::<DP1>` / `reshape`.
18///
19/// # Panics
20/// Panics if `nheads % ngroups != 0` (i.e. `nheads` is not a multiple of the
21/// current group count).
22///
23/// # Example
24/// ```ignore
25/// // b_bnlgr: [batch, nchunks, chunk_len, ngroups, state_rank] (D=5)
26/// // group_dim = 3 (the ngroups axis)
27/// // result:  [batch, nchunks, chunk_len, nheads,  state_rank]
28/// let b_bnlhr = gqa_expand_to_heads::<_, 5, 6>(b_bnlgr, 3, nheads);
29/// ```
30pub fn gqa_expand_to_heads<const D: usize, const DP1: usize>(
31    t: Tensor<D>,
32    group_dim: usize,
33    nheads: usize,
34) -> Tensor<D> {
35    let dims = t.dims();
36    let ngroups = dims[group_dim];
37    assert!(
38        nheads.is_multiple_of(ngroups),
39        "nheads ({nheads}) must be a multiple of ngroups ({ngroups})"
40    );
41    let heads_per_group = nheads / ngroups;
42
43    // Expanded shape: insert `heads_per_group` immediately after `group_dim`.
44    let mut expanded = [0usize; DP1];
45    expanded[..=group_dim].copy_from_slice(&dims[..=group_dim]);
46    expanded[group_dim + 1] = heads_per_group;
47    expanded[group_dim + 2..].copy_from_slice(&dims[group_dim + 1..]);
48
49    // Final shape: collapse `(ngroups, heads_per_group)` back into `nheads`.
50    let mut final_shape = dims;
51    final_shape[group_dim] = nheads;
52
53    t.unsqueeze_dim::<DP1>(group_dim + 1)
54        .expand(expanded)
55        .reshape(final_shape)
56}