burn_mamba/schedule.rs
1#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
2pub enum Schedule {
3 /// Fills virtual positions by wrapping around the real schedule in a looping fashion.
4 ///
5 /// # Example
6 /// - virtual len = 8, real len = 3:
7 /// ` → → → → → → → → `
8 /// `(0⇒0, 1⇒1, 2⇒2), (3⇒0, 4⇒1, 5⇒2), (6⇒0, 7⇒1, ...)`
9 #[default]
10 Cyclic,
11 /// Fills virtual positions by stretching the real schedule.
12 ///
13 /// # Example
14 /// - virtual len = 8, real len = 3:
15 /// ` → → → → → → → → `
16 /// `(0⇒0, 1⇒0, 2⇒0), (3⇒1, 4⇒1, 5⇒1), (6⇒2, 7⇒2, ...)`
17 Stretched,
18 /// Fills virtual positions by referring to the index vector.
19 ///
20 /// # Example
21 /// - virtual len = 8, real len = 3, custom = [0, 1, 2, 2, 1, 0, 0, 0]:
22 /// ` → → → → → → → → `
23 /// `(0⇒0, 1⇒1, 2⇒2, 3⇒2, 4⇒1, 5⇒0, 6⇒0, 7⇒0, ...)`
24 Custom(Vec<usize>),
25}
26
27impl Schedule {
28 pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
29 match self {
30 Schedule::Cyclic => virtual_idx % real_len,
31 Schedule::Stretched => (virtual_idx * real_len) / virtual_len,
32 Schedule::Custom(map) => *map.get(virtual_idx).unwrap(),
33 }
34 }
35}
36
37#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
38pub enum BidiSchedule {
39 /// Use even virtual positions for straight-direction (→), and odd virtual positions for
40 /// reverse-direction (←), wrapping around for each schedule.
41 //
42 /// # Example
43 /// - virtual len = 10, real len = 4:
44 /// ` → ← → ← → ← → ← → ← `
45 /// `[(0⇒0, 1⇒1), (2⇒2, 3⇒3)], [(4⇒0, 5⇒1), (6⇒2, 7⇒3)], [(8⇒0, 9⇒1), (...)]`
46 #[default]
47 StridedCyclic,
48 /// Use even virtual positions for straight-direction (→), and odd virtual positions for
49 /// reverse-direction (←), stretching for each schedule.
50 ///
51 /// # Example
52 /// - virtual len = 10, real len = 4:
53 /// ` → ← → ← → ← → ← → ← `
54 /// `[(0⇒0, 1⇒1), (2⇒0, 3⇒1), (4⇒0, 5⇒1)], [(6⇒2, 7⇒3), (8⇒2, 9⇒3), (...)]`
55 StridedStretched,
56 /// Fills virtual positions by wrapping around the real schedule in a looping fashion,
57 /// replicating between the straight (→) and reverse (←) directions.
58 ///
59 /// # Example
60 /// - virtual len = 10, real len = 4:
61 /// ` → ← → ← → ← → ← → ← `
62 /// `[(0⇒0, 1⇒0), (2⇒1, 3⇒1), (4⇒2, 5⇒2), (6⇒3, 7⇒3)], [(8⇒0, 9⇒0), (...)]`
63 SymmetricCyclic,
64 /// Fills virtual positions by stretching the real schedule, replicating between
65 /// the straight (→) and reverse (←) directions.
66 ///
67 /// # Example
68 /// - virtual len = 10, real len = 4:
69 /// ` → ← → ← → ← → ← → ← `
70 /// `[(0⇒0, 1⇒0), (2⇒0, 3⇒0)],[(4⇒1, 5⇒1), (...)], [(6⇒2, 7⇒2)], [(8⇒3, 9⇒3)]`
71 SymmetricStretched,
72 /// Fills virtual positions by referring to the index vector.
73 ///
74 /// # Example
75 /// - virtual len = 10, real len = 4, custom = [0, 1, 2, 2, 1, 0, 0, 0, 3, 2]:
76 /// ` → ← → ← → ← → ← → ← `
77 /// `[(0⇒0, 1⇒1)], [(2⇒2, 3⇒2)], [(4⇒1, 5⇒0)], [(6⇒0, 7⇒0)], [(8⇒3, 9⇒2)], [(...)]`
78 Custom(Vec<usize>),
79}
80
81impl BidiSchedule {
82 pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
83 let virtual_outer_idx = virtual_idx / 2;
84 let virtual_outer_len = virtual_len / 2;
85 match self {
86 BidiSchedule::StridedCyclic => {
87 let odd_len = real_len / 2;
88 let even_len = odd_len + real_len % 2;
89 let is_even = virtual_idx.is_multiple_of(2);
90 if is_even {
91 (virtual_outer_idx % even_len) * 2
92 } else {
93 (virtual_outer_idx % odd_len) * 2 + 1
94 }
95 }
96 BidiSchedule::StridedStretched => {
97 let odd_len = real_len / 2;
98 let even_len = odd_len + real_len % 2;
99 let is_even = virtual_idx.is_multiple_of(2);
100 if is_even {
101 ((virtual_outer_idx * even_len) / virtual_outer_len) * 2
102 } else {
103 ((virtual_outer_idx * odd_len) / virtual_outer_len) * 2 + 1
104 }
105 }
106 BidiSchedule::SymmetricCyclic => virtual_outer_idx % real_len,
107 BidiSchedule::SymmetricStretched => (virtual_outer_idx * real_len) / virtual_outer_len,
108 BidiSchedule::Custom(map) => *map.get(virtual_idx).unwrap(),
109 }
110 }
111}
112
113#[cfg(test)]
114mod test {
115 use super::*;
116
117 #[test]
118 fn schedule() {
119 use Schedule::*;
120 assert_eq!(
121 (0..8)
122 .map(|i| Schedule::real_idx(&Cyclic, i, 8, 3))
123 .collect::<Vec<_>>(),
124 vec![0, 1, 2, 0, 1, 2, 0, 1]
125 );
126 assert_eq!(
127 (0..8)
128 .map(|i| Schedule::real_idx(&Stretched, i, 8, 3))
129 .collect::<Vec<_>>(),
130 vec![0, 0, 0, 1, 1, 1, 2, 2]
131 );
132 let custom = vec![0, 1, 2, 2, 1, 0, 0, 0];
133 assert_eq!(
134 (0..8)
135 .map(|i| Schedule::real_idx(&Custom(custom.clone()), i, 8, 3))
136 .collect::<Vec<_>>(),
137 custom
138 );
139 }
140
141 #[test]
142 fn bidi_schedule() {
143 use BidiSchedule::*;
144 assert_eq!(
145 (0..10)
146 .map(|i| BidiSchedule::real_idx(&StridedCyclic, i, 10, 4))
147 .collect::<Vec<_>>(),
148 vec![
149 0, 1, /**/ 2, 3, /**/ 0, 1, /**/ 2, 3, /**/ 0, 1
150 ]
151 );
152 assert_eq!(
153 (0..10)
154 .map(|i| BidiSchedule::real_idx(&StridedStretched, i, 10, 4))
155 .collect::<Vec<_>>(),
156 vec![
157 0, 1, /**/ 0, 1, /**/ 0, 1, /**/ 2, 3, /**/ 2, 3
158 ]
159 );
160 assert_eq!(
161 (0..10)
162 .map(|i| BidiSchedule::real_idx(&SymmetricCyclic, i, 10, 4))
163 .collect::<Vec<_>>(),
164 vec![
165 0, 0, /**/ 1, 1, /**/ 2, 2, /**/ 3, 3, /**/ 0, 0
166 ]
167 );
168 assert_eq!(
169 (0..10)
170 .map(|i| BidiSchedule::real_idx(&SymmetricStretched, i, 10, 4))
171 .collect::<Vec<_>>(),
172 vec![
173 0, 0, /**/ 0, 0, /**/ 1, 1, /**/ 2, 2, /**/ 3, 3
174 ]
175 );
176 let custom = vec![
177 0, 1, /**/ 2, 2, /**/ 1, 0, /**/ 0, 0, /**/ 3, 2,
178 ];
179 assert_eq!(
180 (0..10)
181 .map(|i| BidiSchedule::real_idx(&Custom(custom.clone()), i, 10, 4))
182 .collect::<Vec<_>>(),
183 custom
184 );
185 }
186}