burn_mamba/utils/
sanity.rs1use crate::{DENY_INF, DENY_NAN};
2use burn::prelude::*;
3
4pub fn sanity<B: Backend, const D: usize>(t: &Tensor<B, D>) {
5 let mut has_nan = false;
6 let mut has_inf = false;
7
8 if DENY_NAN {
9 has_nan = t.clone().contains_nan().into_scalar().to_bool();
10 if has_nan {
11 eprintln!("got a NaN");
12 }
13 }
14 if DENY_INF {
15 has_inf = t.clone().is_inf().any().into_scalar().to_bool();
16 if has_inf {
17 eprintln!("got a INF");
18 }
19 }
20
21 if has_nan || has_inf {
22 panic!("sanity check failed");
23 }
24}
25
26pub fn sanity_nan<B: Backend, const D: usize>(t: &Tensor<B, D>) {
27 let mut has_nan = false;
28
29 if DENY_NAN {
30 has_nan = t.clone().contains_nan().into_scalar().to_bool();
31 if has_nan {
32 eprintln!("got a NaN");
33 }
34 }
35
36 if has_nan {
37 panic!("sanity check failed");
38 }
39}