Skip to main content

combined_backward

Function combined_backward 

Source
pub fn combined_backward<B: Backend>(
    d_y_bnlmhp: F<B, 6>,
    d_final_bhpr: F<B, 4>,
    v_bnlmhp: F<B, 6>,
    da_bnlh: F<B, 4>,
    b_bnlmhr: F<B, 6>,
    c_bnlmhr: F<B, 6>,
    initial_state_bhpr: F<B, 4>,
) -> CombinedGrads<B>
Expand description

Memory-efficient backward for the Mamba-3 MIMO-first chunkwise SSD.

Recomputes the forward intermediates (K1-K4) from the saved inputs, then runs a reverse per-chunk loop that fuses the K5 (BLUE + ORANGE) backward with the K4 state-passing backward. K3/K2/K1 backwards run as single batched ops once the loop has collected all per-chunk slices.

§Arguments

  • d_y_bnlmhp — upstream gradient of the SSD output
  • d_final_bhpr — upstream gradient of the final SSM state
  • v_bnlmhp, da_bnlh, b_bnlmhr, c_bnlmhr, initial_state_bhpr — the five saved forward inputs

§Returns

One CombinedGrads struct containing gradients for all 5 inputs.