Skip to main content

burn_mamba/utils/schedule/
mod.rs

1//! # Virtual-layer → real-weight scheduling
2//!
3//! A `{Model}Layers` stack can run `n_virtual_layers` logical passes over only
4//! `n_real_layers` weight sets (e.g. 48 logical from 12 real); each virtual
5//! layer keeps its own cache but shares parameters.  A [`Schedule`] maps a
6//! virtual layer index to the real weight index to use.
7//!
8//! For **bidirectional** stacks, [`BidiSchedule`] additionally interleaves the
9//! two directions: even virtual indices run the straight (→) pass and odd
10//! indices run the reverse (←) pass.
11//!
12//! Each variant is documented with a worked virtual→real mapping example.
13
14/// How a unidirectional layer stack maps virtual layer indices to real
15/// (weight-bearing) layer indices.
16#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub enum Schedule {
18    /// Fills virtual positions by wrapping around the real schedule in a looping fashion.
19    ///
20    /// # Example
21    /// - virtual len = 8, real len = 3:  
22    ///   `  →    →    →      →    →    →      →    →       `  
23    ///   `(0⇒0, 1⇒1, 2⇒2), (3⇒0, 4⇒1, 5⇒2), (6⇒0, 7⇒1, ...)`
24    #[default]
25    Cyclic,
26    /// Fills virtual positions by stretching the real schedule.
27    ///
28    /// # Example
29    /// - virtual len = 8, real len = 3:  
30    ///   `  →    →    →      →    →    →      →    →       `  
31    ///   `(0⇒0, 1⇒0, 2⇒0), (3⇒1, 4⇒1, 5⇒1), (6⇒2, 7⇒2, ...)`
32    Stretched,
33    /// Fills virtual positions by referring to the index vector.
34    ///
35    /// # Example
36    /// - virtual len = 8, real len = 3, custom = `[0, 1, 2, 2, 1, 0, 0, 0]`:  
37    ///   `  →    →    →    →    →    →    →    →       `  
38    ///   `(0⇒0, 1⇒1, 2⇒2, 3⇒2, 4⇒1, 5⇒0, 6⇒0, 7⇒0, ...)`
39    Custom(Vec<usize>),
40}
41
42impl Schedule {
43    /// Map `virtual_idx` (in `0..virtual_len`) to a real layer index in
44    /// `0..real_len` according to this schedule.
45    pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
46        match self {
47            Schedule::Cyclic => virtual_idx % real_len,
48            Schedule::Stretched => (virtual_idx * real_len) / virtual_len,
49            Schedule::Custom(map) => *map.get(virtual_idx).unwrap(),
50        }
51    }
52}
53
54/// How a bidirectional layer stack maps virtual layer indices to real layer
55/// indices, interleaving the straight (→, even indices) and reverse (←, odd
56/// indices) directions.
57#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
58pub enum BidiSchedule {
59    /// Use even virtual positions for straight-direction (→), and odd virtual positions for
60    /// reverse-direction (←), wrapping around for each schedule.
61    //
62    /// # Example
63    /// - virtual len = 10, real len = 4:  
64    ///   `   →    ←      →    ←        →    ←      →    ←        →    ←          `  
65    ///   `[(0⇒0, 1⇒1), (2⇒2, 3⇒3)], [(4⇒0, 5⇒1), (6⇒2, 7⇒3)], [(8⇒0, 9⇒1), (...)]`
66    #[default]
67    StridedCyclic,
68    /// Use even virtual positions for straight-direction (→), and odd virtual positions for
69    /// reverse-direction (←), stretching for each schedule.
70    ///
71    /// # Example
72    /// - virtual len = 10, real len = 4:  
73    ///   `   →    ←      →    ←      →    ←        →    ←      →    ←          `  
74    ///   `[(0⇒0, 1⇒1), (2⇒0, 3⇒1), (4⇒0, 5⇒1)], [(6⇒2, 7⇒3), (8⇒2, 9⇒3), (...)]`
75    StridedStretched,
76    /// Fills virtual positions by wrapping around the real schedule in a looping fashion,
77    /// replicating between the straight (→) and reverse (←) directions.
78    ///
79    /// # Example
80    /// - virtual len = 10, real len = 4:  
81    ///   `   →    ←      →    ←      →    ←      →    ←        →    ←          `  
82    ///   `[(0⇒0, 1⇒0), (2⇒1, 3⇒1), (4⇒2, 5⇒2), (6⇒3, 7⇒3)], [(8⇒0, 9⇒0), (...)]`
83    SymmetricCyclic,
84    /// Fills virtual positions by stretching the real schedule, replicating between
85    /// the straight (→) and reverse (←) directions.
86    ///
87    /// # Example
88    /// - virtual len = 10, real len = 4:  
89    ///   `   →    ←      →    ←       →    ←               →    ←        →    ←   `  
90    ///   `[(0⇒0, 1⇒0), (2⇒0, 3⇒0)],[(4⇒1, 5⇒1), (...)], [(6⇒2, 7⇒2)], [(8⇒3, 9⇒3)]`
91    SymmetricStretched,
92    /// Fills virtual positions by referring to the index vector.
93    ///
94    /// # Example
95    /// - virtual len = 10, real len = 4, custom = `[0, 1, 2, 2, 1, 0, 0, 0, 3, 2]`:  
96    ///   `   →    ←        →    ←        →    ←        →    ←        →    ←            `  
97    ///   `[(0⇒0, 1⇒1)], [(2⇒2, 3⇒2)], [(4⇒1, 5⇒0)], [(6⇒0, 7⇒0)], [(8⇒3, 9⇒2)], [(...)]`
98    Custom(Vec<usize>),
99}
100
101impl BidiSchedule {
102    /// Map `virtual_idx` (in `0..virtual_len`) to a real layer index in
103    /// `0..real_len`.  Even/odd `virtual_idx` selects the straight/reverse
104    /// direction; the outer index `virtual_idx / 2` is what the schedule cycles
105    /// or stretches over.
106    pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
107        let virtual_outer_idx = virtual_idx / 2;
108        let virtual_outer_len = virtual_len / 2;
109        match self {
110            BidiSchedule::StridedCyclic => {
111                let odd_len = real_len / 2;
112                let even_len = odd_len + real_len % 2;
113                let is_even = virtual_idx.is_multiple_of(2);
114                if is_even {
115                    (virtual_outer_idx % even_len) * 2
116                } else {
117                    (virtual_outer_idx % odd_len) * 2 + 1
118                }
119            }
120            BidiSchedule::StridedStretched => {
121                let odd_len = real_len / 2;
122                let even_len = odd_len + real_len % 2;
123                let is_even = virtual_idx.is_multiple_of(2);
124                if is_even {
125                    ((virtual_outer_idx * even_len) / virtual_outer_len) * 2
126                } else {
127                    ((virtual_outer_idx * odd_len) / virtual_outer_len) * 2 + 1
128                }
129            }
130            BidiSchedule::SymmetricCyclic => virtual_outer_idx % real_len,
131            BidiSchedule::SymmetricStretched => (virtual_outer_idx * real_len) / virtual_outer_len,
132            BidiSchedule::Custom(map) => *map.get(virtual_idx).unwrap(),
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests;