Skip to main content

burn_mamba/modules/misc/
split.rs

1//! Typed-array variants of [`Tensor::split_with_sizes`].
2
3use burn::prelude::*;
4
5/// Like [`Tensor::split_with_sizes`] but returns a fixed-size array, enabling
6/// `let [a, b, c, ...] = split_into::<…, N>(t, [size_a, size_b, size_c, ...], dim);`
7/// destructuring at the call site instead of a fragile `parts.next().unwrap()` chain.
8///
9/// Panics if the underlying split does not produce exactly `N` parts (which
10/// would indicate that the requested sizes do not cover the dimension).
11pub fn split_into<const D: usize, const N: usize>(
12    t: Tensor<D>,
13    sizes: [usize; N],
14    dim: usize,
15) -> [Tensor<D>; N] {
16    let parts = t.split_with_sizes(sizes.to_vec(), dim);
17    let got = parts.len();
18    parts
19        .try_into()
20        .unwrap_or_else(|_| panic!("split_into: expected {N} parts, got {got}"))
21}