1use burn::backend::Backend;
21use burn::backend::get_device_settings;
22use burn::backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
23use burn::backend::{FloatDType, Scalar, Shape, Slice, SliceArg, TensorMetadata};
24
25pub struct F<B: Backend, const D: usize>(pub FloatTensor<B>);
30
31impl<B: Backend, const D: usize> Clone for F<B, D> {
32 fn clone(&self) -> Self {
33 F(self.0.clone())
34 }
35}
36
37impl<B: Backend, const D: usize> F<B, D> {
38 pub fn new(p: FloatTensor<B>) -> Self {
40 F(p)
41 }
42
43 pub fn inner(self) -> FloatTensor<B> {
45 self.0
46 }
47
48 pub fn dims(&self) -> [usize; D] {
50 self.0.shape().dims()
51 }
52
53 pub fn device(&self) -> Device<B> {
55 B::float_device(&self.0)
56 }
57
58 pub fn dtype(&self) -> FloatDType {
60 self.0.dtype().into()
61 }
62
63 pub fn matmul(self, rhs: Self) -> Self {
65 F(B::float_matmul(self.0, rhs.0))
66 }
67
68 pub fn permute(self, axes: [usize; D]) -> Self {
70 F(B::float_permute(self.0, &axes))
71 }
72
73 pub fn swap_dims(self, dim1: usize, dim2: usize) -> Self {
75 F(B::float_swap_dims(self.0, dim1, dim2))
76 }
77
78 pub fn exp(self) -> Self {
80 F(B::float_exp(self.0))
81 }
82
83 pub fn sum_dim(self, dim: usize) -> Self {
85 F(B::float_sum_dim(self.0, dim))
86 }
87
88 pub fn cumsum(self, dim: usize) -> Self {
90 F(B::float_cumsum(self.0, dim))
91 }
92
93 pub fn flip(self, axes: &[usize]) -> Self {
95 F(B::float_flip(self.0, axes))
96 }
97
98 pub fn slice<S: SliceArg>(self, slices: S) -> Self {
101 let shape = self.0.shape();
102 let slices = slices.into_slices(&shape);
103 F(B::float_slice(self.0, &slices))
104 }
105
106 pub fn narrow(self, dim: usize, start: usize, length: usize) -> Self {
108 let mut slices: Vec<Slice> = (0..dim).map(|_| Slice::from(..)).collect();
109 slices.push(Slice::from(start..start + length));
110 let shape = self.0.shape();
111 let slices = (&slices[..]).into_slices(&shape);
112 F(B::float_slice(self.0, &slices))
113 }
114
115 pub fn reshape<const D2: usize>(self, shape: [usize; D2]) -> F<B, D2> {
117 F(B::float_reshape(self.0, Shape::new(shape)))
118 }
119
120 pub fn expand(self, shape: [usize; D]) -> Self {
122 F(B::float_expand(self.0, Shape::new(shape)))
123 }
124
125 pub fn squeeze_dim<const D2: usize>(self, dim: usize) -> F<B, D2> {
127 let current = self.0.shape().dims::<D>();
128 let mut new_dims = [0usize; D2];
129 new_dims[..dim].copy_from_slice(¤t[..dim]);
130 new_dims[dim..].copy_from_slice(¤t[dim + 1..]);
131 F(B::float_reshape(self.0, Shape::new(new_dims)))
132 }
133
134 pub fn unsqueeze_dim<const D2: usize>(self, dim: usize) -> F<B, D2> {
136 let shape = self.0.shape().dims::<D>();
137 let mut dims = [1usize; D2];
138 dims[0..dim].copy_from_slice(&shape[0..dim]);
139 if dim < D {
140 dims[dim] = 1;
141 dims[(dim + 1)..].copy_from_slice(&shape[dim..]);
142 } else {
143 dims[dim] = 1;
144 }
145 F(B::float_reshape(self.0, Shape::new(dims)))
146 }
147
148 pub fn unsqueeze_dims<const D2: usize>(self, axes: &[isize]) -> F<B, D2> {
153 let old_dims = self.0.shape().dims::<D>();
154 let mut new_dims = [1usize; D2];
155
156 let mut neg_offset = D2;
158 let mut dim_indices = axes
159 .iter()
160 .map(|&d| {
161 (if d < 0 {
162 neg_offset -= 1;
163 d + neg_offset as isize + 1
164 } else {
165 d
166 }) as usize
167 })
168 .collect::<Vec<usize>>();
169 dim_indices.sort_unstable();
170 for i in 1..dim_indices.len() {
172 if dim_indices[i] <= dim_indices[i - 1] {
173 dim_indices[i] = dim_indices[i - 1] + 1;
174 }
175 }
176
177 let mut dim_indices_curr = 0usize;
178 let mut old_dims_curr = 0usize;
179 for new_dims_curr in 0..D2 {
180 if dim_indices_curr == dim_indices.len() {
181 new_dims[new_dims_curr..].copy_from_slice(&old_dims[old_dims_curr..]);
182 break;
183 }
184 if new_dims_curr == dim_indices[dim_indices_curr] {
185 dim_indices_curr += 1;
186 } else {
187 new_dims[new_dims_curr] = old_dims[old_dims_curr];
188 old_dims_curr += 1;
189 }
190 }
191
192 F(B::float_reshape(self.0, Shape::new(new_dims)))
193 }
194
195 pub fn triu(self, diagonal: i64) -> Self {
201 let dims = self.0.shape().dims::<D>();
202 let rows = dims[D - 2];
203 let cols = dims[D - 1];
204 let device = B::float_device(&self.0);
205
206 let mask2 = tri_bool::<B>(rows, cols, diagonal, false, &device);
207 let mut lead = [1usize; D];
208 lead[D - 2] = rows;
209 lead[D - 1] = cols;
210 let mask = B::bool_reshape(mask2, Shape::new(lead));
211 let mask = B::bool_expand(mask, Shape::new(dims));
212 F(B::float_mask_fill(self.0, mask, Scalar::from(0.0f32)))
213 }
214
215 pub fn mask_fill(self, mask: Mask<B>, value: f32) -> Self {
217 F(B::float_mask_fill(self.0, mask.0, Scalar::from(value)))
218 }
219
220 pub fn cat(tensors: Vec<F<B, D>>, dim: usize) -> Self {
222 F(B::float_cat(
223 tensors.into_iter().map(|t| t.0).collect(),
224 dim,
225 ))
226 }
227
228 pub fn stack<const D2: usize>(tensors: Vec<F<B, D>>, dim: usize) -> F<B, D2> {
230 let unsqueezed = tensors
231 .into_iter()
232 .map(|t| {
233 let current = t.0.shape().dims::<D>();
234 let mut new_dims = [1usize; D2];
235 new_dims[0..dim].copy_from_slice(¤t[0..dim]);
236 new_dims[dim] = 1;
237 new_dims[(dim + 1)..].copy_from_slice(¤t[dim..]);
238 B::float_reshape(t.0, Shape::new(new_dims))
239 })
240 .collect::<Vec<_>>();
241 F(B::float_cat(unsqueezed, dim))
242 }
243
244 pub fn zeros(shape: [usize; D], device: &Device<B>, dtype: FloatDType) -> Self {
246 F(B::float_zeros(Shape::new(shape), device, dtype))
247 }
248
249 pub fn full(shape: [usize; D], value: f32, device: &Device<B>, dtype: FloatDType) -> Self {
251 F(B::float_full(
252 Shape::new(shape),
253 Scalar::from(value),
254 device,
255 dtype,
256 ))
257 }
258}
259
260impl<B: Backend, const D: usize> core::ops::Add for F<B, D> {
261 type Output = Self;
262 fn add(self, rhs: Self) -> Self {
263 F(B::float_add(self.0, rhs.0))
264 }
265}
266
267impl<B: Backend, const D: usize> core::ops::Sub for F<B, D> {
268 type Output = Self;
269 fn sub(self, rhs: Self) -> Self {
270 F(B::float_sub(self.0, rhs.0))
271 }
272}
273
274impl<B: Backend, const D: usize> core::ops::Mul for F<B, D> {
275 type Output = Self;
276 fn mul(self, rhs: Self) -> Self {
277 F(B::float_mul(self.0, rhs.0))
278 }
279}
280
281impl<B: Backend, const D: usize> core::ops::Neg for F<B, D> {
282 type Output = Self;
283 fn neg(self) -> Self {
284 F(B::float_neg(self.0))
285 }
286}
287
288pub struct Mask<B: Backend>(pub BoolTensor<B>);
293
294impl<B: Backend> Clone for Mask<B> {
295 fn clone(&self) -> Self {
296 Mask(self.0.clone())
297 }
298}
299
300impl<B: Backend> Mask<B> {
301 pub fn tril_mask(rows: usize, cols: usize, offset: i64, device: &Device<B>) -> Self {
306 Mask(tri_bool::<B>(rows, cols, offset, true, device))
307 }
308
309 pub fn reshape<const N: usize>(self, shape: [usize; N]) -> Self {
311 Mask(B::bool_reshape(self.0, Shape::new(shape)))
312 }
313
314 pub fn expand<const N: usize>(self, shape: [usize; N]) -> Self {
316 Mask(B::bool_expand(self.0, Shape::new(shape)))
317 }
318}
319
320fn tri_bool<B: Backend>(
332 rows: usize,
333 cols: usize,
334 offset: i64,
335 lower: bool,
336 device: &Device<B>,
337) -> BoolTensor<B> {
338 let settings = get_device_settings::<B>(device);
339 let shape = Shape::new([rows, cols]);
340
341 let rows_i: IntTensor<B> = B::int_reshape(
342 B::int_arange(0..rows as i64, device, settings.int_dtype),
343 Shape::new([rows, 1]),
344 );
345 let cols_i: IntTensor<B> = B::int_reshape(
346 B::int_arange(0..cols as i64, device, settings.int_dtype),
347 Shape::new([1, cols]),
348 );
349 let matrix = B::int_sub(
351 B::int_expand(rows_i, shape.clone()),
352 B::int_expand(cols_i, shape),
353 );
354 let threshold = Scalar::from(-offset);
355 if lower {
356 B::int_lower_elem(matrix, threshold, settings.bool_dtype)
357 } else {
358 B::int_greater_elem(matrix, threshold, settings.bool_dtype)
359 }
360}
361
362pub fn san<B: Backend, const D: usize>(t: &F<B, D>) {
369 if !crate::DENY_NAN && !crate::DENY_INF {
370 return;
371 }
372 let data = burn::tensor::read_sync(B::float_into_data(t.0.clone()))
373 .expect("sanity check: failed to read tensor data");
374 let mut has_nan = false;
375 let mut has_inf = false;
376 for v in data.iter::<f64>() {
377 if crate::DENY_NAN && v.is_nan() {
378 has_nan = true;
379 }
380 if crate::DENY_INF && v.is_infinite() {
381 has_inf = true;
382 }
383 }
384 if has_nan {
385 eprintln!("got a NaN");
386 }
387 if has_inf {
388 eprintln!("got a INF");
389 }
390 if has_nan || has_inf {
391 panic!("sanity check failed");
392 }
393}