burn_mamba/utils/class/mod.rs
1use burn::config::Config;
2use burn::module::Param;
3use burn::nn::Initializer;
4use burn::prelude::*;
5
6#[cfg(all(test, feature = "_dev-test"))]
7mod tests;
8
9// ===========================================================================
10// Class tokens / latents (learnable sequence-inserted tokens)
11// ===========================================================================
12//
13// A *class token* / *class latent* is a learnable embedding spliced into the
14// sequence — a transformer-`[CLS]`-style register the model can read/write
15// through. They are inserted at the input boundary of a container (a network's
16// input for [`ClassToken`], width = the input feature width; a layer's working
17// sequence for [`ClassLatent`], width = `d_model`), permanently lengthening the
18// sequence for everything downstream. A container can carry any number; the
19// markers below say *where* each one lands, while a single `Param<Tensor<2>>`
20// of shape `[num_markers, width]` holds the embeddings (row `i` ↔ marker `i`).
21//
22// Insertion order (all relative to the *original* length `L`): every `Start`
23// first (index 0), then `Middle` (index `L/2`, splitting the original
24// sequence), then `End` (index `L`), then `Custom(index)` (explicit index,
25// inserted last). Markers sharing an index keep their `Vec` order. Because
26// `Middle`/`End` materialise positions that a single-token `step()` cannot
27// reproduce, their presence makes `step()` panic; `Start`/`Custom` are a
28// forward-time concern and are simply not re-inserted during `step()`.
29
30/// Position marker for a learnable class **token** inserted into a *network's*
31/// input sequence (embedding width = the network input width / "d_input").
32#[derive(Config, Debug)]
33pub enum ClassToken {
34 /// Prepend before the whole sequence (index 0).
35 Start,
36 /// Insert at the middle of the original sequence (index `L/2`).
37 /// Incompatible with `step()` calls.
38 Middle,
39 /// Append after the whole sequence (index `L`).
40 /// Incompatible with `step()` calls.
41 End,
42 /// Insert at an explicit index into the original sequence.
43 Custom(usize),
44}
45
46/// Position marker for a learnable class **latent** inserted into a *layer's*
47/// working sequence (embedding width = `d_model`).
48#[derive(Config, Debug)]
49pub enum ClassLatent {
50 /// Prepend before the whole sequence (index 0).
51 Start,
52 /// Insert at the middle of the original sequence (index `L/2`).
53 /// Incompatible with `step()` calls.
54 Middle,
55 /// Append after the whole sequence (index `L`).
56 /// Incompatible with `step()` calls.
57 End,
58 /// Insert at an explicit index into the original sequence.
59 Custom(usize),
60}
61
62/// Shared behaviour of the [`ClassToken`] / [`ClassLatent`] position markers,
63/// letting one generic helper place either kind.
64pub trait ClassMarker: Clone {
65 /// Insertion index measured against the *original* sequence length `orig_len`.
66 fn insert_pos(&self, orig_len: usize) -> usize;
67 /// Tie-break rank among markers sharing an index (`Start`<`Middle`<`End`<`Custom`).
68 fn group_rank(&self) -> usize;
69 /// Whether this marker is incompatible with single-token `step()`
70 /// (`Middle`/`End` create positions a per-token recurrence cannot reproduce).
71 fn forbids_step(&self) -> bool;
72}
73
74macro_rules! impl_class_marker {
75 ($ty:ty) => {
76 impl ClassMarker for $ty {
77 fn insert_pos(&self, orig_len: usize) -> usize {
78 match self {
79 Self::Start => 0,
80 Self::Middle => orig_len / 2,
81 Self::End => orig_len,
82 Self::Custom(index) => *index,
83 }
84 }
85 fn group_rank(&self) -> usize {
86 match self {
87 Self::Start => 0,
88 Self::Middle => 1,
89 Self::End => 2,
90 Self::Custom(_) => 3,
91 }
92 }
93 fn forbids_step(&self) -> bool {
94 matches!(self, Self::Middle | Self::End)
95 }
96 }
97 };
98}
99impl_class_marker!(ClassToken);
100impl_class_marker!(ClassLatent);
101
102/// Insert the learnable class tokens `emb` (`[k, width]`, row `i` ↔ `markers[i]`)
103/// into `x` (`[batch, orig_len, width]`) per the `markers`, returning the
104/// lengthened sequence (`[batch, orig_len + k, width]`) and, for each marker in
105/// `Vec` order, its position in the output sequence.
106///
107/// `markers` empty ⇒ `x` is returned unchanged with an empty index vector.
108pub(crate) fn insert_class_markers<M: ClassMarker>(
109 x: Tensor<3>,
110 markers: &[M],
111 emb: Option<&Param<Tensor<2>>>,
112) -> (Tensor<3>, Vec<usize>) {
113 let [batch, orig_len, width] = x.dims();
114 let k = markers.len();
115 if k == 0 {
116 return (x, Vec::new());
117 }
118 let emb = emb
119 .expect("class-token markers present but no embedding param")
120 .val();
121 assert_eq!(emb.dims(), [k, width], "one embedding row per class marker");
122
123 // Emit in (insert_pos, group_rank, vec order) order.
124 let mut order: Vec<usize> = (0..k).collect();
125 order.sort_by_key(|&i| (markers[i].insert_pos(orig_len), markers[i].group_rank(), i));
126
127 let mut segments: Vec<Tensor<3>> = Vec::new();
128 let mut cursor = 0usize; // consumed prefix of the original sequence
129 let mut out_len = 0usize; // length emitted so far
130 let mut out_index = vec![0usize; k];
131 for &i in &order {
132 let p = markers[i].insert_pos(orig_len);
133 assert!(
134 p <= orig_len,
135 "class-token insert index {p} > sequence length {orig_len}"
136 );
137 if p > cursor {
138 segments.push(x.clone().narrow(1, cursor, p - cursor));
139 out_len += p - cursor;
140 cursor = p;
141 }
142 let row = emb
143 .clone()
144 .narrow(0, i, 1) // [1, width]
145 .unsqueeze_dim::<3>(0) // [1, 1, width]
146 .expand([batch, 1, width]);
147 segments.push(row);
148 out_index[i] = out_len;
149 out_len += 1;
150 }
151 if cursor < orig_len {
152 segments.push(x.narrow(1, cursor, orig_len - cursor));
153 }
154 let out = Tensor::cat(segments, 1);
155 assert_eq!(out.dims(), [batch, orig_len + k, width]);
156 (out, out_index)
157}
158
159/// The output-sequence position of each marker (in `Vec` order) for an input of
160/// length `orig_len`, without materialising any tensor. Mirrors the placement in
161/// [`insert_class_markers`] — useful for reading a class token back out.
162pub(crate) fn class_marker_output_indices<M: ClassMarker>(
163 markers: &[M],
164 orig_len: usize,
165) -> Vec<usize> {
166 let k = markers.len();
167 let mut order: Vec<usize> = (0..k).collect();
168 order.sort_by_key(|&i| (markers[i].insert_pos(orig_len), markers[i].group_rank(), i));
169 let mut cursor = 0usize;
170 let mut out_len = 0usize;
171 let mut out_index = vec![0usize; k];
172 for &i in &order {
173 let p = markers[i].insert_pos(orig_len).min(orig_len);
174 if p > cursor {
175 out_len += p - cursor;
176 cursor = p;
177 }
178 out_index[i] = out_len;
179 out_len += 1;
180 }
181 out_index
182}
183
184/// Build the embedding param for `n` class markers of the given `width`
185/// (`None` when there are none — Burn has no zero-width tensors).
186pub(crate) fn init_class_emb(n: usize, width: usize, device: &Device) -> Option<Param<Tensor<2>>> {
187 (n > 0).then(|| {
188 Initializer::Normal {
189 mean: 0.0,
190 std: 0.02,
191 }
192 .init([n, width], device)
193 })
194}
195
196/// Panic if any marker is incompatible with single-token `step()`.
197pub(crate) fn assert_step_compatible<M: ClassMarker>(markers: &[M], who: &str) {
198 assert!(
199 !markers.iter().any(|m| m.forbids_step()),
200 "{who}: Middle/End class tokens are not compatible with step()"
201 );
202}
203
204/// The output-sequence position of each step-injectable marker (in `Vec` order),
205/// for use by `step`'s cursor. Asserts no `Middle`/`End` (those need the full
206/// length — `forward` only). `Start`/`Custom` positions are length-independent,
207/// so an unbounded `orig_len` resolves them exactly.
208pub(crate) fn class_step_injections<M: ClassMarker>(markers: &[M], who: &str) -> Vec<usize> {
209 assert_step_compatible(markers, who);
210 class_marker_output_indices(markers, usize::MAX)
211}