burn_mamba/mamba3/ssd_path.rs
1//! # Pathway-agnostic SSD algorithm selection (Mamba-3)
2//!
3//! [`Mamba3SsdPath`] picks the chunkwise SSD *algorithm* (Minimal / Serial /
4//! SerialRecalculated) and chunk length, independent of the double-vs-single
5//! *pathway* (which the supplied cache variant selects). It converts into the
6//! per-pathway path types via `From` and is threaded by
7//! [`Mamba3::forward`](crate::mamba3::mamba3::Mamba3::forward) into whichever
8//! pathway the cache implies.
9
10use crate::mamba3::prelude::*;
11
12/// Algorithm selection for the Mamba-3 chunkwise SSD.
13///
14/// This selects the chunkwise SSD *algorithm*. The *pathway* (double- vs
15/// single-ssd) is selected separately, by the supplied cache variant (see
16/// [`crate::mamba3::cache::Mamba3Caches`]); [`Mamba3::forward`] threads this
17/// same selection into whichever pathway the cache implies, converting it into
18/// the per-pathway input bundle ([`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput`]
19/// or [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput`]) and calling
20/// that bundle's `run`.
21///
22/// Each variant carries an optional chunk length. Larger values increase the
23/// intra-chunk GEMM work and reduce the inter-chunk scan length; the optimal
24/// value is approximately `√(state_rank · per_head_dim)` (see
25/// [`Self::optimal_chunk_len`]). `None` falls back to that optimal value.
26///
27/// If no path is specified, the cache defaults to
28/// [`crate::mamba3::cache::Mamba3Caches::SingleSsd`] with [`Self::default`]
29/// (i.e. [`Self::SerialRecalculated`] with an unset chunk length).
30#[derive(Debug, Clone)]
31pub enum Mamba3SsdPath {
32 /// Minimal/segsum SSD: mostly batched matmuls; backward via autodiff.
33 ///
34 /// See [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput::double_ssd_minimal`]
35 /// / [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput::single_ssd_minimal`].
36 /// For training, prefer [`Self::SerialRecalculated`].
37 Minimal(Option<usize>),
38
39 /// (Hybrid) serial SSD: a serial loop over the chunks plus batched matmuls;
40 /// backward via autodiff.
41 ///
42 /// See [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput::double_ssd_serial`]
43 /// / [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput::single_ssd_serial`].
44 /// For a memory-saving custom backward, see [`Self::SerialRecalculated`].
45 Serial(Option<usize>),
46
47 /// (Hybrid) serial SSD with a custom, memory-efficient backward that
48 /// recomputes the forward intermediates instead of storing them.
49 ///
50 /// See [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput::double_ssd_serial_recalculated`]
51 /// / [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput::single_ssd_serial_recalculated`].
52 /// For a plain autodiff backward, see [`Self::Serial`].
53 SerialRecalculated(Option<usize>),
54}
55
56impl Mamba3SsdPath {
57 /// Optimal chunk length, approximately `√(state_rank · per_head_dim)`,
58 /// rounded up to a multiple of 32 and capped at 512.
59 pub fn optimal_chunk_len(state_rank: usize, per_head_dim: usize) -> usize {
60 (state_rank * per_head_dim)
61 .isqrt()
62 .next_multiple_of(32) // rule-of-thumb: common plane dimension.
63 .min(512) // rule-of-thumb: ceiling at 512.
64 }
65
66 /// The chunk length carried by this variant, if any.
67 pub fn chunk_len(&self) -> Option<usize> {
68 match self {
69 Self::Minimal(chunk_len)
70 | Self::Serial(chunk_len)
71 | Self::SerialRecalculated(chunk_len) => *chunk_len,
72 }
73 }
74
75 /// The chunk length carried by this variant, or [`Self::optimal_chunk_len`]
76 /// when unset.
77 pub fn chunk_len_or_optimal(&self, state_rank: usize, per_head_dim: usize) -> usize {
78 self.chunk_len()
79 .unwrap_or_else(|| Self::optimal_chunk_len(state_rank, per_head_dim))
80 }
81
82 /// The recommended default path for a given block: [`Self::SerialRecalculated`]
83 /// with [`Self::optimal_chunk_len`] for the block's dimensions.
84 pub fn default_optimal_from_block(block: &Mamba3) -> Self {
85 let chunk_len = Self::optimal_chunk_len(block.state_rank, block.per_head_dim());
86 Self::SerialRecalculated(Some(chunk_len))
87 }
88}
89
90impl Default for Mamba3SsdPath {
91 fn default() -> Self {
92 // Defaults to the SerialRecalculated algorithm with the optimal chunk length.
93 Self::SerialRecalculated(None)
94 }
95}