Skip to main content

burn_mamba/utils/
scheduler.rs

1// copied from:
2// https://github.com/huy209vn/burn-jepa/blob/588d3654fbcfdcfce2ecdb7bcaf7a2e5bd5a70ea/src/train/scheduler.rs
3// slight adaptions: added Config derives and a unified enum.
4
5//! Learning rate schedulers for controlling the optimization process.
6//!
7//! This module provides various strategies to adjust the learning rate during training,
8//! such as cosine annealing with linear warmup, to improve model convergence and performance.
9
10use burn::prelude::*;
11use std::f64::consts::PI;
12
13#[derive(Config, Debug)]
14pub enum Lr {
15    CosineAnnealing(CosineAnnealingLr),
16    Constant(ConstantLr),
17}
18
19impl Lr {
20    pub fn get_lr(&self, step: usize) -> f64 {
21        match self {
22            Lr::CosineAnnealing(inner) => inner.get_lr(step),
23            Lr::Constant(inner) => inner.get_lr(step),
24        }
25    }
26}
27
28/// # Cosine Annealing Learning Rate Scheduler with Linear Warmup.
29///
30/// This scheduler:
31/// 1. Linearly increases LR from 0 to `max_lr` during warmup phase
32/// 2. Applies cosine annealing from `max_lr` to `min_lr` after warmup
33///
34/// This is a common pattern in modern deep learning training.
35#[derive(Config, Debug)]
36pub struct CosineAnnealingLr {
37    /// The maximum learning rate (reached after warmup)
38    #[config(default = 1e-4)]
39    pub max_lr: f64,
40    /// The minimum learning rate (reached at end of training)
41    #[config(default = 1e-6)]
42    pub min_lr: f64,
43    /// The total number of training steps
44    pub total_steps: usize,
45    /// The number of warmup steps
46    #[config(default = 0)]
47    pub warmup_steps: usize,
48}
49
50impl CosineAnnealingLr {
51    /// Get the learning rate for the current training step.
52    ///
53    /// # Arguments
54    /// * `step` - Current training step (0-indexed)
55    ///
56    /// # Returns
57    /// * Learning rate for this step
58    pub fn get_lr(&self, step: usize) -> f64 {
59        // Warmup phase: linear increase from 0 to max_lr
60        if step < self.warmup_steps {
61            return self.max_lr * (step as f64) / (self.warmup_steps as f64);
62        }
63
64        // After total_steps, return min_lr
65        if step >= self.total_steps {
66            return self.min_lr;
67        }
68
69        // Cosine annealing phase
70        let progress =
71            (step - self.warmup_steps) as f64 / (self.total_steps - self.warmup_steps) as f64;
72        self.min_lr + 0.5 * (self.max_lr - self.min_lr) * (1.0 + (PI * progress).cos())
73    }
74}
75
76/// # Constant Learning Rate Scheduler.
77///
78/// Simply returns a fixed learning rate for all steps.
79/// Useful for simple experiments or when learning rate scheduling is not needed.
80#[derive(Config, Debug)]
81pub struct ConstantLr {
82    #[config(default = 1e-4)]
83    pub lr: f64,
84}
85
86impl ConstantLr {
87    pub fn get_lr(&self, _step: usize) -> f64 {
88        self.lr
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use super::*;
95
96    #[test]
97    fn test_cosine_annealing_warmup() {
98        let scheduler = CosineAnnealingLr::new(1000)
99            .with_max_lr(1.)
100            .with_min_lr(0.01)
101            .with_warmup_steps(100);
102
103        // At step 0, LR should be 0.0001
104        assert_eq!(scheduler.get_lr(0), 0.);
105
106        // At half warmup, LR should be max_lr / 2
107        assert_eq!(scheduler.get_lr(50), 0.5);
108
109        // At end of warmup, LR should be max_lr
110        assert_eq!(scheduler.get_lr(100), 1.0);
111
112        // After total steps, LR should be min_lr
113        assert_eq!(scheduler.get_lr(1000), 0.01);
114    }
115
116    #[test]
117    fn test_constant_lr() {
118        let scheduler = ConstantLr::new().with_lr(0.001);
119        assert_eq!(scheduler.get_lr(0), 0.001);
120        assert_eq!(scheduler.get_lr(1000), 0.001);
121        assert_eq!(scheduler.get_lr(10000), 0.001);
122    }
123}