Skip to main content

burn_mamba/mamba1/
mamba1.rs

1use crate::mamba1::prelude::*;
2use crate::utils::silu::Silu;
3use burn::prelude::*;
4use burn::{
5    module::{Module, Param},
6    nn::conv::{Conv1d, Conv1dConfig},
7    nn::{Initializer, Linear, LinearConfig, PaddingConfig1d},
8};
9
10#[derive(Module, Debug)]
11pub struct Mamba1<B: Backend> {
12    /// Input channel: d_model.
13    /// Output channel: 2 * d_inner.
14    pub in_proj: Linear<B>,
15
16    /// Input channel: d_inner.
17    /// Output channel: d_inner.
18    pub conv1d: Conv1d<B>,
19
20    /// Input channel: d_inner.
21    /// Output channel: dt_rank + 2 * d_state.
22    pub x_proj: Linear<B>,
23
24    /// Input channel: dt_rank.
25    /// Output channel: d_inner.
26    pub dt_proj: Linear<B>,
27
28    /// Dims: [d_inner, d_state].
29    pub a_log: Param<Tensor<B, 2>>,
30
31    /// Dims: [d_inner].
32    pub d: Param<Tensor<B, 1>>,
33
34    /// Input channel: d_inner.
35    /// Output channel: d_model.
36    pub out_proj: Linear<B>,
37}
38
39#[derive(Config, Debug)]
40pub struct Mamba1Config {
41    /// Hidden dimension.
42    pub d_model: usize,
43
44    /// latent state dimension (`N` in Algorithm 2 from the Mamba paper).
45    #[config(default = 16)]
46    pub d_state: usize,
47
48    #[config(default = 4)]
49    pub d_conv: usize,
50
51    #[config(default = 2)]
52    pub expand: usize,
53
54    /// Minimum dt value.
55    #[config(default = 1e-3)]
56    pub dt_min: f64,
57
58    /// Maximum dt value.
59    #[config(default = 1e-1)]
60    pub dt_max: f64,
61
62    /// Scale for dt initialization.
63    #[config(default = 1.)]
64    pub dt_scale: f64,
65
66    /// Floor for dt initialization.
67    #[config(default = 1e-4)]
68    pub dt_init_floor: f64,
69
70    /// Whether conv1d should have a bias.
71    #[config(default = true)]
72    pub conv_bias: bool,
73
74    /// Whether in_proj and out_proj should have a bias.
75    #[config(default = false)]
76    pub bias: bool,
77
78    /// Rank of Δ (See Section 3.6 "Parameterization of ∆" from the Mamba paper).
79    /// Δ or delta: input-dependent step size.
80    ///
81    /// By default, set to (d_model + d_state - 1) / d_state.
82    pub dt_rank: Option<usize>,
83
84    /// DModel * expand (`D` in Algorithm 2 from the Mamba paper).
85    ///
86    /// By default, set to expand * d_model.
87    pub d_inner: Option<usize>,
88}
89
90impl Mamba1Config {
91    /// Returns the initialized model.
92    pub fn init<B: Backend>(&self, device: &B::Device) -> Mamba1<B> {
93        let d_inner = self.d_inner();
94        debug_assert_ne!(self.d_state, 0);
95        debug_assert!(self.d_model + self.d_state > 0);
96        let dt_rank = self.dt_rank();
97
98        // Helper function for PyTorch-style uniform initialization
99        let uniform_init = |d_input: usize| {
100            let bound = 1.0 / (d_input as f64).sqrt();
101            Initializer::Uniform {
102                min: -bound,
103                max: bound,
104            }
105        };
106
107        let dt_proj = {
108            use burn::tensor::Distribution;
109            let weight: Tensor<B, 2> = {
110                let dt_init_std = (dt_rank as f64).powf(-0.5) * self.dt_scale;
111                Tensor::random(
112                    [dt_rank, d_inner],
113                    Distribution::Uniform(-dt_init_std, dt_init_std),
114                    device,
115                )
116            };
117            debug_assert_eq!([dt_rank, d_inner], weight.dims());
118            let bias: Tensor<B, 1> = {
119                // note: this placeholder impl may lose precision for very small values,
120                // and a Taylor series could approximate it: e^x - 1 = x + x^2/2! + x^3/3! + ⋯
121                // but with the clamp at dt_init_floor, this isn't necessary
122                let expm1 = |t: Tensor<B, 1>| t.exp() - 1.;
123                let dt = Tensor::random([d_inner], Distribution::Uniform(0.0, 1.0), device)
124                    * (f64::ln(self.dt_max) - f64::ln(self.dt_min))
125                    + f64::ln(self.dt_min);
126                let dt = dt.exp().clamp_min(self.dt_init_floor);
127                // Inverse of softplus
128                let inv_dt = dt.clone() + (-expm1(-dt)).log();
129                inv_dt
130            };
131            debug_assert_eq!([d_inner], bias.dims());
132            Linear {
133                weight: Param::from_tensor(weight),
134                bias: Some(Param::from_tensor(bias)),
135            }
136        };
137
138        let a_log = {
139            let a_row: Tensor<B, 1> =
140                Tensor::<B, 1, Int>::arange(1..self.d_state as i64 + 1, device).float();
141            debug_assert_eq!([self.d_state], a_row.dims());
142            let a_row = a_row.unsqueeze();
143            debug_assert_eq!([1, self.d_state], a_row.dims());
144            let a = a_row.repeat(&[d_inner, 1]);
145            debug_assert_eq!([d_inner, self.d_state], a.dims());
146            let a_log = a.log();
147            Param::from_tensor(a_log)
148        };
149
150        Mamba1 {
151            in_proj: LinearConfig::new(self.d_model, 2 * d_inner)
152                .with_bias(self.bias)
153                // follows PyTorch's default initializer
154                .with_initializer(uniform_init(self.d_model))
155                .init(device),
156            conv1d: Conv1dConfig::new(d_inner, d_inner, self.d_conv)
157                // TODO: only left-padding is necessary,
158                // and possibly a narrowing will no longer be necessary
159                .with_padding(PaddingConfig1d::Explicit(self.d_conv - 1, self.d_conv - 1))
160                .with_groups(d_inner)
161                .with_bias(self.conv_bias)
162                // follows PyTorch's default initializer
163                // fan_in = in_channels / groups * kernel_size
164                .with_initializer(uniform_init(self.d_conv))
165                .init(device),
166            x_proj: LinearConfig::new(d_inner, dt_rank + 2 * self.d_state)
167                .with_bias(false)
168                // follows PyTorch's default initializer
169                .with_initializer(uniform_init(d_inner))
170                .init(device),
171            dt_proj,
172            a_log,
173            d: Initializer::Ones.init([d_inner], device),
174            out_proj: LinearConfig::new(d_inner, self.d_model)
175                .with_bias(self.bias)
176                // follows PyTorch's default initializer
177                .with_initializer(uniform_init(d_inner))
178                .init(device),
179        }
180    }
181    pub fn d_inner(&self) -> usize {
182        self.d_inner.unwrap_or(self.expand * self.d_model)
183    }
184    pub fn dt_rank(&self) -> usize {
185        self.dt_rank.unwrap_or(self.d_model.div_ceil(self.d_state))
186    }
187}
188
189impl<B: Backend> Mamba1<B> {
190    /// See also [`Self::step`].
191    ///
192    /// # Shapes
193    ///   - Input [batch, sequence, d_model]
194    ///   - Output [batch, sequence, d_model]
195    pub fn forward(&self, x: Tensor<B, 3>) -> Tensor<B, 3> {
196        let [batch, sequence, d_model] = x.dims();
197        let [d_inner] = self.d.dims();
198        let [_, _, d_conv] = self.conv1d.weight.dims();
199
200        // layer 1 (in_proj)
201        let (xs, res) = {
202            // projects the input d_model into 2 * d_inner
203            let xs_and_res = self.in_proj.forward(x);
204            debug_assert_eq!([batch, sequence, 2 * d_inner], xs_and_res.dims());
205
206            let mut split = xs_and_res
207                .split_with_sizes(vec![d_inner, d_inner], 2)
208                .into_iter();
209            debug_assert_eq!(split.len(), 2);
210            (split.next().unwrap(), split.next().unwrap())
211        };
212        debug_assert_eq!([batch, sequence, d_inner], xs.dims());
213        debug_assert_eq!([batch, sequence, d_inner], res.dims());
214
215        // layer 2 (conv1d)
216        let xs = {
217            // let xs = xs.swap_dims(1, 2);
218            let xs = xs.permute([0, 2, 1]);
219            debug_assert_eq!([batch, d_inner, sequence], xs.dims());
220
221            debug_assert!(d_conv > 0);
222            let xs = self.conv1d.forward(xs);
223            debug_assert_eq!([batch, d_inner, sequence + d_conv - 1], xs.dims());
224
225            let xs = xs.narrow(2, 0, sequence);
226            debug_assert_eq!([batch, d_inner, sequence], xs.dims());
227
228            // restore original positioning as per before the layer 2
229            // let xs = xs.swap_dims(1, 2);
230            let xs = xs.permute([0, 2, 1]);
231            debug_assert_eq!([batch, sequence, d_inner], xs.dims());
232
233            // activation
234            let xs = Silu::new().forward(xs);
235            debug_assert_eq!([batch, sequence, d_inner], xs.dims());
236
237            xs
238        };
239        debug_assert_eq!([batch, sequence, d_inner], xs.dims());
240
241        let ss = self.ss(xs);
242        debug_assert_eq!([batch, sequence, d_inner], ss.dims());
243
244        // activation
245        let ys = ss * Silu::new().forward(res);
246        debug_assert_eq!([batch, sequence, d_inner], ys.dims());
247
248        let y = self.out_proj.forward(ys);
249        debug_assert_eq!([batch, sequence, d_model], y.dims());
250
251        y
252    }
253
254    /// # Shapes
255    ///   - Input [batch, sequence, d_inner]
256    ///   - Output [batch, sequence, d_inner]
257    pub fn ss(&self, u: Tensor<B, 3>) -> Tensor<B, 3> {
258        let [batch, sequence, d_inner] = u.dims();
259        let [_d_inner, d_state] = self.a_log.dims();
260        let [dt_rank, _d_inner] = self.dt_proj.weight.dims();
261
262        // Compute ∆ A B C D, the state space parameters.
263
264        // A
265        // this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective)
266        let a = self.a_log.val().exp().neg();
267        debug_assert_eq!([d_inner, d_state], a.dims());
268
269        let x_dbl = self.x_proj.forward(u.clone());
270        debug_assert_eq!([batch, sequence, dt_rank + 2 * d_state], x_dbl.dims());
271
272        // ∆ (part 1/2)
273        // ∆ is input-dependent
274        // B and C are input-dependent
275        let mut split = x_dbl
276            .split_with_sizes(vec![dt_rank, d_state, d_state], 2)
277            .into_iter();
278        let delta = split.next().unwrap();
279        let b = split.next().unwrap();
280        let c = split.next().unwrap();
281        debug_assert_eq!([batch, sequence, dt_rank], delta.dims());
282        debug_assert_eq!([batch, sequence, d_state], b.dims());
283        debug_assert_eq!([batch, sequence, d_state], c.dims());
284
285        // ∆ (part 2/2)
286        // ∆ is input-dependent
287        let delta = self.dt_proj.forward(delta);
288        debug_assert_eq!([batch, sequence, d_inner], delta.dims());
289
290        let delta = burn::tensor::activation::softplus(delta, 1.);
291
292        // let delta = delta.swap_dims(0, 1);
293        let delta = delta.permute([1, 0, 2]);
294        debug_assert_eq!([sequence, batch, d_inner], delta.dims());
295
296        // let c = c.swap_dims(0, 1);
297        let c = c.permute([1, 0, 2]);
298        debug_assert_eq!([sequence, batch, d_state], c.dims());
299
300        Self::selective_scan(delta, a, b, c, self.d.val(), u)
301    }
302
303    /// Selective Scan.
304    ///
305    /// Does selective scan algorithm. See:
306    /// - Section 2 State Space Models from the Mamba paper;
307    /// - Algorithm 2 in Section 3.2 from the Mamba paper;
308    /// - run_SSM(A, B, C, u) from The Annotated S4.
309    ///
310    /// # Shapes
311    ///   - Input delta [sequence, batch, d_inner]
312    ///   - Input a [d_inner, d_state]
313    ///   - Input b [batch, sequence, d_state]
314    ///   - Input c [sequence, batch, d_state]
315    ///   - Input d [d_inner]
316    ///   - Input u [batch, sequence, d_inner]
317    ///   - Output [batch, sequence, d_inner]
318    pub fn selective_scan(
319        delta: Tensor<B, 3>,
320        a: Tensor<B, 2>,
321        b: Tensor<B, 3>,
322        c: Tensor<B, 3>,
323        d: Tensor<B, 1>,
324        u: Tensor<B, 3>,
325    ) -> Tensor<B, 3> {
326        let device = &u.device();
327        let [sequence, batch, d_inner] = delta.dims();
328        let [_d_inner, d_state] = a.dims();
329        let outer_shape = [sequence, batch, d_inner, d_state];
330
331        // Discretize continuous parameters (A, B)
332        //  - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper)
333        //  - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
334        //    "A is the more important term and the performance doesn't change much with the simplification on B"
335        let (delta_a, delta_bu) = {
336            let delta = delta.unsqueeze_dim(3);
337            debug_assert_eq!([sequence, batch, d_inner, 1], delta.dims());
338            let delta = delta.expand(outer_shape);
339            debug_assert_eq!(outer_shape, delta.dims());
340
341            let a = a.unsqueeze_dims(&[0, 1]);
342            debug_assert_eq!([1, 1, d_inner, d_state], a.dims());
343            let a = a.expand(outer_shape);
344            debug_assert_eq!(outer_shape, a.dims());
345            let delta_a = (delta.clone() * a).exp();
346            debug_assert_eq!(outer_shape, delta_a.dims());
347
348            // let b = b.swap_dims(0, 1);
349            let b = b.permute([1, 0, 2]);
350            debug_assert_eq!([sequence, batch, d_state], b.dims());
351            let b = b.unsqueeze_dim(2);
352            debug_assert_eq!([sequence, batch, 1, d_state], b.dims());
353            let b = b.expand(outer_shape);
354            debug_assert_eq!(outer_shape, b.dims());
355            let delta_b = delta * b;
356            debug_assert_eq!(outer_shape, delta_b.dims());
357
358            // let u = u.clone().swap_dims(0, 1);
359            let u = u.clone().permute([1, 0, 2]);
360            debug_assert_eq!([sequence, batch, d_inner], u.dims());
361            let u = u.unsqueeze_dim(3);
362            debug_assert_eq!([sequence, batch, d_inner, 1], u.dims());
363            let u = u.expand(outer_shape);
364            debug_assert_eq!(outer_shape, u.dims());
365            let delta_bu = delta_b * u;
366            debug_assert_eq!(outer_shape, delta_bu.dims());
367
368            (delta_a, delta_bu)
369        };
370        debug_assert_eq!(outer_shape, delta_a.dims());
371        debug_assert_eq!(outer_shape, delta_bu.dims());
372
373        // Perform selective scan (see scan_SSM() from The Annotated S4)
374        // Note that the below is sequential, while the official implementation does a much faster parallel scan that
375        // is additionally hardware-aware (like FlashAttention).
376
377        // unstack the Sequence axis
378
379        let delta_a = delta_a.split(1, 0);
380        debug_assert_eq!(delta_a.len(), sequence);
381
382        let delta_bu = delta_bu.split(1, 0);
383        debug_assert_eq!(delta_bu.len(), sequence);
384
385        let c = c.unsqueeze_dim(3);
386        debug_assert_eq!([sequence, batch, d_state, 1], c.dims());
387        let c = c.split(1, 0);
388        debug_assert_eq!(c.len(), sequence);
389
390        let inner_shape = [batch, d_inner, d_state];
391        let mut xs: Tensor<B, 3> = Tensor::zeros(inner_shape, device);
392        let mut ys = Vec::with_capacity(sequence); // inner shape: [batch, d_inner]
393        for ((delta_a, delta_bu), c) in delta_a
394            .into_iter()
395            .zip(delta_bu.into_iter())
396            .zip(c.into_iter())
397        {
398            let delta_a = delta_a.squeeze_dim(0);
399            debug_assert_eq!(inner_shape, delta_a.dims());
400            let delta_bu = delta_bu.squeeze_dim(0);
401            debug_assert_eq!(inner_shape, delta_bu.dims());
402            let c = c.squeeze_dim(0);
403            debug_assert_eq!([batch, d_state, 1], c.dims());
404
405            xs = (xs.clone() * delta_a) + delta_bu;
406            let y = xs.clone().matmul(c);
407            debug_assert_eq!([batch, d_inner, 1], y.dims());
408            let y = y.squeeze_dim(2);
409            debug_assert_eq!([batch, d_inner], y.dims());
410            ys.push(y);
411        }
412
413        let ys = Tensor::stack(ys, 1);
414        debug_assert_eq!([batch, sequence, d_inner], ys.dims());
415
416        let d = d.unsqueeze_dims(&[0, 1]);
417        debug_assert_eq!([1, 1, d_inner], d.dims());
418        let d = d.expand([batch, sequence, d_inner]);
419
420        let ys = ys + (d * u);
421        debug_assert_eq!([batch, sequence, d_inner], ys.dims());
422
423        ys
424    }
425}
426
427mod step {
428    use super::*;
429
430    impl<B: Backend> Mamba1<B> {
431        /// # Shapes
432        ///   - Input [batch, d_model]
433        ///   - Output [batch, d_model]
434        pub fn step(
435            &self,
436            x: Tensor<B, 2>,
437            mut cache: Mamba1Cache<B>,
438        ) -> (Tensor<B, 2>, Mamba1Cache<B>) {
439            let [batch, d_inner, d_conv] = cache.conv.dims();
440            let [_batch, d_model] = x.dims();
441
442            // layer 1 (in_proj)
443            let (xs, res) = {
444                // projects the input d_model into 2 * d_inner
445                let xs_and_res = self.in_proj.forward(x);
446                debug_assert_eq!([batch, 2 * d_inner], xs_and_res.dims());
447
448                let mut split = xs_and_res
449                    .split_with_sizes(vec![d_inner, d_inner], 1)
450                    .into_iter();
451                (split.next().unwrap(), split.next().unwrap())
452            };
453            debug_assert_eq!([batch, d_inner], xs.dims());
454            debug_assert_eq!([batch, d_inner], res.dims());
455
456            // layer 2 (conv1d)
457            cache.conv = cache.conv.map(|conv| {
458                // split-off oldest/first column (i.e. rolling leftwards)
459                let t0 = conv.narrow(2, 1, d_conv - 1);
460                debug_assert_eq!([batch, d_inner, d_conv - 1], t0.dims());
461
462                // insert xs as a the newest/last column
463                let conv = Tensor::cat([t0, xs.unsqueeze_dim(2)].to_vec(), 2);
464                debug_assert_eq!([batch, d_inner, d_conv], conv.dims());
465
466                conv
467            });
468            let xs = {
469                let conv1d = self.conv1d.weight.val();
470                // [channels_out, channels_in / groups, kernel_size]
471                debug_assert_eq!([d_inner, 1, d_conv], conv1d.dims());
472                // let conv1d = conv1d.swap_dims(1, 0);
473                let conv1d = conv1d.permute([1, 0, 2]);
474                debug_assert_eq!([1, d_inner, d_conv], conv1d.dims());
475                let conv1d = conv1d.expand([batch, d_inner, d_conv]);
476                debug_assert_eq!([batch, d_inner, d_conv], conv1d.dims());
477
478                let xs = cache.conv.val() * conv1d;
479                let xs = xs.sum_dim(2);
480                debug_assert_eq!([batch, d_inner, 1], xs.dims());
481                let xs = xs.squeeze_dim(2);
482                debug_assert_eq!([batch, d_inner], xs.dims());
483
484                // conv1d bias
485                let conv1d_bias = self.conv1d.bias.as_ref().unwrap().val();
486                // [channels_out]
487                debug_assert_eq!([d_inner], conv1d_bias.dims());
488                let conv1d_bias = conv1d_bias.unsqueeze();
489                debug_assert_eq!([1, d_inner], conv1d_bias.dims());
490                let xs = xs + conv1d_bias;
491
492                // activation
493                let xs = Silu::new().forward(xs);
494                debug_assert_eq!([batch, d_inner], xs.dims());
495
496                xs
497            };
498            debug_assert_eq!([batch, d_inner], xs.dims());
499
500            let (ss, cache) = self.ss_step(xs, cache);
501            debug_assert_eq!([batch, d_inner], ss.dims());
502
503            // activation
504            let ys = ss * Silu::new().forward(res);
505            debug_assert_eq!([batch, d_inner], ys.dims());
506
507            let y = self.out_proj.forward(ys);
508            debug_assert_eq!([batch, d_model], y.dims());
509
510            (y, cache)
511        }
512
513        /// Runs the SSM. See:
514        /// - Algorithm 2 in Section 3.2 from the Mamba paper;
515        /// - run_SSM(A, B, C, u) from The Annotated S4.
516        ///
517        /// # Shapes
518        ///   - Input u [batch, d_inner]
519        ///   - Output [batch, d_inner]
520        pub fn ss_step(
521            &self,
522            u: Tensor<B, 2>,
523            cache: Mamba1Cache<B>,
524        ) -> (Tensor<B, 2>, Mamba1Cache<B>) {
525            let [batch, d_inner, d_state] = cache.ssm.dims();
526            let [dt_rank, _d_inner] = self.dt_proj.weight.dims();
527
528            // Compute ∆ A B C D, the state space parameters.
529
530            // A
531            // this is input independent (see Section 3.5.2 "Interpretation of A" form the Mamba paper for why A isn't selective)
532            let a = self.a_log.val().exp().neg();
533            debug_assert_eq!([d_inner, d_state], a.dims());
534
535            let x_dbl = self.x_proj.forward(u.clone());
536            debug_assert_eq!([batch, dt_rank + 2 * d_state], x_dbl.dims());
537
538            // ∆ (part 1/2)
539            // ∆ is input-dependent
540            // B and C are input-dependent
541            let mut split = x_dbl
542                .split_with_sizes(vec![dt_rank, d_state, d_state], 1)
543                .into_iter();
544            let delta = split.next().unwrap();
545            let b = split.next().unwrap();
546            let c = split.next().unwrap();
547            debug_assert_eq!([batch, dt_rank], delta.dims());
548            debug_assert_eq!([batch, d_state], b.dims());
549            debug_assert_eq!([batch, d_state], c.dims());
550
551            // ∆ (part 2/2)
552            // ∆ is input-dependent
553            let delta = self.dt_proj.forward(delta);
554            debug_assert_eq!([batch, d_inner], delta.dims());
555            let delta = burn::tensor::activation::softplus(delta, 1.);
556
557            Self::selective_scan_step(delta, a, b, c, self.d.val(), u, cache)
558        }
559
560        /// Selective Scan.
561        ///
562        /// Does selective scan algorithm. See:
563        /// - Section 2 State Space Models from the Mamba paper;
564        /// - Algorithm 2 in Section 3.2 from the Mamba paper;
565        /// - run_SSM(A, B, C, u) from The Annotated S4.
566        ///
567        /// # Shapes
568        ///   - Input delta [batch, d_inner]
569        ///   - Input a [d_inner, d_state]
570        ///   - Input b [batch, d_state]
571        ///   - Input c [batch, d_state]
572        ///   - Input d [d_inner]
573        ///   - Input u [batch, d_inner]
574        ///   - Output [batch, d_inner]
575        pub fn selective_scan_step(
576            delta: Tensor<B, 2>,
577            a: Tensor<B, 2>,
578            b: Tensor<B, 2>,
579            c: Tensor<B, 2>,
580            d: Tensor<B, 1>,
581            u: Tensor<B, 2>,
582            mut cache: Mamba1Cache<B>,
583        ) -> (Tensor<B, 2>, Mamba1Cache<B>) {
584            let [batch, d_inner, d_state] = cache.ssm.dims();
585            let outer_shape = [batch, d_inner, d_state];
586
587            // Discretize continuous parameters (A, B)
588            //  - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper)
589            //  - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
590            //    "A is the more important term and the performance doesn't change much with the simplification on B"
591            let (delta_a, delta_bu) = {
592                let delta = delta.unsqueeze_dim(2);
593                debug_assert_eq!([batch, d_inner, 1], delta.dims());
594                let delta = delta.expand(outer_shape);
595                debug_assert_eq!(outer_shape, delta.dims());
596
597                let a = a.unsqueeze();
598                debug_assert_eq!([1, d_inner, d_state], a.dims());
599                let a = a.expand(outer_shape);
600                debug_assert_eq!(outer_shape, a.dims());
601                let delta_a = (delta.clone() * a).exp();
602                debug_assert_eq!(outer_shape, delta_a.dims());
603
604                let b = b.unsqueeze_dim(1);
605                debug_assert_eq!([batch, 1, d_state], b.dims());
606                let b = b.expand(outer_shape);
607                debug_assert_eq!(outer_shape, b.dims());
608                let delta_b = delta * b;
609                debug_assert_eq!(outer_shape, delta_b.dims());
610
611                let u = u.clone().unsqueeze_dim(2);
612                debug_assert_eq!([batch, d_inner, 1], u.dims());
613                let u = u.expand(outer_shape);
614                debug_assert_eq!(outer_shape, u.dims());
615                let delta_bu = delta_b * u;
616                debug_assert_eq!(outer_shape, delta_bu.dims());
617
618                (delta_a, delta_bu)
619            };
620            debug_assert_eq!(outer_shape, delta_a.dims());
621            debug_assert_eq!(outer_shape, delta_bu.dims());
622
623            cache.ssm = cache.ssm.map(|ssm| (ssm * delta_a) + delta_bu);
624
625            let c = c.unsqueeze_dim(2);
626            debug_assert_eq!([batch, d_state, 1], c.dims());
627
628            let y = cache.ssm.val().matmul(c);
629            debug_assert_eq!([batch, d_inner, 1], y.dims());
630            let y = y.squeeze_dim(2);
631            debug_assert_eq!([batch, d_inner], y.dims());
632
633            let d = d.unsqueeze();
634            debug_assert_eq!([1, d_inner], d.dims());
635            let d = d.expand([batch, d_inner]);
636            debug_assert_eq!([batch, d_inner], d.dims());
637
638            let y = y + (d * u);
639            debug_assert_eq!([batch, d_inner], y.dims());
640
641            (y, cache)
642        }
643    }
644}