Expand description
Rank-tagged FloatTensor primitive wrapper mirroring the Tensor method
API, used by the custom-backward gradient math.
§Rank-tagged primitive tensor wrapper for the custom backward math
[F] is a thin newtype over a backend’s [FloatTensor] primitive that
mirrors the subset of the high-level Tensor method
API used by the recompute-backward gradient math
(*/serial_recalculated/combined_backward.rs).
Why it exists: in Burn 0.22 the high-level Tensor is pinned to the global
Dispatch backend, so it cannot be built from an arbitrary backend B’s
primitive. A custom Backward
node runs with a generic B, so its gradient math must operate directly on
B’s primitives via the B::float_* ops. This wrapper keeps that
primitive-level math reading like the original Tensor code (method
chaining, shape-suffixed names) instead of deeply nested free-function calls.
The rank D is a compile-time tag for parity with the ported code and to
catch rank mistakes; every operation ultimately defers to B’s
runtime-shaped primitive ops.
Structs§
- F
- A backend float-tensor primitive tagged with a compile-time rank
D. - Mask
- A boolean mask primitive used with
F::mask_fill.
Functions§
- san
- Primitive analogue of
crate::modules::sanityforF. - tri_
bool 🔒 - Build a
[rows, cols]triangular boolean mask on-device.