1#![allow(non_snake_case)]
16
17use crate::mamba3::double_ssd::prelude::*;
18use crate::utils::fprim::{F, san};
19use burn::backend::tensor::FloatTensor;
20use burn::backend::*;
21use burn::backend::{Backend, Dispatch, backend_extension};
22use burn::tensor::Tensor;
23use burn::tensor::s;
24
25impl Mamba3DoubleSsdInput {
26 pub fn double_ssd_serial_recalculated(self) -> (Tensor<6>, Tensor<4>) {
37 let input = self;
38 assert!(
39 input.init_state_hpr.is_none(),
40 "init_state_hpr not yet implemented for ssd_serial_recalculated"
41 );
42
43 let (y_bnlmhp, final_state_bhpr) =
44 <Dispatch as Mamba3DoubleSsdBackendExt>::double_ssd_serial_recalculated(
45 input.v_bnlmhp.into_primitive(),
46 input.da_bnlh.into_primitive(),
47 input.b_bnlmhr.into_primitive(),
48 input.c_bnlmhr.into_primitive(),
49 input.initial_state_bhpr.into_primitive(),
50 );
51 let y_bnlmhp = Tensor::from_primitive(y_bnlmhp);
52 let final_state_bhpr = Tensor::from_primitive(final_state_bhpr);
53 (y_bnlmhp, final_state_bhpr)
54 }
55}
56
57#[backend_extension(
64 Cpu: cfg(feature = "backend-cpu"),
65 Cuda: cfg(feature = "backend-cuda"),
66 Rocm: cfg(feature = "backend-rocm"),
67 Metal: cfg(feature = "backend-metal"),
68 Vulkan: cfg(feature = "backend-vulkan"),
69 Wgpu: cfg(feature = "backend-wgpu"),
70 WebGpu: cfg(feature = "backend-webgpu"),
71 Flex: cfg(feature = "backend-flex"),
72 NdArray: cfg(feature = "backend-ndarray"),
73 LibTorch: cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu")),
74 Autodiff: cfg(feature = "autodiff"),
75)]
76pub trait Mamba3DoubleSsdBackendExt: Backend {
77 fn double_ssd_serial_recalculated(
90 v_bnlmhp: FloatTensor<Self>,
91 da_bnlh: FloatTensor<Self>,
92 b_bnlmhr: FloatTensor<Self>,
93 c_bnlmhr: FloatTensor<Self>,
94 initial_state_bhpr: FloatTensor<Self>,
95 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
96 let v_bnlmhp = F::<Self, 6>::new(v_bnlmhp);
98 let da_bnlh = F::<Self, 4>::new(da_bnlh);
99 let b_bnlmhr = F::<Self, 6>::new(b_bnlmhr);
100 let c_bnlmhr = F::<Self, 6>::new(c_bnlmhr);
101 let initial_state_bhpr = F::<Self, 4>::new(initial_state_bhpr);
102
103 let nchunks = v_bnlmhp.dims()[1];
104 assert!(nchunks > 0, "sequence length must be at least 1");
105
106 let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum::<Self>(da_bnlh);
107 san(&da_cumsum_bhnl);
108
109 let cb_bnhLMLM = k2_ssd_bmm::<Self>(c_bnlmhr.clone(), b_bnlmhr.clone());
110 san(&cb_bnhLMLM);
111
112 let intra_chunk_state_bnhpr =
113 k3_ssd_chunk_state::<Self>(v_bnlmhp.clone(), b_bnlmhr, da_cumsum_bhnl.clone());
114 san(&intra_chunk_state_bnhpr);
115
116 let (chunk_input_state_bnhpr, final_state_bhpr) = k4_ssd_state_passing::<Self>(
117 intra_chunk_state_bnhpr,
118 da_chunk_end_bhn,
119 initial_state_bhpr,
120 );
121 san(&chunk_input_state_bnhpr);
122 san(&final_state_bhpr);
123
124 let y_bnlmhp = k5_ssd_chunk_scan::<Self>(
125 da_cumsum_bhnl,
126 v_bnlmhp,
127 c_bnlmhr,
128 cb_bnhLMLM,
129 chunk_input_state_bnhpr,
130 );
131 san(&y_bnlmhp);
132
133 (y_bnlmhp.inner(), final_state_bhpr.inner())
134 }
135}
136
137crate::decl_ssd_autodiff_backend_ext!(Mamba3DoubleSsdAutodiffBackendExt, Mamba3DoubleSsdBackendExt);
138
139crate::impl_ssd_backend_ext_for_burn_backends!(Mamba3DoubleSsdBackendExt);
144
145pub(crate) fn k1_ssd_chunk_cumsum<B: Backend>(da_bnlh: F<B, 4>) -> (F<B, 4>, F<B, 3>) {
157 let da_bhnl = da_bnlh.permute([0, 3, 1, 2]);
158 let da_cumsum_bhnl = da_bhnl.cumsum(3);
159 let da_chunk_end_bhn = da_cumsum_bhnl
160 .clone()
161 .slice(s![.., .., .., -1])
162 .squeeze_dim::<3>(3);
163 (da_cumsum_bhnl, da_chunk_end_bhn)
164}
165
166pub(crate) fn k2_ssd_bmm<B: Backend>(c_bnlmhr: F<B, 6>, b_bnlmhr: F<B, 6>) -> F<B, 5> {
170 let [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank] = c_bnlmhr.dims();
171 let c_bnLMhr = c_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
172 let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
173 let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
174 let b_bnhrLM = b_bnLMhr.permute([0, 1, 3, 4, 2]);
175 c_bnhLMr.matmul(b_bnhrLM)
176}
177
178pub(crate) fn k3_ssd_chunk_state<B: Backend>(
184 v_bnlmhp: F<B, 6>,
185 b_bnlmhr: F<B, 6>,
186 da_cumsum_bhnl: F<B, 4>,
187) -> F<B, 5> {
188 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
189 let [.., state_rank] = b_bnlmhr.dims();
190
191 let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
192 let b_bnLMhr = b_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
193
194 let da_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]);
195 let da_cumsum_bhnLM = da_cumsum_bhnl
196 .unsqueeze_dim::<5>(4) .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); let decay_bhnLM = (da_cumsum_last_bhn1 - da_cumsum_bhnLM).exp();
200
201 let decay_bnLMh1 = decay_bhnLM.permute([0, 2, 3, 1]).unsqueeze_dim::<5>(4); let decayed_v_bnLMhp = decay_bnLMh1 * v_bnLMhp;
203
204 let decayed_v_bnhpLM = decayed_v_bnLMhp.permute([0, 1, 3, 4, 2]);
205 let b_bnhLMr = b_bnLMhr.permute([0, 1, 3, 2, 4]);
206 let intra_chunk_state_bnhpr = decayed_v_bnhpLM.matmul(b_bnhLMr);
207 assert_eq!(
208 [batch, nchunks, nheads, per_head_dim, state_rank],
209 intra_chunk_state_bnhpr.dims()
210 );
211 intra_chunk_state_bnhpr
212}
213
214pub(crate) fn k4_ssd_state_passing<B: Backend>(
219 intra_chunk_state_bnhpr: F<B, 5>,
220 da_chunk_end_bhn: F<B, 3>,
221 initial_state_bhpr: F<B, 4>,
222) -> (F<B, 5>, F<B, 4>) {
223 let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
224
225 let mut running_state_bhpr = initial_state_bhpr;
226 let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
227 chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
228
229 for i_chunk in 0..nchunks {
230 let intra_state_bhpr = intra_chunk_state_bnhpr
231 .clone()
232 .slice(s![.., i_chunk, .., .., ..])
233 .squeeze_dim::<4>(1);
234 let decay_bhpr = da_chunk_end_bhn
235 .clone()
236 .slice(s![.., .., i_chunk])
237 .unsqueeze_dim::<4>(3)
238 .exp()
239 .expand([batch, nheads, per_head_dim, state_rank]);
240 running_state_bhpr = decay_bhpr * running_state_bhpr + intra_state_bhpr;
241 chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
242 }
243
244 let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
245 let chunk_input_state_bnhpr = F::stack(chunk_input_state_vec_bhpr, 1);
246 (chunk_input_state_bnhpr, final_state_bhpr)
247}
248
249fn k5_ssd_chunk_scan<B: Backend>(
255 da_cumsum_bhnl: F<B, 4>,
256 v_bnlmhp: F<B, 6>,
257 c_bnlmhr: F<B, 6>,
258 cb_bnhLMLM: F<B, 5>,
259 chunk_input_state_bnhpr: F<B, 5>,
260) -> F<B, 6> {
261 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlmhp.dims();
262 let [.., state_rank] = c_bnlmhr.dims();
263 let device = v_bnlmhp.device();
264 let dtype = v_bnlmhp.dtype();
265
266 let v_bnLMhp = v_bnlmhp.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, per_head_dim]);
267 let c_bnLMhr = c_bnlmhr.reshape([batch, nchunks, chunk_len * mimo_rank, nheads, state_rank]);
268
269 let da_cumsum_bhnLM = da_cumsum_bhnl
271 .unsqueeze_dim::<5>(4) .expand([batch, nheads, nchunks, chunk_len, mimo_rank]) .reshape([batch, nheads, nchunks, chunk_len * mimo_rank]); let exp_da_bnhLMp = da_cumsum_bhnLM
277 .clone()
278 .exp()
279 .permute([0, 2, 1, 3]) .unsqueeze_dim::<5>(4) .expand([batch, nchunks, nheads, chunk_len * mimo_rank, per_head_dim]);
282 let c_bnhLMr = c_bnLMhr.permute([0, 1, 3, 2, 4]);
283 let chunk_input_state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]);
284 let ch_bnhLMp = c_bnhLMr.matmul(chunk_input_state_bnhrp);
285 let blue_bnhLMp = ch_bnhLMp * exp_da_bnhLMp;
286
287 let da_cumsum_bnhLM = da_cumsum_bhnLM.permute([0, 2, 1, 3]);
289 let target_da_cumsum_bnhLMLM = da_cumsum_bnhLM
290 .clone()
291 .unsqueeze_dim::<5>(4) .expand([
293 batch,
294 nchunks,
295 nheads,
296 chunk_len * mimo_rank,
297 chunk_len * mimo_rank,
298 ]);
299 let source_da_cumsum_bnhLMLM = da_cumsum_bnhLM
300 .unsqueeze_dim::<5>(3) .expand([
302 batch,
303 nchunks,
304 nheads,
305 chunk_len * mimo_rank,
306 chunk_len * mimo_rank,
307 ]);
308 let diff_da_cumsum_bnhLMLM = target_da_cumsum_bnhLMLM - source_da_cumsum_bnhLMLM;
309
310 let neg_inf_base_bnhll =
314 F::<B, 2>::full([chunk_len, chunk_len], f32::NEG_INFINITY, &device, dtype)
315 .triu(1) .unsqueeze_dims::<5>(&[0, 1, 2]) .expand([batch, nchunks, nheads, chunk_len, chunk_len]); let neg_inf_bnhLMLM = neg_inf_base_bnhll
319 .unsqueeze_dim::<6>(4) .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len]) .reshape([batch, nchunks, nheads, chunk_len * mimo_rank, chunk_len]) .unsqueeze_dim::<6>(5) .expand([
324 batch,
325 nchunks,
326 nheads,
327 chunk_len * mimo_rank,
328 chunk_len,
329 mimo_rank,
330 ]) .reshape([
332 batch,
333 nchunks,
334 nheads,
335 chunk_len * mimo_rank,
336 chunk_len * mimo_rank,
337 ]); let decay_bnhLMLM = (diff_da_cumsum_bnhLMLM + neg_inf_bnhLMLM).exp();
340
341 let v_bnhLMp = v_bnLMhp.permute([0, 1, 3, 2, 4]);
342 let orange_bnhLMp = (cb_bnhLMLM * decay_bnhLMLM).matmul(v_bnhLMp);
343
344 let y_bnlmhp = (blue_bnhLMp + orange_bnhLMp)
346 .permute([0, 1, 3, 2, 4]) .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]); y_bnlmhp
349}