pub struct RmsNormGated {
pub gamma: Param<Tensor<1>>,
pub norm_before_gate: bool,
}Expand description
Applies Gated Rms Normalization over an input tensor along the last dimension.
- If
norm_before_gate=true:Y = (X / sqrt(mean(X^2) + eps) * gamma) * SiLU(z) - If
norm_before_gate=false:Y = (X * SiLU(z)) / sqrt(mean((X * SiLU(z))^2) + eps) * gamma
Where:
Xis the input tensorYis the output tensorzis the gating tensorgammais the learnable weightmeanis the mean operationepsis a small value to avoid division by zero.
Should be created using the RmsNormGatedConfig configuration.
Fields§
§gamma: Param<Tensor<1>>The learnable per-channel scale γ, shape [d_model].
norm_before_gate: boolWhether to normalize before applying the gating.
Implementations§
Trait Implementations§
Source§impl AutodiffModule for RmsNormGated
impl AutodiffModule for RmsNormGated
Source§impl Clone for RmsNormGated
impl Clone for RmsNormGated
Source§impl Debug for RmsNormGated
impl Debug for RmsNormGated
Source§impl Display for RmsNormGated
impl Display for RmsNormGated
Source§impl Module for RmsNormGated
impl Module for RmsNormGated
Source§type Record = RmsNormGatedRecord
type Record = RmsNormGatedRecord
Type to save and load the module.
Source§fn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the module state from a record.
Source§fn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Convert the module into a record containing the state.
Source§fn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.
Source§fn visit<Visitor: ModuleVisitor>(&self, visitor: &mut Visitor)
fn visit<Visitor: ModuleVisitor>(&self, visitor: &mut Visitor)
Visit each tensor parameter in the module with a visitor.
Source§fn map<Mapper: ModuleMapper>(self, mapper: &mut Mapper) -> Self
fn map<Mapper: ModuleMapper>(self, mapper: &mut Mapper) -> Self
Map each tensor parameter in the module with a mapper.
Source§fn collect_devices(&self, devices: Devices) -> Devices
fn collect_devices(&self, devices: Devices) -> Devices
Return all the devices found in the underneath module tree added to the given vector
without duplicates.
Source§fn to_device(self, device: &Device) -> Self
fn to_device(self, device: &Device) -> Self
Move the module and all of its sub-modules to the given device. Read more
Source§fn fork(self, device: &Device) -> Self
fn fork(self, device: &Device) -> Self
Fork the module and all of its sub-modules to the given device. Read more
§fn devices(&self) -> Vec<Device>
fn devices(&self) -> Vec<Device>
Return all the devices found in the underneath module tree without duplicates.
§fn train(self) -> Selfwhere
Self: AutodiffModule,
fn train(self) -> Selfwhere
Self: AutodiffModule,
Move the module and all of its sub-modules to the autodiff backend. Read more
§fn quantize_weights(self, quantizer: &mut Quantizer) -> Self
fn quantize_weights(self, quantizer: &mut Quantizer) -> Self
Quantize the weights of the module.
Source§impl ModuleDisplay for RmsNormGated
impl ModuleDisplay for RmsNormGated
Source§fn custom_settings(&self) -> Option<DisplaySettings>
fn custom_settings(&self) -> Option<DisplaySettings>
Custom display settings for the module. Read more
Source§fn custom_content(&self, content: Content) -> Option<Content>
fn custom_content(&self, content: Content) -> Option<Content>
Custom attributes for the module. Read more
Auto Trait Implementations§
impl !Freeze for RmsNormGated
impl !RefUnwindSafe for RmsNormGated
impl !UnwindSafe for RmsNormGated
impl Send for RmsNormGated
impl Sync for RmsNormGated
impl Unpin for RmsNormGated
impl UnsafeUnpin for RmsNormGated
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more