Skip to main content

burn_mamba/modules/misc/
sanity.rs

1//! Optional `NaN`/`Inf` guards for debugging numerical issues.
2//!
3//! Both checks are gated by the compile-time flags [`crate::DENY_NAN`] /
4//! [`crate::DENY_INF`] (both `false` by default).  When the flags are off the
5//! functions compile down to nothing — there is no runtime cost in release
6//! builds — so calls can be sprinkled liberally through the forward passes.
7
8use crate::{DENY_INF, DENY_NAN};
9use burn::prelude::*;
10
11/// Panics if `t` contains a `NaN` (when [`crate::DENY_NAN`] is set) or an `Inf`
12/// (when [`crate::DENY_INF`] is set).  A no-op when both flags are `false`.
13pub fn sanity<const D: usize>(t: &Tensor<D>) {
14    if !crate::DENY_NAN && !crate::DENY_INF {
15        return;
16    }
17
18    let mut has_nan = false;
19    let mut has_inf = false;
20
21    if DENY_NAN {
22        has_nan = t.clone().is_nan().any().into_scalar::<bool>().to_bool();
23        if has_nan {
24            eprintln!("got a NaN");
25        }
26    }
27    if DENY_INF {
28        has_inf = t.clone().is_inf().any().into_scalar::<bool>().to_bool();
29        if has_inf {
30            eprintln!("got a INF");
31        }
32    }
33
34    if has_nan || has_inf {
35        panic!("sanity check failed");
36    }
37}
38
39/// Like [`sanity`] but checks only for `NaN` (ignores `Inf`).
40pub fn sanity_nan<const D: usize>(t: &Tensor<D>) {
41    let mut has_nan = false;
42
43    if DENY_NAN {
44        has_nan = t.clone().is_nan().any().into_scalar::<bool>().to_bool();
45        if has_nan {
46            eprintln!("got a NaN");
47        }
48    }
49
50    if has_nan {
51        panic!("sanity check failed");
52    }
53}