Skip to main content

burn_mamba/mamba3/
ssd_path.rs

1//! # Pathway-agnostic SSD algorithm selection (Mamba-3)
2//!
3//! [`Mamba3SsdPath`] picks the chunkwise SSD *algorithm* (Minimal / Serial /
4//! SerialRecalculated) and chunk length, independent of the double-vs-single
5//! *pathway* (which the supplied cache variant selects).  It converts into the
6//! per-pathway path types via `From` and is threaded by
7//! [`Mamba3::forward`](crate::mamba3::mamba3::Mamba3::forward) into whichever
8//! pathway the cache implies.
9
10use crate::mamba3::prelude::*;
11
12/// Algorithm selection for the Mamba-3 chunkwise SSD.
13///
14/// This selects the chunkwise SSD *algorithm*. The *pathway* (double- vs
15/// single-ssd) is selected separately, by the supplied cache variant (see
16/// [`crate::mamba3::cache::Mamba3Caches`]); [`Mamba3::forward`] threads this
17/// same selection into whichever pathway the cache implies, converting it into
18/// the per-pathway input bundle ([`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput`]
19/// or [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput`]) and calling
20/// that bundle's `run`.
21///
22/// Each variant carries an optional chunk length. Larger values increase the
23/// intra-chunk GEMM work and reduce the inter-chunk scan length; the optimal
24/// value is approximately `√(state_rank · per_head_dim)` (see
25/// [`Self::optimal_chunk_len`]). `None` falls back to that optimal value.
26///
27/// If no path is specified, the cache defaults to
28/// [`crate::mamba3::cache::Mamba3Caches::SingleSsd`] with [`Self::default`]
29/// (i.e. [`Self::SerialRecalculated`] with an unset chunk length).
30#[derive(Debug, Clone)]
31pub enum Mamba3SsdPath {
32    /// Minimal/segsum SSD: mostly batched matmuls; backward via autodiff.
33    ///
34    /// See [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput::double_ssd_minimal`]
35    /// / [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput::single_ssd_minimal`].
36    /// For training, prefer [`Self::SerialRecalculated`].
37    Minimal(Option<usize>),
38
39    /// (Hybrid) serial SSD: a serial loop over the chunks plus batched matmuls;
40    /// backward via autodiff.
41    ///
42    /// See [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput::double_ssd_serial`]
43    /// / [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput::single_ssd_serial`].
44    /// For a memory-saving custom backward, see [`Self::SerialRecalculated`].
45    Serial(Option<usize>),
46
47    /// (Hybrid) serial SSD with a custom, memory-efficient backward that
48    /// recomputes the forward intermediates instead of storing them.
49    ///
50    /// See [`crate::mamba3::double_ssd::ssd::Mamba3DoubleSsdInput::double_ssd_serial_recalculated`]
51    /// / [`crate::mamba3::single_ssd::ssd::Mamba3SingleSsdInput::single_ssd_serial_recalculated`].
52    /// For a plain autodiff backward, see [`Self::Serial`].
53    SerialRecalculated(Option<usize>),
54}
55
56impl Mamba3SsdPath {
57    /// Optimal chunk length, approximately `√(state_rank · per_head_dim)`,
58    /// rounded up to a multiple of 32 and capped at 512.
59    pub fn optimal_chunk_len(state_rank: usize, per_head_dim: usize) -> usize {
60        (state_rank * per_head_dim)
61            .isqrt()
62            .next_multiple_of(32) // rule-of-thumb: common plane dimension.
63            .min(512) // rule-of-thumb: ceiling at 512.
64    }
65
66    /// The chunk length carried by this variant, if any.
67    pub fn chunk_len(&self) -> Option<usize> {
68        match self {
69            Self::Minimal(chunk_len)
70            | Self::Serial(chunk_len)
71            | Self::SerialRecalculated(chunk_len) => *chunk_len,
72        }
73    }
74
75    /// The chunk length carried by this variant, or [`Self::optimal_chunk_len`]
76    /// when unset.
77    pub fn chunk_len_or_optimal(&self, state_rank: usize, per_head_dim: usize) -> usize {
78        self.chunk_len()
79            .unwrap_or_else(|| Self::optimal_chunk_len(state_rank, per_head_dim))
80    }
81
82    /// The recommended default path for a given block: [`Self::SerialRecalculated`]
83    /// with [`Self::optimal_chunk_len`] for the block's dimensions.
84    pub fn default_optimal_from_block(block: &Mamba3) -> Self {
85        let chunk_len = Self::optimal_chunk_len(block.state_rank, block.per_head_dim());
86        Self::SerialRecalculated(Some(chunk_len))
87    }
88}
89
90impl Default for Mamba3SsdPath {
91    fn default() -> Self {
92        // Defaults to the SerialRecalculated algorithm with the optimal chunk length.
93        Self::SerialRecalculated(None)
94    }
95}