Skip to main content

burn_mamba/modules/
layer.rs

1use crate::modules::RmsNorm;
2use crate::prelude::*;
3use crate::utils::ClassLatent;
4use crate::utils::class::{assert_step_compatible, class_step_injections, insert_class_markers};
5use burn::module::Param;
6use burn::prelude::*;
7
8/// A single Pre-LN block wrapper computing `M(RMSNorm(x))` — the residual is
9/// **not** applied here. The enclosing [`Layers`](crate::modules::Layers) owns
10/// that decision (add the input back, suppress it on the first/last layer, or
11/// thread it through Multi-Gate streams), so no input clone / zero-add is wasted
12/// when no residual is wanted.
13///
14/// May carry its own [`ClassLatent`]s. In `step` they are spliced via the
15/// `index` cursor; in `forward` the caller splices them first (via
16/// [`Self::insert_latents`]) so the residual it adds sees the same lengthened
17/// sequence. They are independent of any class latents on the enclosing
18/// [`Layers`].
19#[derive(Module, Debug)]
20pub struct Layer<M: Module> {
21    /// Pre-norm applied before the inner block.
22    pub norm: RmsNorm,
23    /// The inner Mamba-x SSM block.
24    pub mamba_block: M,
25    /// Positions of this layer's class latents (empty ⇒ none).
26    #[module(skip)]
27    pub class_latents: Vec<ClassLatent>,
28    /// The class-latent embeddings, `[num_class_latents, d_model]` (`None` ⇒ none).
29    pub class_latents_emb: Option<Param<Tensor<2>>>,
30}
31
32impl<M: MambaBlock> Layer<M> {
33    /// Splice this layer's class latents into `x` (no-op when there are none).
34    /// Public to the crate so [`Layers`](crate::modules::Layers) can lengthen the
35    /// sequence itself (and add the matching residual) before calling
36    /// [`Self::forward`].
37    pub(crate) fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
38        if self.class_latents_emb.is_none() {
39            return x;
40        }
41        insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
42    }
43
44    /// Full-sequence Pre-LN block **without** the residual: `M(RMSNorm(x))`.
45    ///
46    /// The caller owns any class-latent insertion ([`Self::insert_latents`]) and
47    /// the residual.
48    pub fn forward(
49        &self,
50        x: Tensor<3>,
51        cache: Option<M::Cache>,
52        ssd_path: M::SsdPath,
53    ) -> (Tensor<3>, M::Cache) {
54        let normed = self.norm.forward(x);
55        self.mamba_block.block_forward(normed, cache, ssd_path)
56    }
57
58    /// Single-token Pre-LN block step **without** the residual.
59    ///
60    /// `index` is the running cursor into this layer's *output* sequence. With
61    /// `Some`, whenever it lands on one of this layer's class-latent positions
62    /// those latents are stepped first (each advancing `index`, recursing with
63    /// `None`); only the user token's output and cache are returned. With `None`
64    /// no class latents are injected — and `Middle`/`End` latents panic (their
65    /// positions need the full sequence; use `forward`). The residual is the
66    /// caller's responsibility.
67    pub fn step(
68        &self,
69        x: Tensor<2>,
70        cache: Option<M::Cache>,
71        index: Option<&mut usize>,
72    ) -> (Tensor<2>, M::Cache) {
73        let Some(cursor) = index else {
74            // The actual one-token work (no class injection, no residual).
75            assert_step_compatible(&self.class_latents, "Layer");
76            let normed = self.norm.forward(x);
77            return self.mamba_block.block_step(normed, cache);
78        };
79        let [batch, d_model] = x.dims();
80        let inj = class_step_injections(&self.class_latents, "Layer");
81        let emb = self.class_latents_emb.as_ref();
82        let mut cache = cache;
83        while let Some(i) = inj.iter().position(|&p| p == *cursor) {
84            let row = emb.unwrap().val().narrow(0, i, 1).expand([batch, d_model]);
85            let (_discard, c) = self.step(row, cache, None);
86            cache = Some(c);
87            *cursor += 1;
88        }
89        let (out, cache) = self.step(x, cache, None);
90        *cursor += 1;
91        (out, cache)
92    }
93
94    /// Stationary fixed point of the Pre-LN block under a constant token,
95    /// **without** the residual: the `step` counterpart of infinitely many
96    /// identical tokens (closed form, no cache — see
97    /// [`MambaBlock::block_step_infinite`]). Cursorless: class latents are not
98    /// injected (`Middle`/`End` latents panic, as in a `None`-cursor `step`).
99    pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
100        assert_step_compatible(&self.class_latents, "Layer");
101        self.mamba_block.block_step_infinite(self.norm.forward(x))
102    }
103
104    /// Closed-form jump equivalent to `n` cursorless [`Self::step`] calls on
105    /// the same constant token, **without** the residual (see
106    /// [`MambaBlock::block_step_n_approx`]).
107    pub fn step_n_approx(
108        &self,
109        x: Tensor<2>,
110        n: usize,
111        cache: Option<M::Cache>,
112    ) -> (Tensor<2>, M::Cache) {
113        assert_step_compatible(&self.class_latents, "Layer");
114        self.mamba_block
115            .block_step_n_approx(self.norm.forward(x), n, cache)
116    }
117}