1use crate::modules::{Residuals, ResidualsConfig, RmsNormConfig};
2use crate::prelude::*;
3use crate::utils::ClassLatent;
4use crate::utils::Schedule;
5use crate::utils::class::{
6 assert_step_compatible, class_marker_output_indices, class_step_injections, init_class_emb,
7 insert_class_markers,
8};
9use burn::module::Param;
10use burn::prelude::*;
11
12#[derive(Module, Debug)]
15pub struct Layers<M: Module> {
16 pub n_real_layers: usize,
18 #[module(skip)]
20 pub n_virtual_layers: Option<(usize, Schedule)>,
21 pub real_layers: Vec<Layer<M>>,
23 pub ignore_first_residual: bool,
25 pub ignore_last_residual: bool,
27 pub residuals: Residuals,
29 #[module(skip)]
33 pub class_latents: Vec<ClassLatent>,
34 pub class_latents_emb: Option<Param<Tensor<2>>>,
36}
37
38impl<M: MambaBlock> Layers<M>
39where
40 M::SsdPath: Clone,
41{
42 pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize> {
44 class_marker_output_indices(&self.class_latents, orig_len)
45 }
46
47 fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
49 if self.class_latents_emb.is_none() {
50 return x;
51 }
52 insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
53 }
54
55 fn n_virtual_count(&self) -> usize {
56 self.n_virtual_layers
57 .as_ref()
58 .map(|(l, _)| *l)
59 .unwrap_or(self.n_real_layers)
60 }
61
62 fn real_idx(&self, virtual_idx: usize) -> usize {
63 if let Some((n, schedule)) = &self.n_virtual_layers {
64 schedule.real_idx(virtual_idx, *n, self.n_real_layers)
65 } else {
66 virtual_idx
67 }
68 }
69
70 fn skip_residual(&self, i: usize, n: usize) -> bool {
73 (self.ignore_first_residual && i == 0) || (self.ignore_last_residual && i + 1 == n)
74 }
75
76 pub fn forward(
95 &self,
96 x: Tensor<3>,
97 caches: Option<M::Caches>,
98 ssd_path: M::SsdPath,
99 ) -> (Tensor<3>, M::Caches) {
100 let mut x = self.insert_latents(x);
101 let n = self.n_virtual_count();
102 let caches =
103 caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_3d(&x, n));
104 assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
105 let mut slots = caches.into_slots();
106
107 let mut streams = self.multi_gate_streams_seed(&x);
110
111 for i in 0..n {
112 let real = self.real_idx(i);
113 let layer = &self.real_layers[real];
114 let cache = slots[i].take().unwrap();
115 let first = self.ignore_first_residual && i == 0;
116 let last = self.ignore_last_residual && i + 1 == n;
117 match &self.residuals {
118 Residuals::Standard(_noop) => {
119 let x_l = layer.insert_latents(x);
123 let (out, c_) = if first || last {
124 layer.forward(x_l, Some(cache), ssd_path.clone())
125 } else {
126 let (out, c_) = layer.forward(x_l.clone(), Some(cache), ssd_path.clone());
127 (out + x_l, c_)
128 };
129 x = out;
130 slots[i] = Some(c_);
131 }
132 Residuals::MultiGate(mg) => {
133 assert!(
134 layer.class_latents_emb.is_none(),
135 "MultiGate residuals do not support per-layer class latents"
136 );
137 let (out, c_) = layer.forward(x, Some(cache), ssd_path.clone());
138 slots[i] = Some(c_);
139 let s = streams.take().unwrap();
140 if last {
145 x = out;
147 streams = Some(s);
148 } else if first {
149 let [b, seq, d] = out.dims();
151 streams = Some(out.clone().unsqueeze_dim::<4>(2).expand([
152 b,
153 seq,
154 mg.n_stream,
155 d,
156 ]));
157 x = out;
158 } else {
159 let idx = mg.module_index(i, real);
160 let (new_h, new_streams) = mg.layers[idx].forward(out, s);
161 x = new_h;
162 streams = Some(new_streams);
163 }
164 }
165 }
166 }
167 (x, M::Caches::from_slots(slots))
168 }
169
170 fn multi_gate_streams_seed(&self, x: &Tensor<3>) -> Option<Tensor<4>> {
174 let Residuals::MultiGate(mg) = &self.residuals else {
175 return None;
176 };
177 assert!(
178 self.class_latents_emb.is_none(),
179 "MultiGate residuals do not support stack-level class latents"
180 );
181 let [batch, sequence, d_model] = x.dims();
182 Some(
183 x.clone()
184 .unsqueeze_dim::<4>(2)
185 .expand([batch, sequence, mg.n_stream, d_model]),
186 )
187 }
188
189 pub fn step(
211 &self,
212 x: Tensor<2>,
213 caches: Option<M::Caches>,
214 own_index: Option<&mut usize>,
215 mut layer_indices: Option<&mut Vec<usize>>,
216 ) -> (Tensor<2>, M::Caches) {
217 if let Residuals::MultiGate(mg) = &self.residuals {
218 return self.step_multi_gate(x, caches, mg);
219 }
220 let [batch, d_model] = x.dims();
221 let n = self.n_virtual_count();
222 let caches =
223 caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
224 assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
225 if let Some(v) = layer_indices.as_deref() {
226 assert_eq!(v.len(), n, "one class-latent cursor per virtual layer");
227 }
228 let mut slots = caches.into_slots();
229
230 let mut stream: Vec<Tensor<2>> = Vec::new();
234 if let Some(own_cursor) = own_index {
235 let positions = class_step_injections(&self.class_latents, "Layers");
236 let emb = self.class_latents_emb.as_ref();
237 while let Some(i) = positions.iter().position(|&p| p == *own_cursor) {
238 stream.push(emb.unwrap().val().narrow(0, i, 1).expand([batch, d_model]));
239 *own_cursor += 1;
240 }
241 stream.push(x);
242 *own_cursor += 1;
243 } else {
244 assert_step_compatible(&self.class_latents, "Layers");
245 stream.push(x);
246 }
247
248 for pos in 0..n {
251 let layer = &self.real_layers[self.real_idx(pos)];
252 let skip = self.skip_residual(pos, n);
253 let mut layer_cursor = layer_indices.as_deref_mut().map(|v| &mut v[pos]);
254 let positions = if layer_cursor.is_some() {
255 class_step_injections(&layer.class_latents, "Layer")
256 } else {
257 assert_step_compatible(&layer.class_latents, "Layer");
258 Vec::new()
259 };
260 let emb = layer.class_latents_emb.as_ref();
261 let mut cache = slots[pos].take();
262 let mut next: Vec<Tensor<2>> = Vec::with_capacity(stream.len());
263 let run = |token: Tensor<2>, cache: Option<M::Cache>| {
266 if skip {
267 layer.step(token, cache, None)
268 } else {
269 let (out, c) = layer.step(token.clone(), cache, None);
270 (out + token, c)
271 }
272 };
273 for token in stream {
274 if let Some(cursor) = layer_cursor.as_deref_mut() {
276 while let Some(i) = positions.iter().position(|&p| p == *cursor) {
277 let row = emb.unwrap().val().narrow(0, i, 1).expand([batch, d_model]);
278 let (out, c) = run(row, cache);
279 next.push(out);
280 cache = Some(c);
281 *cursor += 1;
282 }
283 }
284 let (out, c) = run(token, cache);
285 next.push(out);
286 cache = Some(c);
287 if let Some(cursor) = layer_cursor.as_deref_mut() {
288 *cursor += 1;
289 }
290 }
291 slots[pos] = cache;
292 stream = next;
293 }
294
295 let out = stream.pop().expect("the user token is always emitted");
297 (out, M::Caches::from_slots(slots))
298 }
299
300 pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
308 if let Residuals::MultiGate(mg) = &self.residuals {
309 return self.step_infinite_multi_gate(x, mg);
310 }
311 assert_step_compatible(&self.class_latents, "Layers");
312 let n = self.n_virtual_count();
313 let mut h = x;
314 for i in 0..n {
315 let layer = &self.real_layers[self.real_idx(i)];
316 h = if self.skip_residual(i, n) {
317 layer.step_infinite(h)
318 } else {
319 layer.step_infinite(h.clone()) + h
320 };
321 }
322 h
323 }
324
325 fn step_infinite_multi_gate(&self, x: Tensor<2>, mg: &crate::modules::MultiGate) -> Tensor<2> {
330 assert_step_compatible(&self.class_latents, "Layers");
331 let [batch, d_model] = x.dims();
332 let n = self.n_virtual_count();
333 let mut streams = x
334 .clone()
335 .unsqueeze_dim::<3>(1)
336 .expand([batch, mg.n_stream, d_model]);
337 let mut h = x;
338 for i in 0..n {
339 let real = self.real_idx(i);
340 let layer = &self.real_layers[real];
341 assert_step_compatible(&layer.class_latents, "Layer");
342 let out = layer.step_infinite(h);
343 if self.ignore_last_residual && i + 1 == n {
344 h = out;
345 } else if self.ignore_first_residual && i == 0 {
346 let [b, d] = out.dims();
347 streams = out
348 .clone()
349 .unsqueeze_dim::<3>(1)
350 .expand([b, mg.n_stream, d]);
351 h = out;
352 } else {
353 let idx = mg.module_index(i, real);
354 let (new_h, new_streams) = mg.layers[idx].step(out, streams);
355 h = new_h;
356 streams = new_streams;
357 }
358 }
359 h
360 }
361
362 pub fn step_n_approx(
371 &self,
372 x: Tensor<2>,
373 n_steps: usize,
374 caches: Option<M::Caches>,
375 ) -> (Tensor<2>, M::Caches) {
376 if let Residuals::MultiGate(mg) = &self.residuals {
377 return self.step_n_approx_multi_gate(x, n_steps, caches, mg);
378 }
379 assert_step_compatible(&self.class_latents, "Layers");
380 let n = self.n_virtual_count();
381 let caches =
382 caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
383 assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
384 let mut slots = caches.into_slots();
385
386 let mut h = x;
387 for i in 0..n {
388 let layer = &self.real_layers[self.real_idx(i)];
389 let cache = slots[i].take();
390 let (out, c) = if self.skip_residual(i, n) {
391 layer.step_n_approx(h, n_steps, cache)
392 } else {
393 let (out, c) = layer.step_n_approx(h.clone(), n_steps, cache);
394 (out + h, c)
395 };
396 h = out;
397 slots[i] = Some(c);
398 }
399 (h, M::Caches::from_slots(slots))
400 }
401
402 fn step_n_approx_multi_gate(
405 &self,
406 x: Tensor<2>,
407 n_steps: usize,
408 caches: Option<M::Caches>,
409 mg: &crate::modules::MultiGate,
410 ) -> (Tensor<2>, M::Caches) {
411 assert_step_compatible(&self.class_latents, "Layers");
412 let [batch, d_model] = x.dims();
413 let n = self.n_virtual_count();
414 let caches =
415 caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
416 assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
417
418 let mut slots = caches.into_slots();
419 let mut streams = x
420 .clone()
421 .unsqueeze_dim::<3>(1)
422 .expand([batch, mg.n_stream, d_model]);
423 let mut h = x;
424 for i in 0..n {
425 let real = self.real_idx(i);
426 let layer = &self.real_layers[real];
427 let cache = slots[i].take();
428 let (out, c_) = layer.step_n_approx(h, n_steps, cache);
429 slots[i] = Some(c_);
430 if self.ignore_last_residual && i + 1 == n {
431 h = out;
432 } else if self.ignore_first_residual && i == 0 {
433 let [b, d] = out.dims();
434 streams = out
435 .clone()
436 .unsqueeze_dim::<3>(1)
437 .expand([b, mg.n_stream, d]);
438 h = out;
439 } else {
440 let idx = mg.module_index(i, real);
441 let (new_h, new_streams) = mg.layers[idx].step(out, streams);
442 h = new_h;
443 streams = new_streams;
444 }
445 }
446 (h, M::Caches::from_slots(slots))
447 }
448
449 fn step_multi_gate(
455 &self,
456 x: Tensor<2>,
457 caches: Option<M::Caches>,
458 mg: &crate::modules::MultiGate,
459 ) -> (Tensor<2>, M::Caches) {
460 assert_step_compatible(&self.class_latents, "Layers");
461 let [batch, d_model] = x.dims();
462 let n = self.n_virtual_count();
463 let caches =
464 caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
465 assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
466
467 let mut slots = caches.into_slots();
468 let mut streams = x
469 .clone()
470 .unsqueeze_dim::<3>(1)
471 .expand([batch, mg.n_stream, d_model]);
472 let mut h = x;
473 for i in 0..n {
474 let real = self.real_idx(i);
475 let layer = &self.real_layers[real];
476 assert_step_compatible(&layer.class_latents, "Layer");
477 let cache = slots[i].take();
478 let (out, c_) = layer.step(h, cache, None);
479 slots[i] = Some(c_);
480 if self.ignore_last_residual && i + 1 == n {
483 h = out;
485 } else if self.ignore_first_residual && i == 0 {
486 let [b, d] = out.dims();
488 streams = out
489 .clone()
490 .unsqueeze_dim::<3>(1)
491 .expand([b, mg.n_stream, d]);
492 h = out;
493 } else {
494 let idx = mg.module_index(i, real);
495 let (new_h, new_streams) = mg.layers[idx].step(out, streams);
496 h = new_h;
497 streams = new_streams;
498 }
499 }
500 (h, M::Caches::from_slots(slots))
501 }
502}
503
504pub struct LayersBuilder<C> {
508 pub n_real_layers: usize,
510 pub n_virtual_layers: Option<(usize, Schedule)>,
512 pub mamba_block: C,
514 pub ignore_first_residual: bool,
516 pub ignore_last_residual: bool,
518 pub class_latents: Vec<ClassLatent>,
520 pub residuals: ResidualsConfig,
522}
523
524impl<C: MambaBlockConfig> LayersBuilder<C> {
525 pub fn new(n_real_layers: usize, mamba_block: C) -> Self {
527 Self {
528 n_real_layers,
529 n_virtual_layers: None,
530 mamba_block,
531 ignore_first_residual: false,
532 ignore_last_residual: false,
533 class_latents: Vec::new(),
534 residuals: ResidualsConfig::Standard,
535 }
536 }
537
538 pub fn with_n_virtual_layers(mut self, n: Option<(usize, Schedule)>) -> Self {
540 self.n_virtual_layers = n;
541 self
542 }
543
544 pub fn with_residuals(mut self, residuals: ResidualsConfig) -> Self {
546 self.residuals = residuals;
547 self
548 }
549
550 pub fn with_ignore_first_residual(mut self, ignore: bool) -> Self {
552 self.ignore_first_residual = ignore;
553 self
554 }
555
556 pub fn with_ignore_last_residual(mut self, ignore: bool) -> Self {
558 self.ignore_last_residual = ignore;
559 self
560 }
561
562 #[cfg(test)]
564 pub fn with_class_latents(mut self, class_latents: Vec<ClassLatent>) -> Self {
565 self.class_latents = class_latents;
566 self
567 }
568
569 pub fn init(&self, device: &Device) -> Layers<C::Block> {
571 let d_model = self.mamba_block.d_model();
572 let n_virtual = self
573 .n_virtual_layers
574 .as_ref()
575 .map(|(l, _)| *l)
576 .unwrap_or(self.n_real_layers);
577 let real_layers = (0..self.n_real_layers)
578 .map(|_| Layer {
579 norm: RmsNormConfig::new(d_model).init(device),
580 mamba_block: self.mamba_block.init_block(device),
581 class_latents: Vec::new(),
582 class_latents_emb: None,
583 })
584 .collect();
585 Layers {
586 n_real_layers: self.n_real_layers,
587 n_virtual_layers: self.n_virtual_layers.clone(),
588 real_layers,
589 ignore_first_residual: self.ignore_first_residual,
590 ignore_last_residual: self.ignore_last_residual,
591 residuals: self
592 .residuals
593 .init(d_model, self.n_real_layers, n_virtual, device),
594 class_latents_emb: init_class_emb(self.class_latents.len(), d_model, device),
595 class_latents: self.class_latents.clone(),
596 }
597 }
598}