Skip to main content

burn_mamba/
schedule.rs

1#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
2pub enum Schedule {
3    /// Fills virtual positions by wrapping around the real schedule in a looping fashion.
4    ///
5    /// # Example
6    /// - virtual len = 8, real len = 3:  
7    ///   `  →    →    →      →    →    →      →    →       `
8    ///   `(0⇒0, 1⇒1, 2⇒2), (3⇒0, 4⇒1, 5⇒2), (6⇒0, 7⇒1, ...)`
9    #[default]
10    Cyclic,
11    /// Fills virtual positions by stretching the real schedule.
12    ///
13    /// # Example
14    /// - virtual len = 8, real len = 3:  
15    ///   `  →    →    →      →    →    →      →    →       `
16    ///   `(0⇒0, 1⇒0, 2⇒0), (3⇒1, 4⇒1, 5⇒1), (6⇒2, 7⇒2, ...)`
17    Stretched,
18    /// Fills virtual positions by referring to the index vector.
19    ///
20    /// # Example
21    /// - virtual len = 8, real len = 3, custom = [0, 1, 2, 2, 1, 0, 0, 0]:  
22    ///   `  →    →    →    →    →    →    →    →       `
23    ///   `(0⇒0, 1⇒1, 2⇒2, 3⇒2, 4⇒1, 5⇒0, 6⇒0, 7⇒0, ...)`
24    Custom(Vec<usize>),
25}
26
27impl Schedule {
28    pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
29        match self {
30            Schedule::Cyclic => virtual_idx % real_len,
31            Schedule::Stretched => (virtual_idx * real_len) / virtual_len,
32            Schedule::Custom(map) => *map.get(virtual_idx).unwrap(),
33        }
34    }
35}
36
37#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub enum BidiSchedule {
39    /// Use even virtual positions for straight-direction (→), and odd virtual positions for
40    /// reverse-direction (←), wrapping around for each schedule.
41    //
42    /// # Example
43    /// - virtual len = 10, real len = 4:  
44    ///   `   →    ←      →    ←        →    ←      →    ←        →    ←          `
45    ///   `[(0⇒0, 1⇒1), (2⇒2, 3⇒3)], [(4⇒0, 5⇒1), (6⇒2, 7⇒3)], [(8⇒0, 9⇒1), (...)]`
46    #[default]
47    StridedCyclic,
48    /// Use even virtual positions for straight-direction (→), and odd virtual positions for
49    /// reverse-direction (←), stretching for each schedule.
50    ///
51    /// # Example
52    /// - virtual len = 10, real len = 4:  
53    ///   `   →    ←      →    ←      →    ←        →    ←      →    ←          `
54    ///   `[(0⇒0, 1⇒1), (2⇒0, 3⇒1), (4⇒0, 5⇒1)], [(6⇒2, 7⇒3), (8⇒2, 9⇒3), (...)]`
55    StridedStretched,
56    /// Fills virtual positions by wrapping around the real schedule in a looping fashion,
57    /// replicating between the straight (→) and reverse (←) directions.
58    ///
59    /// # Example
60    /// - virtual len = 10, real len = 4:  
61    ///   `   →    ←      →    ←      →    ←      →    ←        →    ←          `
62    ///   `[(0⇒0, 1⇒0), (2⇒1, 3⇒1), (4⇒2, 5⇒2), (6⇒3, 7⇒3)], [(8⇒0, 9⇒0), (...)]`
63    SymmetricCyclic,
64    /// Fills virtual positions by stretching the real schedule, replicating between
65    /// the straight (→) and reverse (←) directions.
66    ///
67    /// # Example
68    /// - virtual len = 10, real len = 4:  
69    ///   `   →    ←      →    ←       →    ←               →    ←        →    ←   `
70    ///   `[(0⇒0, 1⇒0), (2⇒0, 3⇒0)],[(4⇒1, 5⇒1), (...)], [(6⇒2, 7⇒2)], [(8⇒3, 9⇒3)]`
71    SymmetricStretched,
72    /// Fills virtual positions by referring to the index vector.
73    ///
74    /// # Example
75    /// - virtual len = 10, real len = 4, custom = [0, 1, 2, 2, 1, 0, 0, 0, 3, 2]:  
76    ///   `   →    ←        →    ←        →    ←        →    ←        →    ←            `
77    ///   `[(0⇒0, 1⇒1)], [(2⇒2, 3⇒2)], [(4⇒1, 5⇒0)], [(6⇒0, 7⇒0)], [(8⇒3, 9⇒2)], [(...)]`
78    Custom(Vec<usize>),
79}
80
81impl BidiSchedule {
82    pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
83        let virtual_outer_idx = virtual_idx / 2;
84        let virtual_outer_len = virtual_len / 2;
85        match self {
86            BidiSchedule::StridedCyclic => {
87                let odd_len = real_len / 2;
88                let even_len = odd_len + real_len % 2;
89                let is_even = virtual_idx.is_multiple_of(2);
90                if is_even {
91                    (virtual_outer_idx % even_len) * 2
92                } else {
93                    (virtual_outer_idx % odd_len) * 2 + 1
94                }
95            }
96            BidiSchedule::StridedStretched => {
97                let odd_len = real_len / 2;
98                let even_len = odd_len + real_len % 2;
99                let is_even = virtual_idx.is_multiple_of(2);
100                if is_even {
101                    ((virtual_outer_idx * even_len) / virtual_outer_len) * 2
102                } else {
103                    ((virtual_outer_idx * odd_len) / virtual_outer_len) * 2 + 1
104                }
105            }
106            BidiSchedule::SymmetricCyclic => virtual_outer_idx % real_len,
107            BidiSchedule::SymmetricStretched => (virtual_outer_idx * real_len) / virtual_outer_len,
108            BidiSchedule::Custom(map) => *map.get(virtual_idx).unwrap(),
109        }
110    }
111}
112
113#[cfg(test)]
114mod test {
115    use super::*;
116
117    #[test]
118    fn schedule() {
119        use Schedule::*;
120        assert_eq!(
121            (0..8)
122                .map(|i| Schedule::real_idx(&Cyclic, i, 8, 3))
123                .collect::<Vec<_>>(),
124            vec![0, 1, 2, 0, 1, 2, 0, 1]
125        );
126        assert_eq!(
127            (0..8)
128                .map(|i| Schedule::real_idx(&Stretched, i, 8, 3))
129                .collect::<Vec<_>>(),
130            vec![0, 0, 0, 1, 1, 1, 2, 2]
131        );
132        let custom = vec![0, 1, 2, 2, 1, 0, 0, 0];
133        assert_eq!(
134            (0..8)
135                .map(|i| Schedule::real_idx(&Custom(custom.clone()), i, 8, 3))
136                .collect::<Vec<_>>(),
137            custom
138        );
139    }
140
141    #[test]
142    fn bidi_schedule() {
143        use BidiSchedule::*;
144        assert_eq!(
145            (0..10)
146                .map(|i| BidiSchedule::real_idx(&StridedCyclic, i, 10, 4))
147                .collect::<Vec<_>>(),
148            vec![
149                0, 1, /**/ 2, 3, /**/ 0, 1, /**/ 2, 3, /**/ 0, 1
150            ]
151        );
152        assert_eq!(
153            (0..10)
154                .map(|i| BidiSchedule::real_idx(&StridedStretched, i, 10, 4))
155                .collect::<Vec<_>>(),
156            vec![
157                0, 1, /**/ 0, 1, /**/ 0, 1, /**/ 2, 3, /**/ 2, 3
158            ]
159        );
160        assert_eq!(
161            (0..10)
162                .map(|i| BidiSchedule::real_idx(&SymmetricCyclic, i, 10, 4))
163                .collect::<Vec<_>>(),
164            vec![
165                0, 0, /**/ 1, 1, /**/ 2, 2, /**/ 3, 3, /**/ 0, 0
166            ]
167        );
168        assert_eq!(
169            (0..10)
170                .map(|i| BidiSchedule::real_idx(&SymmetricStretched, i, 10, 4))
171                .collect::<Vec<_>>(),
172            vec![
173                0, 0, /**/ 0, 0, /**/ 1, 1, /**/ 2, 2, /**/ 3, 3
174            ]
175        );
176        let custom = vec![
177            0, 1, /**/ 2, 2, /**/ 1, 0, /**/ 0, 0, /**/ 3, 2,
178        ];
179        assert_eq!(
180            (0..10)
181                .map(|i| BidiSchedule::real_idx(&Custom(custom.clone()), i, 10, 4))
182                .collect::<Vec<_>>(),
183            custom
184        );
185    }
186}