pub fn combined_backward<B: Backend>(
d_y_bnlmhp: F<B, 6>,
d_final_bhpr: F<B, 4>,
v_bnlmhp: F<B, 6>,
da_bnlh: F<B, 4>,
b_bnlmhr: F<B, 6>,
c_bnlmhr: F<B, 6>,
initial_state_bhpr: F<B, 4>,
) -> CombinedGrads<B>Expand description
Memory-efficient backward for the Mamba-3 MIMO-first chunkwise SSD.
Recomputes the forward intermediates (K1-K4) from the saved inputs, then runs a reverse per-chunk loop that fuses the K5 (BLUE + ORANGE) backward with the K4 state-passing backward. K3/K2/K1 backwards run as single batched ops once the loop has collected all per-chunk slices.
§Arguments
d_y_bnlmhp— upstream gradient of the SSD outputd_final_bhpr— upstream gradient of the final SSM statev_bnlmhp,da_bnlh,b_bnlmhr,c_bnlmhr,initial_state_bhpr— the five saved forward inputs
§Returns
One CombinedGrads struct containing gradients for all 5 inputs.