burn_mamba/modules/layer.rs
1use crate::modules::RmsNorm;
2use crate::prelude::*;
3use crate::utils::ClassLatent;
4use crate::utils::class::{assert_step_compatible, class_step_injections, insert_class_markers};
5use burn::module::Param;
6use burn::prelude::*;
7
8/// A single Pre-LN block wrapper computing `M(RMSNorm(x))` — the residual is
9/// **not** applied here. The enclosing [`Layers`](crate::modules::Layers) owns
10/// that decision (add the input back, suppress it on the first/last layer, or
11/// thread it through Multi-Gate streams), so no input clone / zero-add is wasted
12/// when no residual is wanted.
13///
14/// May carry its own [`ClassLatent`]s. In `step` they are spliced via the
15/// `index` cursor; in `forward` the caller splices them first (via
16/// [`Self::insert_latents`]) so the residual it adds sees the same lengthened
17/// sequence. They are independent of any class latents on the enclosing
18/// [`Layers`].
19#[derive(Module, Debug)]
20pub struct Layer<M: Module> {
21 /// Pre-norm applied before the inner block.
22 pub norm: RmsNorm,
23 /// The inner Mamba-x SSM block.
24 pub mamba_block: M,
25 /// Positions of this layer's class latents (empty ⇒ none).
26 #[module(skip)]
27 pub class_latents: Vec<ClassLatent>,
28 /// The class-latent embeddings, `[num_class_latents, d_model]` (`None` ⇒ none).
29 pub class_latents_emb: Option<Param<Tensor<2>>>,
30}
31
32impl<M: MambaBlock> Layer<M> {
33 /// Splice this layer's class latents into `x` (no-op when there are none).
34 /// Public to the crate so [`Layers`](crate::modules::Layers) can lengthen the
35 /// sequence itself (and add the matching residual) before calling
36 /// [`Self::forward`].
37 pub(crate) fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
38 if self.class_latents_emb.is_none() {
39 return x;
40 }
41 insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
42 }
43
44 /// Full-sequence Pre-LN block **without** the residual: `M(RMSNorm(x))`.
45 ///
46 /// The caller owns any class-latent insertion ([`Self::insert_latents`]) and
47 /// the residual.
48 pub fn forward(
49 &self,
50 x: Tensor<3>,
51 cache: Option<M::Cache>,
52 ssd_path: M::SsdPath,
53 ) -> (Tensor<3>, M::Cache) {
54 let normed = self.norm.forward(x);
55 self.mamba_block.block_forward(normed, cache, ssd_path)
56 }
57
58 /// Single-token Pre-LN block step **without** the residual.
59 ///
60 /// `index` is the running cursor into this layer's *output* sequence. With
61 /// `Some`, whenever it lands on one of this layer's class-latent positions
62 /// those latents are stepped first (each advancing `index`, recursing with
63 /// `None`); only the user token's output and cache are returned. With `None`
64 /// no class latents are injected — and `Middle`/`End` latents panic (their
65 /// positions need the full sequence; use `forward`). The residual is the
66 /// caller's responsibility.
67 pub fn step(
68 &self,
69 x: Tensor<2>,
70 cache: Option<M::Cache>,
71 index: Option<&mut usize>,
72 ) -> (Tensor<2>, M::Cache) {
73 let Some(cursor) = index else {
74 // The actual one-token work (no class injection, no residual).
75 assert_step_compatible(&self.class_latents, "Layer");
76 let normed = self.norm.forward(x);
77 return self.mamba_block.block_step(normed, cache);
78 };
79 let [batch, d_model] = x.dims();
80 let inj = class_step_injections(&self.class_latents, "Layer");
81 let emb = self.class_latents_emb.as_ref();
82 let mut cache = cache;
83 while let Some(i) = inj.iter().position(|&p| p == *cursor) {
84 let row = emb.unwrap().val().narrow(0, i, 1).expand([batch, d_model]);
85 let (_discard, c) = self.step(row, cache, None);
86 cache = Some(c);
87 *cursor += 1;
88 }
89 let (out, cache) = self.step(x, cache, None);
90 *cursor += 1;
91 (out, cache)
92 }
93
94 /// Stationary fixed point of the Pre-LN block under a constant token,
95 /// **without** the residual: the `step` counterpart of infinitely many
96 /// identical tokens (closed form, no cache — see
97 /// [`MambaBlock::block_step_infinite`]). Cursorless: class latents are not
98 /// injected (`Middle`/`End` latents panic, as in a `None`-cursor `step`).
99 pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
100 assert_step_compatible(&self.class_latents, "Layer");
101 self.mamba_block.block_step_infinite(self.norm.forward(x))
102 }
103
104 /// Closed-form jump equivalent to `n` cursorless [`Self::step`] calls on
105 /// the same constant token, **without** the residual (see
106 /// [`MambaBlock::block_step_n_approx`]).
107 pub fn step_n_approx(
108 &self,
109 x: Tensor<2>,
110 n: usize,
111 cache: Option<M::Cache>,
112 ) -> (Tensor<2>, M::Cache) {
113 assert_step_compatible(&self.class_latents, "Layer");
114 self.mamba_block
115 .block_step_n_approx(self.norm.forward(x), n, cache)
116 }
117}