1#![allow(non_snake_case)]
18
19use crate::mamba3::double_ssd::ssd::serial_recalculated::{
20 k1_ssd_chunk_cumsum, k2_ssd_bmm, k3_ssd_chunk_state, k4_ssd_state_passing,
21};
22use crate::mamba3::single_ssd::prelude::*;
23use crate::utils::fprim::{F, san};
24use burn::backend::tensor::FloatTensor;
25use burn::backend::*;
26use burn::backend::{Backend, Dispatch, backend_extension};
27use burn::tensor::Tensor;
28
29impl Mamba3SingleSsdInput {
30 pub fn single_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>) {
41 let input = self;
42 input.sanity();
43 assert!(
44 input.init_state_hpr.is_none(),
45 "init_state_hpr not yet implemented for single_ssd_serial_recalculated"
46 );
47
48 let (y_bnlmhp, final_state_bhpr) =
49 <Dispatch as Mamba3SingleSsdBackendExt>::single_ssd_serial_recalculated(
50 input.v_bnlmhp.into_primitive(),
51 input.da_bnlh.into_primitive(),
52 input.b_bnlmhr.into_primitive(),
53 input.c_bnlmhr.into_primitive(),
54 input.gamma_bnlh.into_primitive(),
55 input.scale_bnlh.into_primitive(),
56 input.initial_state_bhpr.into_primitive(),
57 );
58 let y_bnlmhp = Tensor::from_primitive(y_bnlmhp);
59 let final_state_bhpr = Tensor::from_primitive(final_state_bhpr);
60 (y_bnlmhp, final_state_bhpr)
61 }
62}
63
64#[backend_extension(
72 Cpu: cfg(feature = "backend-cpu"),
73 Cuda: cfg(feature = "backend-cuda"),
74 Rocm: cfg(feature = "backend-rocm"),
75 Metal: cfg(feature = "backend-metal"),
76 Vulkan: cfg(feature = "backend-vulkan"),
77 Wgpu: cfg(feature = "backend-wgpu"),
78 WebGpu: cfg(feature = "backend-webgpu"),
79 Flex: cfg(feature = "backend-flex"),
80 NdArray: cfg(feature = "backend-ndarray"),
81 LibTorch: cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu")),
82 Autodiff: cfg(feature = "autodiff"),
83)]
84pub trait Mamba3SingleSsdBackendExt: Backend {
85 fn single_ssd_serial_recalculated(
100 v_bnlmhp: FloatTensor<Self>,
101 da_bnlh: FloatTensor<Self>,
102 b_bnlmhr: FloatTensor<Self>,
103 c_bnlmhr: FloatTensor<Self>,
104 gamma_bnlh: FloatTensor<Self>,
105 scale_bnlh: FloatTensor<Self>,
106 initial_state_bhpr: FloatTensor<Self>,
107 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
108 let v_bnlmhp = F::<Self, 6>::new(v_bnlmhp);
110 let da_bnlh = F::<Self, 4>::new(da_bnlh);
111 let b_bnlmhr = F::<Self, 6>::new(b_bnlmhr);
112 let c_bnlmhr = F::<Self, 6>::new(c_bnlmhr);
113 let gamma_bnlh = F::<Self, 4>::new(gamma_bnlh);
114 let scale_bnlh = F::<Self, 4>::new(scale_bnlh);
115 let initial_state_bhpr = F::<Self, 4>::new(initial_state_bhpr);
116
117 let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum::<Self>(da_bnlh);
119 san(&da_cumsum_bhnl);
120
121 let cb_bnhLMLM = k2_ssd_bmm::<Self>(c_bnlmhr.clone(), b_bnlmhr.clone());
123 san(&cb_bnhLMLM);
124
125 let scale_bnlh11 = scale_bnlh.clone().unsqueeze_dims::<6>(&[3, 5]);
127 let k_scaled_bnlmhr = b_bnlmhr.clone() * scale_bnlh11;
128 let intra_chunk_state_bnhpr =
129 k3_ssd_chunk_state::<Self>(v_bnlmhp.clone(), k_scaled_bnlmhr, da_cumsum_bhnl.clone());
130 san(&intra_chunk_state_bnhpr);
131
132 let (chunk_input_state_bnhpr, final_state_bhpr) = k4_ssd_state_passing::<Self>(
134 intra_chunk_state_bnhpr,
135 da_chunk_end_bhn,
136 initial_state_bhpr,
137 );
138 san(&chunk_input_state_bnhpr);
139 san(&final_state_bhpr);
140
141 let y_bnlmhp = k5_single_ssd_chunk_scan::<Self>(
143 da_cumsum_bhnl,
144 v_bnlmhp,
145 c_bnlmhr,
146 b_bnlmhr,
147 cb_bnhLMLM,
148 gamma_bnlh,
149 scale_bnlh,
150 chunk_input_state_bnhpr,
151 );
152 san(&y_bnlmhp);
153
154 (y_bnlmhp.inner(), final_state_bhpr.inner())
155 }
156}
157
158crate::decl_ssd_autodiff_backend_ext!(Mamba3SingleSsdAutodiffBackendExt, Mamba3SingleSsdBackendExt);
159
160crate::impl_ssd_backend_ext_for_burn_backends!(Mamba3SingleSsdBackendExt);
165
166#[allow(clippy::too_many_arguments)]
184fn k5_single_ssd_chunk_scan<B: Backend>(
185 da_cumsum_bhnl: F<B, 4>,
186 v_bnlmhp: F<B, 6>,
187 c_bnlmhr: F<B, 6>,
188 b_bnlmhr: F<B, 6>,
189 cb_bnhLMLM: F<B, 5>,
190 gamma_bnlh: F<B, 4>,
191 scale_bnlh: F<B, 4>,
192 chunk_input_state_bnhpr: F<B, 5>,
193) -> F<B, 6> {
194 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
195 let [.., state_rank] = c_bnlmhr.dims();
196 let device = v_bnlmhp.device();
197 let dtype = v_bnlmhp.dtype();
198 let fused = chunk_len * mimo_rank;
199
200 let v_bnLMhp = v_bnlmhp
202 .clone()
203 .reshape([batch, nchunks, fused, nheads, per_head_dim]);
204 let c_bnLMhr = c_bnlmhr
205 .clone()
206 .reshape([batch, nchunks, fused, nheads, state_rank]);
207
208 let da_cumsum_bhnLM = da_cumsum_bhnl
210 .unsqueeze_dim::<5>(4)
211 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
212 .reshape([batch, nheads, nchunks, fused]);
213
214 let exp_da_bnhLMp = da_cumsum_bhnLM
216 .clone()
217 .exp()
218 .permute([0, 2, 1, 3]) .unsqueeze_dim::<5>(4) .expand([batch, nchunks, nheads, fused, per_head_dim]);
221 let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
222 let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
223 let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
224 let y_off_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
225
226 let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]); let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
229 .clone()
230 .unsqueeze_dim::<5>(4) .expand([batch, nchunks, nheads, fused, fused]);
232 let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
233 .unsqueeze_dim::<5>(3) .expand([batch, nchunks, nheads, fused, fused]);
235 let diff_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
236
237 let inf_upper_bnhLMLM =
240 F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype)
241 .triu(0) .unsqueeze_dims::<5>(&[0, 1, 2])
243 .expand([batch, nchunks, nheads, chunk_len, chunk_len])
244 .unsqueeze_dim::<6>(4)
245 .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len])
246 .reshape([batch, nchunks, nheads, fused, chunk_len])
247 .unsqueeze_dim::<6>(5)
248 .expand([batch, nchunks, nheads, fused, chunk_len, mimo_rank])
249 .reshape([batch, nchunks, nheads, fused, fused]);
250 let decay_strict_bnhLMLM = (diff_bnhLMLM + inf_upper_bnhLMLM).exp();
251
252 let scale_bnhLM = scale_bnlh
254 .permute([0, 1, 3, 2]) .unsqueeze_dim::<5>(4) .expand([batch, nchunks, nheads, chunk_len, mimo_rank])
257 .reshape([batch, nchunks, nheads, fused]);
258 let scale_col_bnhLMLM = scale_bnhLM
259 .unsqueeze_dim::<5>(3) .expand([batch, nchunks, nheads, fused, fused]);
261
262 let kernel_bnhLMLM = decay_strict_bnhLMLM * scale_col_bnhLMLM;
263 let masked_cb_bnhLMLM = cb_bnhLMLM * kernel_bnhLMLM;
264 let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
265 let y_lower_bnhLMp = masked_cb_bnhLMLM.matmul(v_bnhLMp);
266
267 let c_bnlhmr = c_bnlmhr.permute([0, 1, 2, 4, 3, 5]);
269 let b_bnlhrm = b_bnlmhr.permute([0, 1, 2, 4, 5, 3]);
270 let qk_dot_bnlhmM = c_bnlhmr.matmul(b_bnlhrm); let v_bnlhmp = v_bnlmhp.permute([0, 1, 2, 4, 3, 5]);
272 let y_d_bnlhmp = qk_dot_bnlhmM.matmul(v_bnlhmp); let gamma_bnlh11 = gamma_bnlh.unsqueeze_dims::<6>(&[4, 5]);
274 let y_d_bnlhmp_scaled = y_d_bnlhmp * gamma_bnlh11;
275
276 let y_diag_bnlmhp = y_d_bnlhmp_scaled.permute([0, 1, 2, 4, 3, 5]);
277 let y_diag_bnLMhp = y_diag_bnlmhp.reshape([batch, nchunks, fused, nheads, per_head_dim]);
278 let y_diag_bnhLMp = y_diag_bnLMhp.permute([0, 1, 3, 2, 4]);
279
280 let y_bnhLMp = y_off_bnhLMp + y_lower_bnhLMp + y_diag_bnhLMp;
282 let y_bnLMhp = y_bnhLMp.permute([0, 1, 3, 2, 4]);
283 y_bnLMhp.reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim])
284}