1#![allow(non_snake_case)]
2
3use crate::mamba3::prelude::*;
4use burn::prelude::*;
5
6impl<B: Backend> Mamba3<B> {
7 pub fn ssd_serial(input: super::Mamba3SsdInput<B>) -> (Tensor<B, 6>, Tensor<B, 4>) {
19 let [batch, nchunks, chunk_len, _mimo_rank, nheads, per_head_dim] = input.v_bnlrhp.dims();
20 let [.., state_rank] = input.b_bnlrhn.dims();
21
22 assert!(
23 input.init_state_hpr.is_none(),
24 "init_state_hpr is not yet supported in ssd_serial; use ssd_minimal instead"
25 );
26 assert!(nchunks > 0, "sequence length must be at least 1");
27
28 let (da_cumsum_bhnl, da_chunk_end_bhn) = k1_ssd_chunk_cumsum(input.da_bnlh.clone());
30 assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
31 assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
32
33 let cb_bnhLL: Tensor<B, 5> = k2_ssd_bmm(input.c_bnlrhn.clone(), input.b_bnlrhn.clone());
35 let intra_chunk_state_bnhpr: Tensor<B, 5> = k3_ssd_chunk_state(
39 input.v_bnlrhp.clone(),
40 input.b_bnlrhn.clone(),
41 da_cumsum_bhnl.clone(),
42 );
43 assert_eq!(
44 [batch, nchunks, nheads, per_head_dim, state_rank],
45 intra_chunk_state_bnhpr.dims()
46 );
47
48 let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<B, 5>, Tensor<B, 4>) =
50 k4_ssd_state_passing(
51 intra_chunk_state_bnhpr,
52 da_chunk_end_bhn,
53 input.initial_state_bhpr,
54 );
55 assert_eq!(
56 [batch, nchunks, nheads, per_head_dim, state_rank],
57 chunk_input_state_bnhpr.dims()
58 );
59 assert_eq!(
60 [batch, nheads, per_head_dim, state_rank],
61 final_state_bhpr.dims()
62 );
63
64 let y_bnlrhp: Tensor<B, 6> = k5_ssd_chunk_scan(
66 da_cumsum_bhnl,
67 input.v_bnlrhp,
68 input.c_bnlrhn,
69 cb_bnhLL,
70 chunk_input_state_bnhpr,
71 );
72
73 (y_bnlrhp, final_state_bhpr)
74 }
75}
76
77pub fn k1_ssd_chunk_cumsum<B: Backend>(da_bnlh: Tensor<B, 4>) -> (Tensor<B, 4>, Tensor<B, 3>) {
90 let [batch, nchunks, chunk_len, nheads] = da_bnlh.dims();
91 let da_bhnl = da_bnlh.permute([0, 3, 1, 2]);
93 let da_cumsum_bhnl = da_bhnl.cumsum(3);
94 assert_eq!([batch, nheads, nchunks, chunk_len], da_cumsum_bhnl.dims());
95
96 let da_chunk_end_bhn: Tensor<B, 3> = da_cumsum_bhnl
97 .clone()
98 .slice(s![.., .., .., -1]) .squeeze_dim(3); assert_eq!([batch, nheads, nchunks], da_chunk_end_bhn.dims());
101
102 (da_cumsum_bhnl, da_chunk_end_bhn)
103}
104
105pub fn k2_ssd_bmm<B: Backend>(c_bnlrhn: Tensor<B, 6>, b_bnlrhn: Tensor<B, 6>) -> Tensor<B, 5> {
118 let [batch, nchunks, chunk_len, mimo_rank, nheads, state_rank] = c_bnlrhn.dims();
119 let fused_len = chunk_len * mimo_rank;
120
121 let c_bnLhn = c_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
123 let b_bnLhn = b_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
124
125 let c_bnhLr = c_bnLhn.permute([0, 1, 3, 2, 4]); let b_bnhrL = b_bnLhn.permute([0, 1, 3, 4, 2]); let cb_bnhLL: Tensor<B, 5> = c_bnhLr.matmul(b_bnhrL);
129 assert_eq!(
130 [batch, nchunks, nheads, fused_len, fused_len],
131 cb_bnhLL.dims()
132 );
133 cb_bnhLL
134}
135
136pub fn k3_ssd_chunk_state<B: Backend>(
152 v_bnlrhp: Tensor<B, 6>,
153 b_bnlrhn: Tensor<B, 6>,
154 da_cumsum_bhnl: Tensor<B, 4>,
155) -> Tensor<B, 5> {
156 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlrhp.dims();
157 let [.., state_rank] = b_bnlrhn.dims();
158 let fused_len = chunk_len * mimo_rank;
159
160 let v_bnLhp = v_bnlrhp.reshape([batch, nchunks, fused_len, nheads, per_head_dim]);
162 let b_bnLhn = b_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
163
164 let a_cumsum_last_bhn1 = da_cumsum_bhnl.clone().slice(s![.., .., .., -1]); let a_cumsum_fused_bhnL = da_cumsum_bhnl
169 .unsqueeze_dim::<5>(4)
170 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
171 .reshape([batch, nheads, nchunks, fused_len]);
172 let decay_bhnL = (a_cumsum_last_bhn1 - a_cumsum_fused_bhnL).exp();
174
175 let decay_bnLh1 = decay_bhnL.permute([0, 2, 3, 1]).unsqueeze_dim(4);
177 let decayed_v_bnLhp = decay_bnLh1 * v_bnLhp;
178
179 let decayed_v_bnhpL = decayed_v_bnLhp.permute([0, 1, 3, 4, 2]);
181 let b_bnhLN = b_bnLhn.permute([0, 1, 3, 2, 4]);
182 let intra_chunk_state_bnhpr: Tensor<B, 5> = decayed_v_bnhpL.matmul(b_bnhLN);
183 assert_eq!(
184 [batch, nchunks, nheads, per_head_dim, state_rank],
185 intra_chunk_state_bnhpr.dims()
186 );
187 intra_chunk_state_bnhpr
188}
189
190pub fn k4_ssd_state_passing<B: Backend>(
208 intra_chunk_state_bnhpr: Tensor<B, 5>,
209 da_chunk_end_bhn: Tensor<B, 3>,
210 initial_state_bhpr: Tensor<B, 4>,
211) -> (Tensor<B, 5>, Tensor<B, 4>) {
212 let [batch, nchunks, nheads, per_head_dim, state_rank] = intra_chunk_state_bnhpr.dims();
213
214 let mut running_state_bhpr = initial_state_bhpr;
215 assert_eq!(
216 [batch, nheads, per_head_dim, state_rank],
217 running_state_bhpr.dims()
218 );
219
220 let mut chunk_input_state_vec_bhpr = Vec::with_capacity(nchunks + 1);
221 chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
222
223 for i_chunk in 0..nchunks {
224 let intra_state_bhpr: Tensor<B, 4> = intra_chunk_state_bnhpr
225 .clone()
226 .slice(s![.., i_chunk, .., .., ..])
227 .squeeze_dim(1);
228
229 let decay_bhpr = da_chunk_end_bhn
230 .clone()
231 .slice(s![.., .., i_chunk])
232 .exp()
233 .unsqueeze_dim::<4>(3)
234 .expand([batch, nheads, per_head_dim, state_rank]);
235
236 running_state_bhpr = decay_bhpr * running_state_bhpr + intra_state_bhpr;
238 chunk_input_state_vec_bhpr.push(running_state_bhpr.clone());
239 }
240
241 let final_state_bhpr = chunk_input_state_vec_bhpr.pop().unwrap();
242 assert_eq!(
243 [batch, nheads, per_head_dim, state_rank],
244 final_state_bhpr.dims()
245 );
246
247 let chunk_input_state_bnhpr = Tensor::stack(chunk_input_state_vec_bhpr, 1);
248 assert_eq!(
249 [batch, nchunks, nheads, per_head_dim, state_rank],
250 chunk_input_state_bnhpr.dims()
251 );
252
253 (chunk_input_state_bnhpr, final_state_bhpr)
254}
255
256pub fn k5_ssd_chunk_scan<B: Backend>(
278 da_cumsum_bhnl: Tensor<B, 4>,
279 v_bnlrhp: Tensor<B, 6>,
280 c_bnlrhn: Tensor<B, 6>,
281 cb_bnhLL: Tensor<B, 5>,
282 chunk_input_state_bnhpr: Tensor<B, 5>,
283) -> Tensor<B, 6> {
284 let [batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim] = v_bnlrhp.dims();
285 let [.., state_rank] = c_bnlrhn.dims();
286 let fused_len = chunk_len * mimo_rank;
287 let device = v_bnlrhp.device();
288
289 let v_bnLhp = v_bnlrhp.reshape([batch, nchunks, fused_len, nheads, per_head_dim]);
291 let c_bnLhn = c_bnlrhn.reshape([batch, nchunks, fused_len, nheads, state_rank]);
292
293 let da_cumsum_fused_bhnL = da_cumsum_bhnl
295 .unsqueeze_dim::<5>(4)
296 .expand([batch, nheads, nchunks, chunk_len, mimo_rank])
297 .reshape([batch, nheads, nchunks, fused_len]);
298
299 let exp_da_fused_bnhLp = da_cumsum_fused_bhnL
302 .clone()
303 .exp()
304 .permute([0, 2, 1, 3])
305 .unsqueeze_dim::<5>(4)
306 .expand([batch, nchunks, nheads, fused_len, per_head_dim]);
307
308 let c_bnhLr = c_bnLhn.permute([0, 1, 3, 2, 4]); let state_bnhrp = chunk_input_state_bnhpr.permute([0, 1, 2, 4, 3]); let ch_bnhLp = c_bnhLr.matmul(state_bnhrp); let blue_bnhLp = ch_bnhLp * exp_da_fused_bnhLp; let da_fused_bnhL = da_cumsum_fused_bhnL.permute([0, 2, 1, 3]); let da_target_bnhLL = da_fused_bnhL
319 .clone()
320 .unsqueeze_dim::<5>(4)
321 .expand([batch, nchunks, nheads, fused_len, fused_len]); let da_source_bnhLL = da_fused_bnhL
323 .unsqueeze_dim::<5>(3)
324 .expand([batch, nchunks, nheads, fused_len, fused_len]); let diff_bnhLL = da_target_bnhLL - da_source_bnhLL;
326
327 let neg_inf_base_bnhll: Tensor<B, 5> = {
330 let zero_ll: Tensor<B, 2> = Tensor::zeros([chunk_len, chunk_len], &device);
331 Tensor::full_like(&zero_ll, f32::NEG_INFINITY)
332 .triu(1) .unsqueeze_dims::<5>(&[0, 1, 2])
334 .expand([batch, nchunks, nheads, chunk_len, chunk_len])
335 };
336 let neg_inf_mimo_bnhLL: Tensor<B, 5> = neg_inf_base_bnhll
338 .unsqueeze_dim::<6>(4)
339 .expand([batch, nchunks, nheads, chunk_len, mimo_rank, chunk_len])
340 .reshape([batch, nchunks, nheads, fused_len, chunk_len])
341 .unsqueeze_dim::<6>(5)
342 .expand([batch, nchunks, nheads, fused_len, chunk_len, mimo_rank])
343 .reshape([batch, nchunks, nheads, fused_len, fused_len]);
344
345 let decay_bnhLL = (diff_bnhLL + neg_inf_mimo_bnhLL).exp(); let v_bnhLp = v_bnLhp.permute([0, 1, 3, 2, 4]); let orange_bnhLp = (cb_bnhLL * decay_bnhLL).matmul(v_bnhLp); let y_bnlrhp = (blue_bnhLp + orange_bnhLp)
353 .permute([0, 1, 3, 2, 4])
354 .reshape([batch, nchunks, chunk_len, mimo_rank, nheads, per_head_dim]);
355
356 y_bnlrhp
357}