Skip to main content

burn_mamba/utils/
sanity.rs

1use crate::{DENY_INF, DENY_NAN};
2use burn::prelude::*;
3
4pub fn sanity<B: Backend, const D: usize>(t: &Tensor<B, D>) {
5    let mut has_nan = false;
6    let mut has_inf = false;
7
8    if DENY_NAN {
9        has_nan = t.clone().contains_nan().into_scalar().to_bool();
10        if has_nan {
11            eprintln!("got a NaN");
12        }
13    }
14    if DENY_INF {
15        has_inf = t.clone().is_inf().any().into_scalar().to_bool();
16        if has_inf {
17            eprintln!("got a INF");
18        }
19    }
20
21    if has_nan || has_inf {
22        panic!("sanity check failed");
23    }
24}
25
26pub fn sanity_nan<B: Backend, const D: usize>(t: &Tensor<B, D>) {
27    let mut has_nan = false;
28
29    if DENY_NAN {
30        has_nan = t.clone().contains_nan().into_scalar().to_bool();
31        if has_nan {
32            eprintln!("got a NaN");
33        }
34    }
35
36    if has_nan {
37        panic!("sanity check failed");
38    }
39}