burn_mamba/utils/scheduler.rs
1// copied from:
2// https://github.com/huy209vn/burn-jepa/blob/588d3654fbcfdcfce2ecdb7bcaf7a2e5bd5a70ea/src/train/scheduler.rs
3// slight adaptions: added Config derives and a unified enum.
4
5//! Learning rate schedulers for controlling the optimization process.
6//!
7//! This module provides various strategies to adjust the learning rate during training,
8//! such as cosine annealing with linear warmup, to improve model convergence and performance.
9
10use burn::prelude::*;
11use std::f64::consts::PI;
12
13#[derive(Config, Debug)]
14pub enum Lr {
15 CosineAnnealing(CosineAnnealingLr),
16 Constant(ConstantLr),
17}
18
19impl Lr {
20 pub fn get_lr(&self, step: usize) -> f64 {
21 match self {
22 Lr::CosineAnnealing(inner) => inner.get_lr(step),
23 Lr::Constant(inner) => inner.get_lr(step),
24 }
25 }
26}
27
28/// # Cosine Annealing Learning Rate Scheduler with Linear Warmup.
29///
30/// This scheduler:
31/// 1. Linearly increases LR from 0 to `max_lr` during warmup phase
32/// 2. Applies cosine annealing from `max_lr` to `min_lr` after warmup
33///
34/// This is a common pattern in modern deep learning training.
35#[derive(Config, Debug)]
36pub struct CosineAnnealingLr {
37 /// The maximum learning rate (reached after warmup)
38 #[config(default = 1e-4)]
39 pub max_lr: f64,
40 /// The minimum learning rate (reached at end of training)
41 #[config(default = 1e-6)]
42 pub min_lr: f64,
43 /// The total number of training steps
44 pub total_steps: usize,
45 /// The number of warmup steps
46 #[config(default = 0)]
47 pub warmup_steps: usize,
48}
49
50impl CosineAnnealingLr {
51 /// Get the learning rate for the current training step.
52 ///
53 /// # Arguments
54 /// * `step` - Current training step (0-indexed)
55 ///
56 /// # Returns
57 /// * Learning rate for this step
58 pub fn get_lr(&self, step: usize) -> f64 {
59 // Warmup phase: linear increase from 0 to max_lr
60 if step < self.warmup_steps {
61 return self.max_lr * (step as f64) / (self.warmup_steps as f64);
62 }
63
64 // After total_steps, return min_lr
65 if step >= self.total_steps {
66 return self.min_lr;
67 }
68
69 // Cosine annealing phase
70 let progress =
71 (step - self.warmup_steps) as f64 / (self.total_steps - self.warmup_steps) as f64;
72 self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + (PI * progress).cos())
73 }
74}
75
76/// # Constant Learning Rate Scheduler.
77///
78/// Simply returns a fixed learning rate for all steps.
79/// Useful for simple experiments or when learning rate scheduling is not needed.
80#[derive(Config, Debug)]
81pub struct ConstantLr {
82 #[config(default = 1e-4)]
83 pub lr: f64,
84}
85
86impl ConstantLr {
87 pub fn get_lr(&self, _step: usize) -> f64 {
88 self.lr
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use super::*;
95
96 #[test]
97 fn test_cosine_annealing_warmup() {
98 let scheduler = CosineAnnealingLr::new(1000)
99 .with_max_lr(1.)
100 .with_min_lr(0.01)
101 .with_warmup_steps(100);
102
103 // At step 0, LR should be 0.0001
104 assert_eq!(scheduler.get_lr(0), 0.);
105
106 // At half warmup, LR should be max_lr / 2
107 assert_eq!(scheduler.get_lr(50), 0.5);
108
109 // At end of warmup, LR should be max_lr
110 assert_eq!(scheduler.get_lr(100), 1.0);
111
112 // After total steps, LR should be min_lr
113 assert_eq!(scheduler.get_lr(1000), 0.01);
114 }
115
116 #[test]
117 fn test_constant_lr() {
118 let scheduler = ConstantLr::new().with_lr(0.001);
119 assert_eq!(scheduler.get_lr(0), 0.001);
120 assert_eq!(scheduler.get_lr(1000), 0.001);
121 assert_eq!(scheduler.get_lr(10000), 0.001);
122 }
123}