Skip to main content

Module combined_backward

Module combined_backward 

Source
Expand description

Recompute-based gradient math (the memory-efficient backward).

§Recompute-based gradient math for the Mamba-3 double-SSD

The analytic backward of the MIMO-first serial scan used by each pass of the double-SSD decomposition. The forward intermediates (K1–K4) are recomputed from the saved leaf inputs rather than stashed, then a reverse per-chunk loop fuses the K5 and K4 backwards; K1/K2/K3 backwards run batched once the loop has gathered the per-chunk slices. The fused L·M length carries the mimo_rank axis through the intra-chunk products.

Everything operates on backend primitives through the rank-tagged [F] wrapper: the custom Backward node runs with a generic backend B, so the high-level Tensor is unavailable and the math uses B’s float_* ops. The recomputed K1/K2/K4 kernels are local primitive ports of the high-level [super::super::serial] kernels.

Structs§

CombinedGrads
Per-input gradients produced by combined_backward (one field per differentiable forward input of the double-SSD scan).

Functions§

combined_backward
Memory-efficient backward for the Mamba-3 MIMO-first chunkwise SSD.
k3_ssd_chunk_state_extended
Same as k3_ssd_chunk_state but also returns intermediates needed by the custom backward: