burn_mamba/utils/schedule/mod.rs
1//! # Virtual-layer → real-weight scheduling
2//!
3//! A `{Model}Layers` stack can run `n_virtual_layers` logical passes over only
4//! `n_real_layers` weight sets (e.g. 48 logical from 12 real); each virtual
5//! layer keeps its own cache but shares parameters. A [`Schedule`] maps a
6//! virtual layer index to the real weight index to use.
7//!
8//! For **bidirectional** stacks, [`BidiSchedule`] additionally interleaves the
9//! two directions: even virtual indices run the straight (→) pass and odd
10//! indices run the reverse (←) pass.
11//!
12//! Each variant is documented with a worked virtual→real mapping example.
13
14/// How a unidirectional layer stack maps virtual layer indices to real
15/// (weight-bearing) layer indices.
16#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub enum Schedule {
18 /// Fills virtual positions by wrapping around the real schedule in a looping fashion.
19 ///
20 /// # Example
21 /// - virtual len = 8, real len = 3:
22 /// ` → → → → → → → → `
23 /// `(0⇒0, 1⇒1, 2⇒2), (3⇒0, 4⇒1, 5⇒2), (6⇒0, 7⇒1, ...)`
24 #[default]
25 Cyclic,
26 /// Fills virtual positions by stretching the real schedule.
27 ///
28 /// # Example
29 /// - virtual len = 8, real len = 3:
30 /// ` → → → → → → → → `
31 /// `(0⇒0, 1⇒0, 2⇒0), (3⇒1, 4⇒1, 5⇒1), (6⇒2, 7⇒2, ...)`
32 Stretched,
33 /// Fills virtual positions by referring to the index vector.
34 ///
35 /// # Example
36 /// - virtual len = 8, real len = 3, custom = `[0, 1, 2, 2, 1, 0, 0, 0]`:
37 /// ` → → → → → → → → `
38 /// `(0⇒0, 1⇒1, 2⇒2, 3⇒2, 4⇒1, 5⇒0, 6⇒0, 7⇒0, ...)`
39 Custom(Vec<usize>),
40}
41
42impl Schedule {
43 /// Map `virtual_idx` (in `0..virtual_len`) to a real layer index in
44 /// `0..real_len` according to this schedule.
45 pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
46 match self {
47 Schedule::Cyclic => virtual_idx % real_len,
48 Schedule::Stretched => (virtual_idx * real_len) / virtual_len,
49 Schedule::Custom(map) => *map.get(virtual_idx).unwrap(),
50 }
51 }
52}
53
54/// How a bidirectional layer stack maps virtual layer indices to real layer
55/// indices, interleaving the straight (→, even indices) and reverse (←, odd
56/// indices) directions.
57#[derive(Default, Debug, Clone, serde::Serialize, serde::Deserialize)]
58pub enum BidiSchedule {
59 /// Use even virtual positions for straight-direction (→), and odd virtual positions for
60 /// reverse-direction (←), wrapping around for each schedule.
61 //
62 /// # Example
63 /// - virtual len = 10, real len = 4:
64 /// ` → ← → ← → ← → ← → ← `
65 /// `[(0⇒0, 1⇒1), (2⇒2, 3⇒3)], [(4⇒0, 5⇒1), (6⇒2, 7⇒3)], [(8⇒0, 9⇒1), (...)]`
66 #[default]
67 StridedCyclic,
68 /// Use even virtual positions for straight-direction (→), and odd virtual positions for
69 /// reverse-direction (←), stretching for each schedule.
70 ///
71 /// # Example
72 /// - virtual len = 10, real len = 4:
73 /// ` → ← → ← → ← → ← → ← `
74 /// `[(0⇒0, 1⇒1), (2⇒0, 3⇒1), (4⇒0, 5⇒1)], [(6⇒2, 7⇒3), (8⇒2, 9⇒3), (...)]`
75 StridedStretched,
76 /// Fills virtual positions by wrapping around the real schedule in a looping fashion,
77 /// replicating between the straight (→) and reverse (←) directions.
78 ///
79 /// # Example
80 /// - virtual len = 10, real len = 4:
81 /// ` → ← → ← → ← → ← → ← `
82 /// `[(0⇒0, 1⇒0), (2⇒1, 3⇒1), (4⇒2, 5⇒2), (6⇒3, 7⇒3)], [(8⇒0, 9⇒0), (...)]`
83 SymmetricCyclic,
84 /// Fills virtual positions by stretching the real schedule, replicating between
85 /// the straight (→) and reverse (←) directions.
86 ///
87 /// # Example
88 /// - virtual len = 10, real len = 4:
89 /// ` → ← → ← → ← → ← → ← `
90 /// `[(0⇒0, 1⇒0), (2⇒0, 3⇒0)],[(4⇒1, 5⇒1), (...)], [(6⇒2, 7⇒2)], [(8⇒3, 9⇒3)]`
91 SymmetricStretched,
92 /// Fills virtual positions by referring to the index vector.
93 ///
94 /// # Example
95 /// - virtual len = 10, real len = 4, custom = `[0, 1, 2, 2, 1, 0, 0, 0, 3, 2]`:
96 /// ` → ← → ← → ← → ← → ← `
97 /// `[(0⇒0, 1⇒1)], [(2⇒2, 3⇒2)], [(4⇒1, 5⇒0)], [(6⇒0, 7⇒0)], [(8⇒3, 9⇒2)], [(...)]`
98 Custom(Vec<usize>),
99}
100
101impl BidiSchedule {
102 /// Map `virtual_idx` (in `0..virtual_len`) to a real layer index in
103 /// `0..real_len`. Even/odd `virtual_idx` selects the straight/reverse
104 /// direction; the outer index `virtual_idx / 2` is what the schedule cycles
105 /// or stretches over.
106 pub fn real_idx(&self, virtual_idx: usize, virtual_len: usize, real_len: usize) -> usize {
107 let virtual_outer_idx = virtual_idx / 2;
108 let virtual_outer_len = virtual_len / 2;
109 match self {
110 BidiSchedule::StridedCyclic => {
111 let odd_len = real_len / 2;
112 let even_len = odd_len + real_len % 2;
113 let is_even = virtual_idx.is_multiple_of(2);
114 if is_even {
115 (virtual_outer_idx % even_len) * 2
116 } else {
117 (virtual_outer_idx % odd_len) * 2 + 1
118 }
119 }
120 BidiSchedule::StridedStretched => {
121 let odd_len = real_len / 2;
122 let even_len = odd_len + real_len % 2;
123 let is_even = virtual_idx.is_multiple_of(2);
124 if is_even {
125 ((virtual_outer_idx * even_len) / virtual_outer_len) * 2
126 } else {
127 ((virtual_outer_idx * odd_len) / virtual_outer_len) * 2 + 1
128 }
129 }
130 BidiSchedule::SymmetricCyclic => virtual_outer_idx % real_len,
131 BidiSchedule::SymmetricStretched => (virtual_outer_idx * real_len) / virtual_outer_len,
132 BidiSchedule::Custom(map) => *map.get(virtual_idx).unwrap(),
133 }
134 }
135}
136
137#[cfg(test)]
138mod tests;