Expand description
Recompute-based gradient math (the memory-efficient backward).
§Recompute-based gradient math for the Mamba-3 double-SSD
The analytic backward of the MIMO-first serial scan used by each pass of the
double-SSD decomposition. The forward intermediates (K1–K4) are recomputed
from the saved leaf inputs rather than stashed, then a reverse per-chunk loop
fuses the K5 and K4 backwards; K1/K2/K3 backwards run batched once the loop
has gathered the per-chunk slices. The fused L·M length carries the
mimo_rank axis through the intra-chunk products.
Everything operates on backend primitives through the rank-tagged [F]
wrapper: the custom Backward node
runs with a generic backend B, so the high-level Tensor is unavailable
and the math uses B’s float_* ops. The recomputed K1/K2/K4 kernels are
local primitive ports of the high-level [super::super::serial] kernels.
Structs§
- Combined
Grads - Per-input gradients produced by
combined_backward(one field per differentiable forward input of the double-SSD scan).
Functions§
- combined_
backward - Memory-efficient backward for the Mamba-3 MIMO-first chunkwise SSD.
- k3_
ssd_ chunk_ state_ extended - Same as
k3_ssd_chunk_statebut also returns intermediates needed by the custom backward: