burn_mamba/modules/misc/split.rs
1//! Typed-array variants of [`Tensor::split_with_sizes`].
2
3use burn::prelude::*;
4
5/// Like [`Tensor::split_with_sizes`] but returns a fixed-size array, enabling
6/// `let [a, b, c, ...] = split_into::<…, N>(t, [size_a, size_b, size_c, ...], dim);`
7/// destructuring at the call site instead of a fragile `parts.next().unwrap()` chain.
8///
9/// Panics if the underlying split does not produce exactly `N` parts (which
10/// would indicate that the requested sizes do not cover the dimension).
11pub fn split_into<const D: usize, const N: usize>(
12 t: Tensor<D>,
13 sizes: [usize; N],
14 dim: usize,
15) -> [Tensor<D>; N] {
16 let parts = t.split_with_sizes(sizes.to_vec(), dim);
17 let got = parts.len();
18 parts
19 .try_into()
20 .unwrap_or_else(|_| panic!("split_into: expected {N} parts, got {got}"))
21}