Skip to main content

combined_backward

Function combined_backward 

Source
pub fn combined_backward<B: Backend>(
    d_y_bnlhp: Tensor<B, 5>,
    d_final_bhpr: Tensor<B, 4>,
    x_bnlhp: Tensor<B, 5>,
    dt_discretized_bhnl: Tensor<B, 4>,
    b_bnlgr: Tensor<B, 5>,
    c_bnlgr: Tensor<B, 5>,
    d_h: Tensor<B, 1>,
    initial_state_bhpr: Tensor<B, 4>,
    a_decay_h: Tensor<B, 1>,
) -> CombinedGrads<B>
Expand description

Core gradient computation. All arguments use the shapes from the forward.

d_y_bnlhp : upstream gradient of the scan output [B,N,L,H,P] d_final_bhpr : upstream gradient of the final state [B,H,P,R]

Returns one CombinedGrads struct containing gradients for all 7 inputs.