Skip to main content

bidi_pair_forward

Function bidi_pair_forward 

Source
fn bidi_pair_forward<M: MambaBlock>(
    straight_norm: &RmsNorm,
    reverse_norm: &RmsNorm,
    straight_block: &M,
    reverse_block: &M,
    output_merge: &OutputMerge,
    x: Tensor<3>,
    straight_cache: Option<M::Cache>,
    reverse_cache: Option<M::Cache>,
    ssd_path: M::SsdPath,
) -> (Tensor<3>, M::Cache, M::Cache)
where M::SsdPath: Clone,
Expand description

The straight + reverse + merge computation of a bidirectional pair, over borrowed sub-modules.

Taking references (rather than owning clones) is load-bearing: a Burn Param that is still lazily-initialised re-runs its random initialiser on every clone, so cloning a not-yet-materialised block per forward would resample fresh random weights each call. BidiLayers therefore calls this directly on its real layers instead of building a transient (cloned) BidiLayerPair.