Function k2_ssd_bmm
Source pub fn k2_ssd_bmm<B: Backend>(
c_bnlrhn: Tensor<B, 6>,
b_bnlrhn: Tensor<B, 6>,
) -> Tensor<B, 5>
Expand description
Compute the intra-chunk CB matrix on fused (R-into-L) tensors.
§Arguments
c_bnlrhn: [batch, nchunks, chunk_len, R, nheads, state_rank]
b_bnlrhn: [batch, nchunks, chunk_len, R, nheads, state_rank]
§Returns
cb_bnhLL: [batch, nchunks, nheads, L, L] where L = chunk_len * mimo_rank