Skip to main content

burn_mamba/utils/class/
mod.rs

1use burn::config::Config;
2use burn::module::Param;
3use burn::nn::Initializer;
4use burn::prelude::*;
5
6#[cfg(all(test, feature = "_dev-test"))]
7mod tests;
8
9// ===========================================================================
10// Class tokens / latents (learnable sequence-inserted tokens)
11// ===========================================================================
12//
13// A *class token* / *class latent* is a learnable embedding spliced into the
14// sequence — a transformer-`[CLS]`-style register the model can read/write
15// through. They are inserted at the input boundary of a container (a network's
16// input for [`ClassToken`], width = the input feature width; a layer's working
17// sequence for [`ClassLatent`], width = `d_model`), permanently lengthening the
18// sequence for everything downstream. A container can carry any number; the
19// markers below say *where* each one lands, while a single `Param<Tensor<2>>`
20// of shape `[num_markers, width]` holds the embeddings (row `i` ↔ marker `i`).
21//
22// Insertion order (all relative to the *original* length `L`): every `Start`
23// first (index 0), then `Middle` (index `L/2`, splitting the original
24// sequence), then `End` (index `L`), then `Custom(index)` (explicit index,
25// inserted last). Markers sharing an index keep their `Vec` order. Because
26// `Middle`/`End` materialise positions that a single-token `step()` cannot
27// reproduce, their presence makes `step()` panic; `Start`/`Custom` are a
28// forward-time concern and are simply not re-inserted during `step()`.
29
30/// Position marker for a learnable class **token** inserted into a *network's*
31/// input sequence (embedding width = the network input width / "d_input").
32#[derive(Config, Debug)]
33pub enum ClassToken {
34    /// Prepend before the whole sequence (index 0).
35    Start,
36    /// Insert at the middle of the original sequence (index `L/2`).  
37    /// Incompatible with `step()` calls.
38    Middle,
39    /// Append after the whole sequence (index `L`).  
40    /// Incompatible with `step()` calls.
41    End,
42    /// Insert at an explicit index into the original sequence.
43    Custom(usize),
44}
45
46/// Position marker for a learnable class **latent** inserted into a *layer's*
47/// working sequence (embedding width = `d_model`).
48#[derive(Config, Debug)]
49pub enum ClassLatent {
50    /// Prepend before the whole sequence (index 0).
51    Start,
52    /// Insert at the middle of the original sequence (index `L/2`).
53    /// Incompatible with `step()` calls.  
54    Middle,
55    /// Append after the whole sequence (index `L`).
56    /// Incompatible with `step()` calls.  
57    End,
58    /// Insert at an explicit index into the original sequence.
59    Custom(usize),
60}
61
62/// Shared behaviour of the [`ClassToken`] / [`ClassLatent`] position markers,
63/// letting one generic helper place either kind.
64pub trait ClassMarker: Clone {
65    /// Insertion index measured against the *original* sequence length `orig_len`.
66    fn insert_pos(&self, orig_len: usize) -> usize;
67    /// Tie-break rank among markers sharing an index (`Start`<`Middle`<`End`<`Custom`).
68    fn group_rank(&self) -> usize;
69    /// Whether this marker is incompatible with single-token `step()`
70    /// (`Middle`/`End` create positions a per-token recurrence cannot reproduce).
71    fn forbids_step(&self) -> bool;
72}
73
74macro_rules! impl_class_marker {
75    ($ty:ty) => {
76        impl ClassMarker for $ty {
77            fn insert_pos(&self, orig_len: usize) -> usize {
78                match self {
79                    Self::Start => 0,
80                    Self::Middle => orig_len / 2,
81                    Self::End => orig_len,
82                    Self::Custom(index) => *index,
83                }
84            }
85            fn group_rank(&self) -> usize {
86                match self {
87                    Self::Start => 0,
88                    Self::Middle => 1,
89                    Self::End => 2,
90                    Self::Custom(_) => 3,
91                }
92            }
93            fn forbids_step(&self) -> bool {
94                matches!(self, Self::Middle | Self::End)
95            }
96        }
97    };
98}
99impl_class_marker!(ClassToken);
100impl_class_marker!(ClassLatent);
101
102/// Insert the learnable class tokens `emb` (`[k, width]`, row `i` ↔ `markers[i]`)
103/// into `x` (`[batch, orig_len, width]`) per the `markers`, returning the
104/// lengthened sequence (`[batch, orig_len + k, width]`) and, for each marker in
105/// `Vec` order, its position in the output sequence.
106///
107/// `markers` empty ⇒ `x` is returned unchanged with an empty index vector.
108pub(crate) fn insert_class_markers<M: ClassMarker>(
109    x: Tensor<3>,
110    markers: &[M],
111    emb: Option<&Param<Tensor<2>>>,
112) -> (Tensor<3>, Vec<usize>) {
113    let [batch, orig_len, width] = x.dims();
114    let k = markers.len();
115    if k == 0 {
116        return (x, Vec::new());
117    }
118    let emb = emb
119        .expect("class-token markers present but no embedding param")
120        .val();
121    assert_eq!(emb.dims(), [k, width], "one embedding row per class marker");
122
123    // Emit in (insert_pos, group_rank, vec order) order.
124    let mut order: Vec<usize> = (0..k).collect();
125    order.sort_by_key(|&i| (markers[i].insert_pos(orig_len), markers[i].group_rank(), i));
126
127    let mut segments: Vec<Tensor<3>> = Vec::new();
128    let mut cursor = 0usize; // consumed prefix of the original sequence
129    let mut out_len = 0usize; // length emitted so far
130    let mut out_index = vec![0usize; k];
131    for &i in &order {
132        let p = markers[i].insert_pos(orig_len);
133        assert!(
134            p <= orig_len,
135            "class-token insert index {p} > sequence length {orig_len}"
136        );
137        if p > cursor {
138            segments.push(x.clone().narrow(1, cursor, p - cursor));
139            out_len += p - cursor;
140            cursor = p;
141        }
142        let row = emb
143            .clone()
144            .narrow(0, i, 1) // [1, width]
145            .unsqueeze_dim::<3>(0) // [1, 1, width]
146            .expand([batch, 1, width]);
147        segments.push(row);
148        out_index[i] = out_len;
149        out_len += 1;
150    }
151    if cursor < orig_len {
152        segments.push(x.narrow(1, cursor, orig_len - cursor));
153    }
154    let out = Tensor::cat(segments, 1);
155    assert_eq!(out.dims(), [batch, orig_len + k, width]);
156    (out, out_index)
157}
158
159/// The output-sequence position of each marker (in `Vec` order) for an input of
160/// length `orig_len`, without materialising any tensor. Mirrors the placement in
161/// [`insert_class_markers`] — useful for reading a class token back out.
162pub(crate) fn class_marker_output_indices<M: ClassMarker>(
163    markers: &[M],
164    orig_len: usize,
165) -> Vec<usize> {
166    let k = markers.len();
167    let mut order: Vec<usize> = (0..k).collect();
168    order.sort_by_key(|&i| (markers[i].insert_pos(orig_len), markers[i].group_rank(), i));
169    let mut cursor = 0usize;
170    let mut out_len = 0usize;
171    let mut out_index = vec![0usize; k];
172    for &i in &order {
173        let p = markers[i].insert_pos(orig_len).min(orig_len);
174        if p > cursor {
175            out_len += p - cursor;
176            cursor = p;
177        }
178        out_index[i] = out_len;
179        out_len += 1;
180    }
181    out_index
182}
183
184/// Build the embedding param for `n` class markers of the given `width`
185/// (`None` when there are none — Burn has no zero-width tensors).
186pub(crate) fn init_class_emb(n: usize, width: usize, device: &Device) -> Option<Param<Tensor<2>>> {
187    (n > 0).then(|| {
188        Initializer::Normal {
189            mean: 0.0,
190            std: 0.02,
191        }
192        .init([n, width], device)
193    })
194}
195
196/// Panic if any marker is incompatible with single-token `step()`.
197pub(crate) fn assert_step_compatible<M: ClassMarker>(markers: &[M], who: &str) {
198    assert!(
199        !markers.iter().any(|m| m.forbids_step()),
200        "{who}: Middle/End class tokens are not compatible with step()"
201    );
202}
203
204/// The output-sequence position of each step-injectable marker (in `Vec` order),
205/// for use by `step`'s cursor. Asserts no `Middle`/`End` (those need the full
206/// length — `forward` only). `Start`/`Custom` positions are length-independent,
207/// so an unbounded `orig_len` resolves them exactly.
208pub(crate) fn class_step_injections<M: ClassMarker>(markers: &[M], who: &str) -> Vec<usize> {
209    assert_step_compatible(markers, who);
210    class_marker_output_indices(markers, usize::MAX)
211}