Skip to main content

Module fprim

Module fprim 

Source
Expand description

Rank-tagged FloatTensor primitive wrapper mirroring the Tensor method API, used by the custom-backward gradient math.

§Rank-tagged primitive tensor wrapper for the custom backward math

[F] is a thin newtype over a backend’s [FloatTensor] primitive that mirrors the subset of the high-level Tensor method API used by the recompute-backward gradient math (*/serial_recalculated/combined_backward.rs).

Why it exists: in Burn 0.22 the high-level Tensor is pinned to the global Dispatch backend, so it cannot be built from an arbitrary backend B’s primitive. A custom Backward node runs with a generic B, so its gradient math must operate directly on B’s primitives via the B::float_* ops. This wrapper keeps that primitive-level math reading like the original Tensor code (method chaining, shape-suffixed names) instead of deeply nested free-function calls.

The rank D is a compile-time tag for parity with the ported code and to catch rank mistakes; every operation ultimately defers to B’s runtime-shaped primitive ops.

Structs§

F
A backend float-tensor primitive tagged with a compile-time rank D.
Mask
A boolean mask primitive used with F::mask_fill.

Functions§

san
Primitive analogue of crate::modules::sanity for F.
tri_bool 🔒
Build a [rows, cols] triangular boolean mask on-device.