Skip to main content

Module combined_backward

Module combined_backward 

Source
Expand description

Recompute-based gradient math (the memory-efficient backward).

§Recompute-based gradient math for the Mamba-3 single-SSD

The analytic backward of the single-pass MIMO-first scan. Forward intermediates (K1–K4) are recomputed from the saved leaf inputs, then a reverse per-chunk loop fuses the K5 state-to-output (BLUE), the strict lower-triangular intra-chunk (LOWER), and the K4 state-passing backwards; the γ-weighted same-step (DIAG) term is computed batched (no recurrence, tiny m × m tensors). Because this pathway applies the trapezoid weights internally, it additionally returns d_gamma and d_scale. The shared K3 extended helper (and K1/K2/K4) are reused from the double-SSD module.

Everything operates on backend primitives through the rank-tagged [F] wrapper: the custom Backward node runs with a generic backend B, so the high-level Tensor is unavailable and the math uses B’s float_* ops.

Structs§

CombinedSingleSsdGrads
Per-input gradients produced by combined_backward for the Single-SSD. Adds d_gamma_bnlh and d_scale_bnlh over the double-ssd form crate::mamba3::double_ssd::ssd::serial_recalculated::combined_backward::CombinedGrads.

Functions§

combined_backward
Memory-efficient backward for the Mamba-3 MIMO-first chunkwise Single-SSD.