Skip to main content

combined_backward

Function combined_backward 

Source
pub fn combined_backward<B: Backend>(
    d_y_bnlmhp: F<B, 6>,
    d_final_bhpr: F<B, 4>,
    v_bnlmhp: F<B, 6>,
    da_bnlh: F<B, 4>,
    b_bnlmhr: F<B, 6>,
    c_bnlmhr: F<B, 6>,
    gamma_bnlh: F<B, 4>,
    scale_bnlh: F<B, 4>,
    initial_state_bhpr: F<B, 4>,
) -> CombinedSingleSsdGrads<B>
Expand description

Memory-efficient backward for the Mamba-3 MIMO-first chunkwise Single-SSD.

Recomputes the forward intermediates (K1–K4) from the saved inputs, then:

  • runs a reverse per-chunk loop that fuses the K5 BLUE (state-to-output) and the strict lower-triangular LOWER (intra-chunk) backward with the K4 state-passing backward, and
  • computes the γ-weighted same-step DIAG backward batched (it has no recurrence, and the m × m working tensors are tiny).

K3/K2/K1 backwards run as single batched ops once the loop has collected all per-chunk slices.

§Arguments

  • d_y_bnlmhp — upstream gradient of the SSD output
  • d_final_bhpr — upstream gradient of the final SSM state
  • v_bnlmhp, da_bnlh, b_bnlmhr, c_bnlmhr, gamma_bnlh, scale_bnlh, initial_state_bhpr — the seven saved forward inputs

§Returns

One CombinedSingleSsdGrads with gradients for all 7 inputs.