pub struct BidiLayers<M: Module> {
pub n_real_layers: usize,
pub n_virtual_layers: Option<(usize, BidiSchedule)>,
pub real_layers: Vec<Layer<M>>,
pub ignore_first_residual: bool,
pub ignore_last_residual: bool,
pub outputs_merge: Vec<OutputMerge>,
pub residuals: Residuals,
pub class_latents: Vec<ClassLatent>,
pub class_latents_emb: Option<Param<Tensor<2>>>,
}Expand description
A stack of bidirectional Layer pairs with optional virtual-layer
scheduling — one struct for every Mamba-x family.
Fields§
§n_real_layers: usizeNumber of real (weight-bearing) layers; must be even (used in pairs).
n_virtual_layers: Option<(usize, BidiSchedule)>Optional (n_virtual_layers, schedule) for weight-sharing.
real_layers: Vec<Layer<M>>The weight-bearing layers, length n_real_layers.
ignore_first_residual: boolZero the first virtual pair’s residual when true.
ignore_last_residual: boolZero the last virtual pair’s residual when true.
outputs_merge: Vec<OutputMerge>One direction-merge per pair, length n_real_layers / 2.
residuals: ResidualsHow residuals are threaded between pairs (plain additive vs Multi-Gate). The MGR unit is the pair: one module per real/virtual pair.
class_latents: Vec<ClassLatent>Positions of the stack-level class latents, spliced into the sequence once before the first pair (independent of any per-pair class latents).
class_latents_emb: Option<Param<Tensor<2>>>The stack-level class-latent embeddings, [num_class_latents, d_model].
Implementations§
Source§impl<M: MambaBlock + Clone> BidiLayers<M>
impl<M: MambaBlock + Clone> BidiLayers<M>
Sourcepub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize>
pub fn class_latent_output_indices(&self, orig_len: usize) -> Vec<usize>
Output positions of the stack-level class latents for an orig_len input.
Sourcefn insert_latents(&self, x: Tensor<3>) -> Tensor<3>
fn insert_latents(&self, x: Tensor<3>) -> Tensor<3>
Splice this bidi-layers’ class latents into x (no-op when there are none).
Sourcefn multi_gate_streams_seed(&self, x: &Tensor<3>) -> Option<Tensor<4>>
fn multi_gate_streams_seed(&self, x: &Tensor<3>) -> Option<Tensor<4>>
Seed the MultiGate streams from a full-sequence input — n_stream copies
of x as [batch, sequence, n_stream, d_model] — or None for the
Standard path. Panics if MultiGate is paired with stack-level class latents.
Sourcepub fn forward(
&self,
x: Tensor<3>,
caches: Option<M::Caches>,
ssd_path: M::SsdPath,
) -> (Tensor<3>, M::Caches)
pub fn forward( &self, x: Tensor<3>, caches: Option<M::Caches>, ssd_path: M::SsdPath, ) -> (Tensor<3>, M::Caches)
[batch, sequence, d_model] → [batch, sequence, d_model]
(sequence grows by the stack-level class-latent count).
Each pair returns its merged transform F_l (no residual). With
Residuals::Standard the input skip is added per pair (unless
suppressed). With Residuals::MultiGate the skip is dropped and
n_stream parallel streams — seeded from x — carry the residual between
pairs: each pair reads their attention-pooled aggregate as input and its
merged output is gated back into every stream (see MultiGate).