Skip to main content

burn_mamba/utils/
fprim.rs

1//! # Rank-tagged primitive tensor wrapper for the custom backward math
2//!
3//! [`F`] is a thin newtype over a backend's [`FloatTensor`] primitive that
4//! mirrors the subset of the high-level [`Tensor`](burn::tensor::Tensor) method
5//! API used by the recompute-backward gradient math
6//! (`*/serial_recalculated/combined_backward.rs`).
7//!
8//! Why it exists: in Burn 0.22 the high-level `Tensor` is pinned to the global
9//! `Dispatch` backend, so it cannot be built from an arbitrary backend `B`'s
10//! primitive.  A custom [`Backward`](burn::backend::autodiff::ops::Backward)
11//! node runs with a *generic* `B`, so its gradient math must operate directly on
12//! `B`'s primitives via the `B::float_*` ops.  This wrapper keeps that
13//! primitive-level math reading like the original `Tensor` code (method
14//! chaining, shape-suffixed names) instead of deeply nested free-function calls.
15//!
16//! The rank `D` is a compile-time tag for parity with the ported code and to
17//! catch rank mistakes; every operation ultimately defers to `B`'s
18//! runtime-shaped primitive ops.
19
20use burn::backend::Backend;
21use burn::backend::get_device_settings;
22use burn::backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
23use burn::backend::{FloatDType, Scalar, Shape, Slice, SliceArg, TensorMetadata};
24
25/// A backend float-tensor primitive tagged with a compile-time rank `D`.
26///
27/// Mirrors the slice of [`Tensor`](burn::tensor::Tensor)'s method API needed by
28/// the custom backward gradient math, operating directly on `B`'s primitives.
29pub struct F<B: Backend, const D: usize>(pub FloatTensor<B>);
30
31impl<B: Backend, const D: usize> Clone for F<B, D> {
32    fn clone(&self) -> Self {
33        F(self.0.clone())
34    }
35}
36
37impl<B: Backend, const D: usize> F<B, D> {
38    /// Wrap a raw primitive.
39    pub fn new(p: FloatTensor<B>) -> Self {
40        F(p)
41    }
42
43    /// Unwrap to the raw primitive.
44    pub fn inner(self) -> FloatTensor<B> {
45        self.0
46    }
47
48    /// Runtime shape as a `[usize; D]` array.
49    pub fn dims(&self) -> [usize; D] {
50        self.0.shape().dims()
51    }
52
53    /// Device the tensor lives on.
54    pub fn device(&self) -> Device<B> {
55        B::float_device(&self.0)
56    }
57
58    /// Float dtype of the tensor.
59    pub fn dtype(&self) -> FloatDType {
60        self.0.dtype().into()
61    }
62
63    /// Batched matrix multiplication over the last two dims.
64    pub fn matmul(self, rhs: Self) -> Self {
65        F(B::float_matmul(self.0, rhs.0))
66    }
67
68    /// Permute the axes (rank-preserving).
69    pub fn permute(self, axes: [usize; D]) -> Self {
70        F(B::float_permute(self.0, &axes))
71    }
72
73    /// Swap two axes.
74    pub fn swap_dims(self, dim1: usize, dim2: usize) -> Self {
75        F(B::float_swap_dims(self.0, dim1, dim2))
76    }
77
78    /// Element-wise `exp`.
79    pub fn exp(self) -> Self {
80        F(B::float_exp(self.0))
81    }
82
83    /// Sum along `dim`, keeping it as a size-1 axis (rank-preserving).
84    pub fn sum_dim(self, dim: usize) -> Self {
85        F(B::float_sum_dim(self.0, dim))
86    }
87
88    /// Cumulative sum along `dim`.
89    pub fn cumsum(self, dim: usize) -> Self {
90        F(B::float_cumsum(self.0, dim))
91    }
92
93    /// Reverse the order of elements along the given `axes` (rank-preserving).
94    pub fn flip(self, axes: &[usize]) -> Self {
95        F(B::float_flip(self.0, axes))
96    }
97
98    /// Slice the tensor (rank-preserving), accepting the same `s![..]` args as
99    /// the high-level API.
100    pub fn slice<S: SliceArg>(self, slices: S) -> Self {
101        let shape = self.0.shape();
102        let slices = slices.into_slices(&shape);
103        F(B::float_slice(self.0, &slices))
104    }
105
106    /// Narrow `dim` to `[start, start + length)` (rank-preserving).
107    pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
108        let mut slices: Vec<Slice> = (0..dim).map(|_| Slice::from(..)).collect();
109        slices.push(Slice::from(start..start + length));
110        let shape = self.0.shape();
111        let slices = (&slices[..]).into_slices(&shape);
112        F(B::float_slice(self.0, &slices))
113    }
114
115    /// Reshape to a new rank `D2`.
116    pub fn reshape<const D2: usize>(self, shape: [usize; D2]) -> F<B, D2> {
117        F(B::float_reshape(self.0, Shape::new(shape)))
118    }
119
120    /// Broadcast-expand to `shape` (rank-preserving).
121    pub fn expand(self, shape: [usize; D]) -> Self {
122        F(B::float_expand(self.0, Shape::new(shape)))
123    }
124
125    /// Remove the size-1 axis at `dim`, yielding rank `D2 = D - 1`.
126    pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> F<B, D2> {
127        let current = self.0.shape().dims::<D>();
128        let mut new_dims = [0usize; D2];
129        new_dims[..dim].copy_from_slice(&current[..dim]);
130        new_dims[dim..].copy_from_slice(&current[dim + 1..]);
131        F(B::float_reshape(self.0, Shape::new(new_dims)))
132    }
133
134    /// Insert a size-1 axis at `dim`, yielding rank `D2 = D + 1`.
135    pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> F<B, D2> {
136        let shape = self.0.shape().dims::<D>();
137        let mut dims = [1usize; D2];
138        dims[0..dim].copy_from_slice(&shape[0..dim]);
139        if dim < D {
140            dims[dim] = 1;
141            dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
142        } else {
143            dims[dim] = 1;
144        }
145        F(B::float_reshape(self.0, Shape::new(dims)))
146    }
147
148    /// Insert size-1 axes at the given output positions, yielding rank `D2`.
149    ///
150    /// Mirrors [`Tensor::unsqueeze_dims`](burn::tensor::Tensor::unsqueeze_dims):
151    /// negative indices count from the back and duplicates insert multiple axes.
152    pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> F<B, D2> {
153        let old_dims = self.0.shape().dims::<D>();
154        let mut new_dims = [1usize; D2];
155
156        // Resolve negative indices (counting from the back, in reverse order).
157        let mut neg_offset = D2;
158        let mut dim_indices = axes
159            .iter()
160            .map(|&d| {
161                (if d < 0 {
162                    neg_offset -= 1;
163                    d + neg_offset as isize + 1
164                } else {
165                    d
166                }) as usize
167            })
168            .collect::<Vec<usize>>();
169        dim_indices.sort_unstable();
170        // Duplicate axes mean "insert N dims at that index": bump duplicates.
171        for i in 1..dim_indices.len() {
172            if dim_indices[i] <= dim_indices[i - 1] {
173                dim_indices[i] = dim_indices[i - 1] + 1;
174            }
175        }
176
177        let mut dim_indices_curr = 0usize;
178        let mut old_dims_curr = 0usize;
179        for new_dims_curr in 0..D2 {
180            if dim_indices_curr == dim_indices.len() {
181                new_dims[new_dims_curr..].copy_from_slice(&old_dims[old_dims_curr..]);
182                break;
183            }
184            if new_dims_curr == dim_indices[dim_indices_curr] {
185                dim_indices_curr += 1;
186            } else {
187                new_dims[new_dims_curr] = old_dims[old_dims_curr];
188                old_dims_curr += 1;
189            }
190        }
191
192        F(B::float_reshape(self.0, Shape::new(new_dims)))
193    }
194
195    /// Zero everything strictly below the `diagonal` (keeps the upper triangle).
196    ///
197    /// Equivalent to [`Tensor::triu`](burn::tensor::Tensor::triu): builds the
198    /// triangular bool mask over the last two dims and fills the masked region
199    /// with `0`.
200    pub fn triu(self, diagonal: i64) -> Self {
201        let dims = self.0.shape().dims::<D>();
202        let rows = dims[D - 2];
203        let cols = dims[D - 1];
204        let device = B::float_device(&self.0);
205
206        let mask2 = tri_bool::<B>(rows, cols, diagonal, false, &device);
207        let mut lead = [1usize; D];
208        lead[D - 2] = rows;
209        lead[D - 1] = cols;
210        let mask = B::bool_reshape(mask2, Shape::new(lead));
211        let mask = B::bool_expand(mask, Shape::new(dims));
212        F(B::float_mask_fill(self.0, mask, Scalar::from(0.0f32)))
213    }
214
215    /// Fill the positions where `mask` is `true` with `value`.
216    pub fn mask_fill(self, mask: Mask<B>, value: f32) -> Self {
217        F(B::float_mask_fill(self.0, mask.0, Scalar::from(value)))
218    }
219
220    /// Concatenate same-rank tensors along `dim` (rank-preserving).
221    pub fn cat(tensors: Vec<F<B, D>>, dim: usize) -> Self {
222        F(B::float_cat(
223            tensors.into_iter().map(|t| t.0).collect(),
224            dim,
225        ))
226    }
227
228    /// Stack same-rank tensors along a fresh axis `dim`, yielding rank `D2 = D + 1`.
229    pub fn stack<const D2: usize>(tensors: Vec<F<B, D>>, dim: usize) -> F<B, D2> {
230        let unsqueezed = tensors
231            .into_iter()
232            .map(|t| {
233                let current = t.0.shape().dims::<D>();
234                let mut new_dims = [1usize; D2];
235                new_dims[0..dim].copy_from_slice(&current[0..dim]);
236                new_dims[dim] = 1;
237                new_dims[(dim + 1)..].copy_from_slice(&current[dim..]);
238                B::float_reshape(t.0, Shape::new(new_dims))
239            })
240            .collect::<Vec<_>>();
241        F(B::float_cat(unsqueezed, dim))
242    }
243
244    /// All-zeros tensor of the given shape, dtype and device.
245    pub fn zeros(shape: [usize; D], device: &Device<B>, dtype: FloatDType) -> Self {
246        F(B::float_zeros(Shape::new(shape), device, dtype))
247    }
248
249    /// Constant-filled tensor of the given shape, dtype and device.
250    pub fn full(shape: [usize; D], value: f32, device: &Device<B>, dtype: FloatDType) -> Self {
251        F(B::float_full(
252            Shape::new(shape),
253            Scalar::from(value),
254            device,
255            dtype,
256        ))
257    }
258}
259
260impl<B: Backend, const D: usize> core::ops::Add for F<B, D> {
261    type Output = Self;
262    fn add(self, rhs: Self) -> Self {
263        F(B::float_add(self.0, rhs.0))
264    }
265}
266
267impl<B: Backend, const D: usize> core::ops::Sub for F<B, D> {
268    type Output = Self;
269    fn sub(self, rhs: Self) -> Self {
270        F(B::float_sub(self.0, rhs.0))
271    }
272}
273
274impl<B: Backend, const D: usize> core::ops::Mul for F<B, D> {
275    type Output = Self;
276    fn mul(self, rhs: Self) -> Self {
277        F(B::float_mul(self.0, rhs.0))
278    }
279}
280
281impl<B: Backend, const D: usize> core::ops::Neg for F<B, D> {
282    type Output = Self;
283    fn neg(self) -> Self {
284        F(B::float_neg(self.0))
285    }
286}
287
288/// A boolean mask primitive used with [`F::mask_fill`].
289///
290/// Mirrors the slice of the bool-tensor API needed to build and broadcast the
291/// causal masks in the custom backward (construct → reshape → expand).
292pub struct Mask<B: Backend>(pub BoolTensor<B>);
293
294impl<B: Backend> Clone for Mask<B> {
295    fn clone(&self) -> Self {
296        Mask(self.0.clone())
297    }
298}
299
300impl<B: Backend> Mask<B> {
301    /// `[rows, cols]` mask that is `true` strictly above the `offset` diagonal.
302    ///
303    /// Matches [`Tensor::tril_mask`](burn::tensor::Tensor::tril_mask): the
304    /// `true` entries are the region a `tril` would fill.
305    pub fn tril_mask(rows: usize, cols: usize, offset: i64, device: &Device<B>) -> Self {
306        Mask(tri_bool::<B>(rows, cols, offset, true, device))
307    }
308
309    /// Reshape the mask to a new rank.
310    pub fn reshape<const N: usize>(self, shape: [usize; N]) -> Self {
311        Mask(B::bool_reshape(self.0, Shape::new(shape)))
312    }
313
314    /// Broadcast-expand the mask.
315    pub fn expand<const N: usize>(self, shape: [usize; N]) -> Self {
316        Mask(B::bool_expand(self.0, Shape::new(shape)))
317    }
318}
319
320/// Build a `[rows, cols]` triangular boolean mask on-device.
321///
322/// Mirrors `tri_mask` in Burn: with `matrix = row - col`, the result is
323/// `matrix < -offset` when `lower` (the `tril_mask`/lower-triangle region) and
324/// `matrix > -offset` otherwise (the `triu_mask`/upper-triangle region) — i.e.
325/// the original `row - col + offset ≷ 0` test with `offset` folded into the
326/// comparison threshold.
327///
328/// Built from on-device `arange`/comparison rather than uploaded host `bool`
329/// data: cubecl backends reject a `Bool(Native)` `bool_from_data`, and the
330/// comparison output here adopts the device's configured bool store.
331fn tri_bool<B: Backend>(
332    rows: usize,
333    cols: usize,
334    offset: i64,
335    lower: bool,
336    device: &Device<B>,
337) -> BoolTensor<B> {
338    let settings = get_device_settings::<B>(device);
339    let shape = Shape::new([rows, cols]);
340
341    let rows_i: IntTensor<B> = B::int_reshape(
342        B::int_arange(0..rows as i64, device, settings.int_dtype),
343        Shape::new([rows, 1]),
344    );
345    let cols_i: IntTensor<B> = B::int_reshape(
346        B::int_arange(0..cols as i64, device, settings.int_dtype),
347        Shape::new([1, cols]),
348    );
349    // matrix_rc = row - col (broadcast to [rows, cols]).
350    let matrix = B::int_sub(
351        B::int_expand(rows_i, shape.clone()),
352        B::int_expand(cols_i, shape),
353    );
354    let threshold = Scalar::from(-offset);
355    if lower {
356        B::int_lower_elem(matrix, threshold, settings.bool_dtype)
357    } else {
358        B::int_greater_elem(matrix, threshold, settings.bool_dtype)
359    }
360}
361
362/// Primitive analogue of [`crate::modules::sanity`] for [`F`].
363///
364/// Panics if `t` contains a `NaN` (when [`crate::DENY_NAN`] is set) or an `Inf`
365/// (when [`crate::DENY_INF`] is set).  A no-op — with no device read — when both
366/// flags are `false` (the default), so it can be sprinkled through the backward
367/// math at no release-build cost.
368pub fn san<B: Backend, const D: usize>(t: &F<B, D>) {
369    if !crate::DENY_NAN && !crate::DENY_INF {
370        return;
371    }
372    let data = burn::tensor::read_sync(B::float_into_data(t.0.clone()))
373        .expect("sanity check: failed to read tensor data");
374    let mut has_nan = false;
375    let mut has_inf = false;
376    for v in data.iter::<f64>() {
377        if crate::DENY_NAN && v.is_nan() {
378            has_nan = true;
379        }
380        if crate::DENY_INF && v.is_infinite() {
381            has_inf = true;
382        }
383    }
384    if has_nan {
385        eprintln!("got a NaN");
386    }
387    if has_inf {
388        eprintln!("got a INF");
389    }
390    if has_nan || has_inf {
391        panic!("sanity check failed");
392    }
393}