pub fn combined_backward<B: Backend>(
d_y_bnlhp: Tensor<B, 5>,
d_final_bhpr: Tensor<B, 4>,
x_bnlhp: Tensor<B, 5>,
dt_discretized_bhnl: Tensor<B, 4>,
b_bnlgr: Tensor<B, 5>,
c_bnlgr: Tensor<B, 5>,
d_h: Tensor<B, 1>,
initial_state_bhpr: Tensor<B, 4>,
a_decay_h: Tensor<B, 1>,
) -> CombinedGrads<B>Expand description
Core gradient computation. All arguments use the shapes from the forward.
d_y_bnlhp : upstream gradient of the scan output [B,N,L,H,P]
d_final_bhpr : upstream gradient of the final state [B,H,P,R]
Returns one CombinedGrads struct containing gradients for all 7 inputs.