pub struct BidiLayerPair<M: Module> {
pub straight_norm: RmsNorm,
pub reverse_norm: RmsNorm,
pub straight_block: M,
pub reverse_block: M,
pub output_merge: OutputMerge,
pub class_latents: Vec<ClassLatent>,
pub class_latents_emb: Option<Param<Tensor<2>>>,
}Expand description
A single bidirectional pair: a straight (→) and a reversed (←) Pre-LN block
whose outputs are merged. The residual is not applied here — the
enclosing BidiLayers adds it (or suppresses it on the first/last pair),
mirroring the Layer / Layers split.
Fields§
§straight_norm: RmsNormPre-norm for the straight pass.
reverse_norm: RmsNormPre-norm for the reversed pass.
straight_block: MThe block run left-to-right.
reverse_block: MThe block run right-to-left (over the flipped sequence).
output_merge: OutputMergeMerge strategy combining the two directions.
class_latents: Vec<ClassLatent>Positions of this pair’s class latents, spliced in before either direction runs (both directions, and the residual, see the lengthened sequence). Empty ⇒ none.
class_latents_emb: Option<Param<Tensor<2>>>This pair’s class-latent embeddings, [num_class_latents, d_model].
Implementations§
Source§impl<M: MambaBlock> BidiLayerPair<M>
impl<M: MambaBlock> BidiLayerPair<M>
Sourcefn insert_latents(&self, x: Tensor<3>) -> Tensor<3>
fn insert_latents(&self, x: Tensor<3>) -> Tensor<3>
Splice this bidi-layer-pair’s class latents into x (no-op when there are none).
Sourcepub fn forward(
&self,
x: Tensor<3>,
straight_cache: Option<M::Cache>,
reverse_cache: Option<M::Cache>,
ssd_path: M::SsdPath,
) -> (Tensor<3>, M::Cache, M::Cache)
pub fn forward( &self, x: Tensor<3>, straight_cache: Option<M::Cache>, reverse_cache: Option<M::Cache>, ssd_path: M::SsdPath, ) -> (Tensor<3>, M::Cache, M::Cache)
[batch, sequence, d_model] → [batch, sequence, d_model], plus the two
updated direction caches. (sequence grows by the class-latent count.)
Returns the merged directions without the residual — the enclosing
BidiLayers adds it.