pub fn combined_backward<B: Backend>(
d_y_bnlmhp: F<B, 6>,
d_final_bhpr: F<B, 4>,
v_bnlmhp: F<B, 6>,
da_bnlh: F<B, 4>,
b_bnlmhr: F<B, 6>,
c_bnlmhr: F<B, 6>,
gamma_bnlh: F<B, 4>,
scale_bnlh: F<B, 4>,
initial_state_bhpr: F<B, 4>,
) -> CombinedSingleSsdGrads<B>Expand description
Memory-efficient backward for the Mamba-3 MIMO-first chunkwise Single-SSD.
Recomputes the forward intermediates (K1–K4) from the saved inputs, then:
- runs a reverse per-chunk loop that fuses the K5 BLUE (state-to-output) and the strict lower-triangular LOWER (intra-chunk) backward with the K4 state-passing backward, and
- computes the γ-weighted same-step DIAG backward batched (it has no
recurrence, and the
m × mworking tensors are tiny).
K3/K2/K1 backwards run as single batched ops once the loop has collected all per-chunk slices.
§Arguments
d_y_bnlmhp— upstream gradient of the SSD outputd_final_bhpr— upstream gradient of the final SSM statev_bnlmhp,da_bnlh,b_bnlmhr,c_bnlmhr,gamma_bnlh,scale_bnlh,initial_state_bhpr— the seven saved forward inputs
§Returns
One CombinedSingleSsdGrads with gradients for all 7 inputs.