Skip to main content

burn_mamba/modules/
layers.rs

1use crate::modules::{Residuals, ResidualsConfig, RmsNormConfig};
2use crate::prelude::*;
3use crate::utils::ClassLatent;
4use crate::utils::Schedule;
5use crate::utils::class::{
6    assert_step_compatible, class_marker_output_indices, class_step_injections, init_class_emb,
7    insert_class_markers,
8};
9use burn::module::Param;
10use burn::prelude::*;
11
12/// A stack of [`Layer`]s with optional virtual-layer scheduling — one struct for
13/// every Mamba-x family.
14#[derive(Module, Debug)]
15pub struct Layers<M: Module> {
16    /// Number of real (weight-bearing) layers.
17    pub n_real_layers: usize,
18    /// Optional `(n_virtual_layers, schedule)` for weight-sharing.
19    #[module(skip)]
20    pub n_virtual_layers: Option<(usize, Schedule)>,
21    /// The weight-bearing layers, length `n_real_layers`.
22    pub real_layers: Vec<Layer<M>>,
23    /// Zero the first virtual layer's residual when `true`.
24    pub ignore_first_residual: bool,
25    /// Zero the last virtual layer's residual when `true`.
26    pub ignore_last_residual: bool,
27    /// How residuals are threaded between layers (plain additive vs Multi-Gate).
28    pub residuals: Residuals,
29    /// Positions of the stack-level class latents, spliced into the sequence
30    /// once before the first virtual layer (independent of any per-[`Layer`]
31    /// class latents). Empty ⇒ none.
32    #[module(skip)]
33    pub class_latents: Vec<ClassLatent>,
34    /// The stack-level class-latent embeddings, `[num_class_latents, d_model]`.
35    pub class_latents_emb: Option<Param<Tensor<2>>>,
36}
37
38impl<M: MambaBlock> Layers<M>
39where
40    M::SsdPath: Clone,
41{
42    /// Output positions of the stack-level class latents for an `orig_len` input.
43    pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize> {
44        class_marker_output_indices(&self.class_latents, orig_len)
45    }
46
47    /// Splice this layers' class latents into `x` (no-op when there are none).
48    fn insert_latents(&self, x: Tensor<3>) -> Tensor<3> {
49        if self.class_latents_emb.is_none() {
50            return x;
51        }
52        insert_class_markers(x, &self.class_latents, self.class_latents_emb.as_ref()).0
53    }
54
55    fn n_virtual_count(&self) -> usize {
56        self.n_virtual_layers
57            .as_ref()
58            .map(|(l, _)| *l)
59            .unwrap_or(self.n_real_layers)
60    }
61
62    fn real_idx(&self, virtual_idx: usize) -> usize {
63        if let Some((n, schedule)) = &self.n_virtual_layers {
64            schedule.real_idx(virtual_idx, *n, self.n_real_layers)
65        } else {
66            virtual_idx
67        }
68    }
69
70    /// Whether (virtual) layer `i` of `n` suppresses its residual — the first
71    /// layer when `ignore_first_residual`, the last when `ignore_last_residual`.
72    fn skip_residual(&self, i: usize, n: usize) -> bool {
73        (self.ignore_first_residual && i == 0) || (self.ignore_last_residual && i + 1 == n)
74    }
75
76    /// Full-sequence pass through every (virtual) layer.
77    ///
78    /// [`Layer`] returns only `F_l = Block(RMSNorm(·))`; the residual is added
79    /// here. With [`Residuals::Standard`] each layer adds the input skip (unless
80    /// suppressed). With [`Residuals::MultiGate`] the skip is dropped and
81    /// `n_stream` parallel streams — seeded from `x` — carry the residual: each
82    /// layer reads their attention-pooled aggregate as input and its output is
83    /// gated back into every stream (see [`MultiGate`]).
84    ///
85    /// `ignore_first/last_residual` apply to **both** paths: skipping the first
86    /// restarts the residual carry from the first layer's output (the input is
87    /// read but not carried); skipping the last makes the stack output the last
88    /// layer's transform `F_l` alone (no input-dependent carry). Class latents
89    /// apply to the Standard path only (MultiGate forbids them, panicking if any
90    /// are present).
91    ///
92    /// [`MultiGate`]: crate::modules::MultiGate
93    /// [`Layer`]: crate::modules::Layer
94    pub fn forward(
95        &self,
96        x: Tensor<3>,
97        caches: Option<M::Caches>,
98        ssd_path: M::SsdPath,
99    ) -> (Tensor<3>, M::Caches) {
100        let mut x = self.insert_latents(x);
101        let n = self.n_virtual_count();
102        let caches =
103            caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_3d(&x, n));
104        assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
105        let mut slots = caches.into_slots();
106
107        // MultiGate keeps `n_stream` parallel streams (seeded from the input);
108        // Standard threads the single tensor `x` directly (streams stays `None`).
109        let mut streams = self.multi_gate_streams_seed(&x);
110
111        for i in 0..n {
112            let real = self.real_idx(i);
113            let layer = &self.real_layers[real];
114            let cache = slots[i].take().unwrap();
115            let first = self.ignore_first_residual && i == 0;
116            let last = self.ignore_last_residual && i + 1 == n;
117            match &self.residuals {
118                Residuals::Standard(_noop) => {
119                    // Splice this layer's class latents, then add the residual
120                    // (the lengthened input) here — unless suppressed, in which
121                    // case the input is moved straight in (no clone, no add).
122                    let x_l = layer.insert_latents(x);
123                    let (out, c_) = if first || last {
124                        layer.forward(x_l, Some(cache), ssd_path.clone())
125                    } else {
126                        let (out, c_) = layer.forward(x_l.clone(), Some(cache), ssd_path.clone());
127                        (out + x_l, c_)
128                    };
129                    x = out;
130                    slots[i] = Some(c_);
131                }
132                Residuals::MultiGate(mg) => {
133                    assert!(
134                        layer.class_latents_emb.is_none(),
135                        "MultiGate residuals do not support per-layer class latents"
136                    );
137                    let (out, c_) = layer.forward(x, Some(cache), ssd_path.clone());
138                    slots[i] = Some(c_);
139                    let s = streams.take().unwrap();
140                    // A skipped residual here is equivalent to forcing the MGR
141                    // mixer gate β ≡ 1 (`new_streams = out`): the carried streams
142                    // are dropped, and the aggregator over the resulting identical
143                    // streams collapses to `F_l`. Both branches shortcut that.
144                    if last {
145                        // Output depends purely on the last layer's transform.
146                        x = out;
147                        streams = Some(s);
148                    } else if first {
149                        // Drop the input seed: restart the streams from `F_0`.
150                        let [b, seq, d] = out.dims();
151                        streams = Some(out.clone().unsqueeze_dim::<4>(2).expand([
152                            b,
153                            seq,
154                            mg.n_stream,
155                            d,
156                        ]));
157                        x = out;
158                    } else {
159                        let idx = mg.module_index(i, real);
160                        let (new_h, new_streams) = mg.layers[idx].forward(out, s);
161                        x = new_h;
162                        streams = Some(new_streams);
163                    }
164                }
165            }
166        }
167        (x, M::Caches::from_slots(slots))
168    }
169
170    /// Seed the MultiGate streams from a full-sequence input — `n_stream` copies
171    /// of `x` as `[batch, sequence, n_stream, d_model]` — or `None` for the
172    /// Standard path. Panics if MultiGate is paired with stack-level class latents.
173    fn multi_gate_streams_seed(&self, x: &Tensor<3>) -> Option<Tensor<4>> {
174        let Residuals::MultiGate(mg) = &self.residuals else {
175            return None;
176        };
177        assert!(
178            self.class_latents_emb.is_none(),
179            "MultiGate residuals do not support stack-level class latents"
180        );
181        let [batch, sequence, d_model] = x.dims();
182        Some(
183            x.clone()
184                .unsqueeze_dim::<4>(2)
185                .expand([batch, sequence, mg.n_stream, d_model]),
186        )
187    }
188
189    /// Single-token step through every (virtual) layer.
190    ///
191    /// Two independent class-latent cursors:
192    /// - `own_index` — the stack-level [`Self::class_latents`] (spliced once
193    ///   before the first layer in `forward`). When it lands on a stack position
194    ///   that latent enters the bottom of the stack as an extra input token.
195    /// - `layer_indices` — one cursor **per virtual layer**
196    ///   (`len == n_virtual_layers`), the per-[`Layer`] cursor so each layer
197    ///   splices its own class latents into the token stream it receives.
198    ///
199    /// Because a layer's class latents grow the sequence the *next* layer sees
200    /// (exactly as in `forward`), a single user step is a **cascade**: the bottom
201    /// input stream (any stack latents at `own_index`, then the user token) is
202    /// threaded up the stack, each layer expanding it with its own class latents
203    /// at `layer_indices[i]`. Every layer's recurrence therefore sees the same
204    /// token order as `forward`, so `forward` and `step` agree. Only the user
205    /// token's (fully propagated) output is returned — it is emitted last.
206    ///
207    /// A `None` cursor skips that level's injection; `Middle`/`End` latents panic
208    /// for the cursored level (their positions need the full sequence — use
209    /// `forward`).
210    pub fn step(
211        &self,
212        x: Tensor<2>,
213        caches: Option<M::Caches>,
214        own_index: Option<&mut usize>,
215        mut layer_indices: Option<&mut Vec<usize>>,
216    ) -> (Tensor<2>, M::Caches) {
217        if let Residuals::MultiGate(mg) = &self.residuals {
218            return self.step_multi_gate(x, caches, mg);
219        }
220        let [batch, d_model] = x.dims();
221        let n = self.n_virtual_count();
222        let caches =
223            caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
224        assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
225        if let Some(v) = layer_indices.as_deref() {
226            assert_eq!(v.len(), n, "one class-latent cursor per virtual layer");
227        }
228        let mut slots = caches.into_slots();
229
230        // Bottom input stream for this user step: the stack-level class latents
231        // that fall at the cursor (fed through the whole stack like ordinary
232        // inputs), then the user token. Without a cursor, just the user token.
233        let mut stream: Vec<Tensor<2>> = Vec::new();
234        if let Some(own_cursor) = own_index {
235            let positions = class_step_injections(&self.class_latents, "Layers");
236            let emb = self.class_latents_emb.as_ref();
237            while let Some(i) = positions.iter().position(|&p| p == *own_cursor) {
238                stream.push(emb.unwrap().val().narrow(0, i, 1).expand([batch, d_model]));
239                *own_cursor += 1;
240            }
241            stream.push(x);
242            *own_cursor += 1;
243        } else {
244            assert_step_compatible(&self.class_latents, "Layers");
245            stream.push(x);
246        }
247
248        // Propagate the stream up through each virtual layer, each layer splicing
249        // its own class latents into the stream it receives.
250        for pos in 0..n {
251            let layer = &self.real_layers[self.real_idx(pos)];
252            let skip = self.skip_residual(pos, n);
253            let mut layer_cursor = layer_indices.as_deref_mut().map(|v| &mut v[pos]);
254            let positions = if layer_cursor.is_some() {
255                class_step_injections(&layer.class_latents, "Layer")
256            } else {
257                assert_step_compatible(&layer.class_latents, "Layer");
258                Vec::new()
259            };
260            let emb = layer.class_latents_emb.as_ref();
261            let mut cache = slots[pos].take();
262            let mut next: Vec<Tensor<2>> = Vec::with_capacity(stream.len());
263            // One token through the layer, adding its residual here unless
264            // suppressed (then the token is moved straight in — no clone/add).
265            let run = |token: Tensor<2>, cache: Option<M::Cache>| {
266                if skip {
267                    layer.step(token, cache, None)
268                } else {
269                    let (out, c) = layer.step(token.clone(), cache, None);
270                    (out + token, c)
271                }
272            };
273            for token in stream {
274                // Splice this layer's class latents that fall before this token.
275                if let Some(cursor) = layer_cursor.as_deref_mut() {
276                    while let Some(i) = positions.iter().position(|&p| p == *cursor) {
277                        let row = emb.unwrap().val().narrow(0, i, 1).expand([batch, d_model]);
278                        let (out, c) = run(row, cache);
279                        next.push(out);
280                        cache = Some(c);
281                        *cursor += 1;
282                    }
283                }
284                let (out, c) = run(token, cache);
285                next.push(out);
286                cache = Some(c);
287                if let Some(cursor) = layer_cursor.as_deref_mut() {
288                    *cursor += 1;
289                }
290            }
291            slots[pos] = cache;
292            stream = next;
293        }
294
295        // The user token entered last, so its fully-propagated output is last.
296        let out = stream.pop().expect("the user token is always emitted");
297        (out, M::Caches::from_slots(slots))
298    }
299
300    /// Stationary fixed point of the whole stack under a constant token, with
301    /// **no caches** involved: under a constant input each layer's output
302    /// converges (its decay damps the transient, and the readout phase of the
303    /// rotation cancels), so the downstream layer's input converges too and
304    /// the limit composes **exactly**, layer by layer — even though every
305    /// layer's SSM state keeps rotating forever. Residual handling mirrors
306    /// [`Self::step`]; cursorless (class latents are not injected).
307    pub fn step_infinite(&self, x: Tensor<2>) -> Tensor<2> {
308        if let Residuals::MultiGate(mg) = &self.residuals {
309            return self.step_infinite_multi_gate(x, mg);
310        }
311        assert_step_compatible(&self.class_latents, "Layers");
312        let n = self.n_virtual_count();
313        let mut h = x;
314        for i in 0..n {
315            let layer = &self.real_layers[self.real_idx(i)];
316            h = if self.skip_residual(i, n) {
317                layer.step_infinite(h)
318            } else {
319                layer.step_infinite(h.clone()) + h
320            };
321        }
322        h
323    }
324
325    /// Multi-Gate counterpart of [`Self::step_infinite`]. The streams are a
326    /// per-token depth construct (as in [`Self::step_multi_gate`]), so applying
327    /// the mixers to the layers' fixed-point outputs *is* the fixed point of
328    /// the whole stack.
329    fn step_infinite_multi_gate(&self, x: Tensor<2>, mg: &crate::modules::MultiGate) -> Tensor<2> {
330        assert_step_compatible(&self.class_latents, "Layers");
331        let [batch, d_model] = x.dims();
332        let n = self.n_virtual_count();
333        let mut streams = x
334            .clone()
335            .unsqueeze_dim::<3>(1)
336            .expand([batch, mg.n_stream, d_model]);
337        let mut h = x;
338        for i in 0..n {
339            let real = self.real_idx(i);
340            let layer = &self.real_layers[real];
341            assert_step_compatible(&layer.class_latents, "Layer");
342            let out = layer.step_infinite(h);
343            if self.ignore_last_residual && i + 1 == n {
344                h = out;
345            } else if self.ignore_first_residual && i == 0 {
346                let [b, d] = out.dims();
347                streams = out
348                    .clone()
349                    .unsqueeze_dim::<3>(1)
350                    .expand([b, mg.n_stream, d]);
351                h = out;
352            } else {
353                let idx = mg.module_index(i, real);
354                let (new_h, new_streams) = mg.layers[idx].step(out, streams);
355                h = new_h;
356                streams = new_streams;
357            }
358        }
359        h
360    }
361
362    /// **Approximate** jump of `n_steps` consecutive constant-token
363    /// [`Self::step`] calls (cursorless), in O(1) per layer.
364    ///
365    /// Each (virtual) layer jumps in closed form with its input held constant
366    /// at the *previous layer's step-`n` output*. The first layer's jump is
367    /// exact; deeper layers ignore the upstream transient, an error that
368    /// decays geometrically in `n_steps` (the `n → ∞` limit is exact — see
369    /// [`Self::step_infinite`]). `n_steps = 1` is exactly one `step`.
370    pub fn step_n_approx(
371        &self,
372        x: Tensor<2>,
373        n_steps: usize,
374        caches: Option<M::Caches>,
375    ) -> (Tensor<2>, M::Caches) {
376        if let Residuals::MultiGate(mg) = &self.residuals {
377            return self.step_n_approx_multi_gate(x, n_steps, caches, mg);
378        }
379        assert_step_compatible(&self.class_latents, "Layers");
380        let n = self.n_virtual_count();
381        let caches =
382            caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
383        assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
384        let mut slots = caches.into_slots();
385
386        let mut h = x;
387        for i in 0..n {
388            let layer = &self.real_layers[self.real_idx(i)];
389            let cache = slots[i].take();
390            let (out, c) = if self.skip_residual(i, n) {
391                layer.step_n_approx(h, n_steps, cache)
392            } else {
393                let (out, c) = layer.step_n_approx(h.clone(), n_steps, cache);
394                (out + h, c)
395            };
396            h = out;
397            slots[i] = Some(c);
398        }
399        (h, M::Caches::from_slots(slots))
400    }
401
402    /// Multi-Gate counterpart of [`Self::step_n_approx`] (mirrors
403    /// [`Self::step_multi_gate`]; the mixers see each layer's step-`n` output).
404    fn step_n_approx_multi_gate(
405        &self,
406        x: Tensor<2>,
407        n_steps: usize,
408        caches: Option<M::Caches>,
409        mg: &crate::modules::MultiGate,
410    ) -> (Tensor<2>, M::Caches) {
411        assert_step_compatible(&self.class_latents, "Layers");
412        let [batch, d_model] = x.dims();
413        let n = self.n_virtual_count();
414        let caches =
415            caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
416        assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
417
418        let mut slots = caches.into_slots();
419        let mut streams = x
420            .clone()
421            .unsqueeze_dim::<3>(1)
422            .expand([batch, mg.n_stream, d_model]);
423        let mut h = x;
424        for i in 0..n {
425            let real = self.real_idx(i);
426            let layer = &self.real_layers[real];
427            let cache = slots[i].take();
428            let (out, c_) = layer.step_n_approx(h, n_steps, cache);
429            slots[i] = Some(c_);
430            if self.ignore_last_residual && i + 1 == n {
431                h = out;
432            } else if self.ignore_first_residual && i == 0 {
433                let [b, d] = out.dims();
434                streams = out
435                    .clone()
436                    .unsqueeze_dim::<3>(1)
437                    .expand([b, mg.n_stream, d]);
438                h = out;
439            } else {
440                let idx = mg.module_index(i, real);
441                let (new_h, new_streams) = mg.layers[idx].step(out, streams);
442                h = new_h;
443                streams = new_streams;
444            }
445        }
446        (h, M::Caches::from_slots(slots))
447    }
448
449    /// Single-token Multi-Gate Residual step — the recurrent counterpart of
450    /// [`Self::forward_multi_gate`]. The streams are a per-token *depth*
451    /// construct (rebuilt from `x` each step, never carried between tokens), so
452    /// no extra state crosses steps and `forward`/`step` agree. Class latents are
453    /// unsupported.
454    fn step_multi_gate(
455        &self,
456        x: Tensor<2>,
457        caches: Option<M::Caches>,
458        mg: &crate::modules::MultiGate,
459    ) -> (Tensor<2>, M::Caches) {
460        assert_step_compatible(&self.class_latents, "Layers");
461        let [batch, d_model] = x.dims();
462        let n = self.n_virtual_count();
463        let caches =
464            caches.unwrap_or_else(|| self.real_layers[0].mamba_block.zero_caches_2d(&x, n));
465        assert_eq!(caches.slot_count(), n, "one cache per virtual layer");
466
467        let mut slots = caches.into_slots();
468        let mut streams = x
469            .clone()
470            .unsqueeze_dim::<3>(1)
471            .expand([batch, mg.n_stream, d_model]);
472        let mut h = x;
473        for i in 0..n {
474            let real = self.real_idx(i);
475            let layer = &self.real_layers[real];
476            assert_step_compatible(&layer.class_latents, "Layer");
477            let cache = slots[i].take();
478            let (out, c_) = layer.step(h, cache, None);
479            slots[i] = Some(c_);
480            // As in `forward`, a skipped residual is β ≡ 1 in the mixer
481            // (`new_streams = out`), the aggregator then collapsing to `F_l`.
482            if self.ignore_last_residual && i + 1 == n {
483                // Output depends purely on the last layer's transform.
484                h = out;
485            } else if self.ignore_first_residual && i == 0 {
486                // Drop the input seed: restart the streams from `F_0`.
487                let [b, d] = out.dims();
488                streams = out
489                    .clone()
490                    .unsqueeze_dim::<3>(1)
491                    .expand([b, mg.n_stream, d]);
492                h = out;
493            } else {
494                let idx = mg.module_index(i, real);
495                let (new_h, new_streams) = mg.layers[idx].step(out, streams);
496                h = new_h;
497                streams = new_streams;
498            }
499        }
500        (h, M::Caches::from_slots(slots))
501    }
502}
503
504/// Plain (non-serde) factory for [`Layers`]. The serializable surface is the
505/// concrete `MambaLatentNetConfig` enum; this is just the generic builder it
506/// delegates to.
507pub struct LayersBuilder<C> {
508    /// Number of real (weight-bearing) layers.
509    pub n_real_layers: usize,
510    /// Optional virtual-layer scheduling.
511    pub n_virtual_layers: Option<(usize, Schedule)>,
512    /// Shared block config.
513    pub mamba_block: C,
514    /// Zero the first virtual layer's residual.
515    pub ignore_first_residual: bool,
516    /// Zero the last virtual layer's residual.
517    pub ignore_last_residual: bool,
518    /// Stack-level class latents (spliced once before the first virtual layer).
519    pub class_latents: Vec<ClassLatent>,
520    /// Inter-layer residual scheme (defaults to plain additive).
521    pub residuals: ResidualsConfig,
522}
523
524impl<C: MambaBlockConfig> LayersBuilder<C> {
525    /// Builder with no virtual scheduling, no class latents, residuals enabled.
526    pub fn new(n_real_layers: usize, mamba_block: C) -> Self {
527        Self {
528            n_real_layers,
529            n_virtual_layers: None,
530            mamba_block,
531            ignore_first_residual: false,
532            ignore_last_residual: false,
533            class_latents: Vec::new(),
534            residuals: ResidualsConfig::Standard,
535        }
536    }
537
538    /// Set the optional virtual-layer scheduling.
539    pub fn with_n_virtual_layers(mut self, n: Option<(usize, Schedule)>) -> Self {
540        self.n_virtual_layers = n;
541        self
542    }
543
544    /// Set the inter-layer residual scheme (plain additive vs Multi-Gate).
545    pub fn with_residuals(mut self, residuals: ResidualsConfig) -> Self {
546        self.residuals = residuals;
547        self
548    }
549
550    /// Suppress the first virtual layer's residual (see [`Layers`]).
551    pub fn with_ignore_first_residual(mut self, ignore: bool) -> Self {
552        self.ignore_first_residual = ignore;
553        self
554    }
555
556    /// Suppress the last virtual layer's residual (see [`Layers`]).
557    pub fn with_ignore_last_residual(mut self, ignore: bool) -> Self {
558        self.ignore_last_residual = ignore;
559        self
560    }
561
562    /// Set the stack-level class latents.
563    #[cfg(test)]
564    pub fn with_class_latents(mut self, class_latents: Vec<ClassLatent>) -> Self {
565        self.class_latents = class_latents;
566        self
567    }
568
569    /// Allocate and initialise the stack on `device`.
570    pub fn init(&self, device: &Device) -> Layers<C::Block> {
571        let d_model = self.mamba_block.d_model();
572        let n_virtual = self
573            .n_virtual_layers
574            .as_ref()
575            .map(|(l, _)| *l)
576            .unwrap_or(self.n_real_layers);
577        let real_layers = (0..self.n_real_layers)
578            .map(|_| Layer {
579                norm: RmsNormConfig::new(d_model).init(device),
580                mamba_block: self.mamba_block.init_block(device),
581                class_latents: Vec::new(),
582                class_latents_emb: None,
583            })
584            .collect();
585        Layers {
586            n_real_layers: self.n_real_layers,
587            n_virtual_layers: self.n_virtual_layers.clone(),
588            real_layers,
589            ignore_first_residual: self.ignore_first_residual,
590            ignore_last_residual: self.ignore_last_residual,
591            residuals: self
592                .residuals
593                .init(d_model, self.n_real_layers, n_virtual, device),
594            class_latents_emb: init_class_emb(self.class_latents.len(), d_model, device),
595            class_latents: self.class_latents.clone(),
596        }
597    }
598}