burn_mamba/mamba2/ssd/serial_recalculated/
serial_recalculated.rs1use crate::mamba2::prelude::*;
2use crate::mamba2::ssd::serial;
3use burn::prelude::*;
4use burn::tensor::{Tensor, TensorPrimitive, ops::FloatTensor};
5
6impl<B: Backend + Mamba2BackendExt> Mamba2<B> {
7 #[allow(non_snake_case)]
13 pub fn ssd_serial_recalculated(
14 input: super::super::Mamba2SsdInput<B>,
15 ) -> (Tensor<B, 5>, Tensor<B, 4>) {
16 let [batch, nchunks, chunk_len, nheads, _per_head_dim] = input.x_bnlhp.dims();
22 let [.., ngroups, _state_rank] = input.b_bnlgr.dims();
23 assert_ne!(ngroups, 0);
24 assert_eq!(nheads % ngroups, 0);
25 assert!(nchunks > 0, "sequence length must be at least 1");
26
27 assert!(
28 input.init_state_hpr.is_none(),
29 "init_state_hpr not yet implemented"
30 );
31
32 let dt_discretized_bhnl = input.dt_bnlh.permute([0, 3, 1, 2]);
35 assert_eq!(
36 [batch, nheads, nchunks, chunk_len],
37 dt_discretized_bhnl.dims()
38 );
39
40 let (y_bnlhp, final_state_bhpr) = <B as Mamba2BackendExt>::ssd_serial_recalculated(
44 input.x_bnlhp.into_primitive().tensor(),
45 dt_discretized_bhnl.into_primitive().tensor(),
46 input.b_bnlgr.into_primitive().tensor(),
47 input.c_bnlgr.into_primitive().tensor(),
48 input.d_h.into_primitive().tensor(),
49 input.initial_state_bhpr.into_primitive().tensor(),
50 input.a_decay_h.into_primitive().tensor(),
51 );
52 let y_bnlhp = Tensor::from_primitive(TensorPrimitive::Float(y_bnlhp));
53 let final_state_bhpr = Tensor::from_primitive(TensorPrimitive::Float(final_state_bhpr));
54 (y_bnlhp, final_state_bhpr)
55 }
56}
57
58pub trait Mamba2BackendExt: burn::tensor::backend::Backend {
60 fn ssd_serial_recalculated(
64 x_bnlhp: FloatTensor<Self>,
65 dt_discretized_bhnl: FloatTensor<Self>,
66 b_bnlgr: FloatTensor<Self>,
67 c_bnlgr: FloatTensor<Self>,
68 d_h: FloatTensor<Self>,
69 initial_state_bhpr: FloatTensor<Self>,
70 a_decay_h: FloatTensor<Self>,
71 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
72 let x_bnlhp: Tensor<Self, 5> = mk(x_bnlhp);
75 let dt_discretized_bhnl: Tensor<Self, 4> = mk(dt_discretized_bhnl);
76 let b_bnlgr: Tensor<Self, 5> = mk(b_bnlgr);
77 let c_bnlgr: Tensor<Self, 5> = mk(c_bnlgr);
78 let d_h: Tensor<Self, 1> = mk(d_h);
79 let initial_state_bhpr: Tensor<Self, 4> = mk(initial_state_bhpr);
80 let a_decay_h: Tensor<Self, 1> = mk(a_decay_h);
81
82 let [batch, nchunks, chunk_len, nheads, per_head_dim] = x_bnlhp.dims();
83 let [.., ngroups, state_rank] = b_bnlgr.dims();
84 assert_ne!(ngroups, 0);
85 assert_eq!(nheads % ngroups, 0);
86 assert!(nchunks > 0, "sequence length must be at least 1");
87 let (da_cumsum_bhnl, da_chunk_end_bhn) =
93 serial::k1_ssd_chunk_cumsum(dt_discretized_bhnl.clone(), a_decay_h);
94 assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
95 assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
96
97 let cb_bngll: Tensor<Self, 5> = serial::k2_ssd_bmm(c_bnlgr.clone(), b_bnlgr.clone());
100 assert_eq!(
101 [batch, nchunks, ngroups, chunk_len, chunk_len],
102 cb_bngll.dims()
103 );
104 let intra_chunk_state_bnhpr: Tensor<Self, 5> = serial::k3_ssd_chunk_state(
109 x_bnlhp.clone(),
110 b_bnlgr.clone(),
111 da_cumsum_bhnl.clone(),
112 dt_discretized_bhnl.clone(),
113 );
114 assert_eq!(
115 [batch, nchunks, nheads, per_head_dim, state_rank],
116 intra_chunk_state_bnhpr.dims()
117 );
118
119 let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<Self, 5>, Tensor<Self, 4>) =
122 serial::k4_ssd_state_passing(
123 intra_chunk_state_bnhpr.clone(),
124 da_chunk_end_bhn.clone(),
125 initial_state_bhpr,
126 );
127 assert_eq!(
128 [batch, nchunks, nheads, per_head_dim, state_rank],
129 chunk_input_state_bnhpr.dims()
130 );
131 assert_eq!(
132 [batch, nheads, per_head_dim, state_rank],
133 final_state_bhpr.dims()
134 );
135
136 let y_bnlhp: Tensor<Self, 5> = serial::k5_ssd_chunk_scan(
138 da_cumsum_bhnl,
139 dt_discretized_bhnl,
140 x_bnlhp,
141 c_bnlgr,
142 cb_bngll,
143 chunk_input_state_bnhpr,
144 d_h,
145 );
146 assert_eq!(
147 [batch, nchunks, chunk_len, nheads, per_head_dim],
148 y_bnlhp.dims()
149 );
150
151 let y_bnlhp = y_bnlhp.into_primitive().tensor();
152 let final_state_bhpr = final_state_bhpr.into_primitive().tensor();
153 (y_bnlhp, final_state_bhpr)
154 }
155}
156
157#[cfg(feature = "backend-ndarray")]
165impl<F, I> Mamba2BackendExt for burn::backend::NdArray<F, I> {}
166#[cfg(feature = "backend-flex")]
167impl Mamba2BackendExt for burn::backend::Flex {}
168#[cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu"))]
169impl<F, I> Mamba2BackendExt for burn::backend::libtorch::LibTorch<F, I> {}
170#[cfg(feature = "backend-remote")]
171impl<F, I> Mamba2BackendExt for burn::backend::RemoteBackend<F, I> {}
172#[cfg(feature = "cubecl")]
174mod cubecl {
175 use burn_cubecl::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
176 impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> super::Mamba2BackendExt
177 for CubeBackend<R, F, I, BT>
178 {
179 }
180}
181
182#[cfg(feature = "fusion")]
185mod fusion {
186 use burn_fusion::{Fusion, FusionBackend};
187 impl<B: FusionBackend + super::Mamba2BackendExt> super::Mamba2BackendExt for Fusion<B> {}
188}
189
190#[cfg(feature = "autodiff")]
192pub trait Mamba2AutodiffBackendExt:
193 Backend + Mamba2BackendExt + burn::tensor::backend::AutodiffBackend
194{
195}
196#[cfg(feature = "autodiff")]
201impl<B: Backend + Mamba2BackendExt> Mamba2AutodiffBackendExt for burn::backend::Autodiff<B> {}
202
203pub(crate) fn mk<B: Backend, const D: usize>(p: FloatTensor<B>) -> Tensor<B, D> {
205 Tensor::from_primitive(TensorPrimitive::Float(p))
206}