burn_mamba/mamba3/ssd/serial_recalculated/
serial_recalculated.rs1#![allow(non_snake_case)]
2
3use crate::mamba3::prelude::*;
4use crate::mamba3::ssd::serial;
5use burn::prelude::*;
6use burn::tensor::{Tensor, TensorPrimitive, ops::FloatTensor};
7
8impl<B: Backend + Mamba3BackendExt> Mamba3<B> {
9 pub fn ssd_serial_recalculated(
21 input: super::super::Mamba3SsdInput<B>,
22 ) -> (Tensor<B, 6>, Tensor<B, 4>) {
23 assert!(
24 input.init_state_hpr.is_none(),
25 "init_state_hpr not yet implemented for ssd_serial_recalculated"
26 );
27
28 let (da_cumsum_bhnl, _da_chunk_end_bhn): (Tensor<B, 4>, Tensor<B, 3>) =
30 serial::k1_ssd_chunk_cumsum(input.da_bnlh.clone());
31
32 let (y_bnlrhp, final_state_bhpr) = <B as Mamba3BackendExt>::ssd_serial_recalculated(
33 input.v_bnlrhp.into_primitive().tensor(),
34 input.da_bnlh.into_primitive().tensor(),
35 input.b_bnlrhn.into_primitive().tensor(),
36 input.c_bnlrhn.into_primitive().tensor(),
37 input.initial_state_bhpr.into_primitive().tensor(),
38 da_cumsum_bhnl.into_primitive().tensor(),
39 );
40 let y_bnlrhp = Tensor::from_primitive(TensorPrimitive::Float(y_bnlrhp));
41 let final_state_bhpr = Tensor::from_primitive(TensorPrimitive::Float(final_state_bhpr));
42 (y_bnlrhp, final_state_bhpr)
43 }
44}
45
46pub trait Mamba3BackendExt: burn::tensor::backend::Backend {
51 fn ssd_serial_recalculated(
65 v_bnlrhp: FloatTensor<Self>,
66 _da_bnlh: FloatTensor<Self>,
67 b_bnlrhn: FloatTensor<Self>,
68 c_bnlrhn: FloatTensor<Self>,
69 initial_state_bhpr: FloatTensor<Self>,
70 da_cumsum_bhnl: FloatTensor<Self>,
71 ) -> (FloatTensor<Self>, FloatTensor<Self>) {
72 let v: Tensor<Self, 6> = mk(v_bnlrhp);
75 let b: Tensor<Self, 6> = mk(b_bnlrhn);
76 let c: Tensor<Self, 6> = mk(c_bnlrhn);
77 let da_cumsum: Tensor<Self, 4> = mk(da_cumsum_bhnl);
78 let init_state: Tensor<Self, 4> = mk(initial_state_bhpr);
79
80 let da_chunk_end_bhn: Tensor<Self, 3> =
82 da_cumsum.clone().slice(s![.., .., .., -1]).squeeze_dim(3);
83
84 let cb_bnhLL: Tensor<Self, 5> = serial::k2_ssd_bmm(c.clone(), b.clone());
85 let intra_chunk_state_bnhpr: Tensor<Self, 5> =
86 serial::k3_ssd_chunk_state(v.clone(), b, da_cumsum.clone());
87 let (chunk_input_state_bnhpr, final_state_bhpr): (Tensor<Self, 5>, Tensor<Self, 4>) =
88 serial::k4_ssd_state_passing(intra_chunk_state_bnhpr, da_chunk_end_bhn, init_state);
89 let y_bnlrhp: Tensor<Self, 6> =
90 serial::k5_ssd_chunk_scan(da_cumsum, v, c, cb_bnhLL, chunk_input_state_bnhpr);
91
92 let y_bnlrhp = y_bnlrhp.into_primitive().tensor();
93 let final_state_bhpr = final_state_bhpr.into_primitive().tensor();
94 (y_bnlrhp, final_state_bhpr)
95 }
96}
97
98#[cfg(feature = "autodiff")]
100pub trait Mamba3AutodiffBackendExt:
101 Backend + Mamba3BackendExt + burn::tensor::backend::AutodiffBackend
102{
103}
104#[cfg(feature = "autodiff")]
106impl<B: Backend + Mamba3BackendExt> Mamba3BackendExt for burn::backend::Autodiff<B> {}
107#[cfg(feature = "autodiff")]
109impl<B: Backend + Mamba3BackendExt> Mamba3AutodiffBackendExt for burn::backend::Autodiff<B> {}
110
111#[cfg(feature = "backend-ndarray")]
116impl<F, I> Mamba3BackendExt for burn::backend::NdArray<F, I> {}
117#[cfg(feature = "backend-flex")]
118impl Mamba3BackendExt for burn::backend::Flex {}
119#[cfg(any(feature = "backend-tch-cpu", feature = "backend-tch-gpu"))]
120impl<F, I> Mamba3BackendExt for burn::backend::libtorch::LibTorch<F, I> {}
121#[cfg(feature = "backend-remote")]
122impl<F, I> Mamba3BackendExt for burn::backend::RemoteBackend<F, I> {}
123
124#[cfg(feature = "cubecl")]
126mod cubecl {
127 use burn_cubecl::{CubeBackend, CubeRuntime, FloatElement, IntElement, element::BoolElement};
128 impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> super::Mamba3BackendExt
129 for CubeBackend<R, F, I, BT>
130 {
131 }
132}
133
134#[cfg(feature = "fusion")]
136mod fusion {
137 use burn_fusion::{Fusion, FusionBackend};
138 impl<B: FusionBackend + super::Mamba3BackendExt> super::Mamba3BackendExt for Fusion<B> {}
139}
140
141pub(crate) fn mk<B: Backend, const D: usize>(p: FloatTensor<B>) -> Tensor<B, D> {
143 Tensor::from_primitive(TensorPrimitive::Float(p))
144}