Skip to main content

k2_ssd_bmm

Function k2_ssd_bmm 

Source
pub fn k2_ssd_bmm<B: Backend>(
    c_bnlrhn: Tensor<B, 6>,
    b_bnlrhn: Tensor<B, 6>,
) -> Tensor<B, 5>
Expand description

Compute the intra-chunk CB matrix on fused (R-into-L) tensors.

§Arguments

  • c_bnlrhn: [batch, nchunks, chunk_len, R, nheads, state_rank]
  • b_bnlrhn: [batch, nchunks, chunk_len, R, nheads, state_rank]

§Returns

  • cb_bnhLL: [batch, nchunks, nheads, L, L] where L = chunk_len * mimo_rank