diff --git a/src/operators/upwind4.rs b/src/operators/upwind4.rs index f72fb03..ace2a43 100644 --- a/src/operators/upwind4.rs +++ b/src/operators/upwind4.rs @@ -16,6 +16,264 @@ diff_op_1d!( true ); +macro_rules! diff_simd_row_7_47 { + ($self: ident, $name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { + impl $self { + #[inline(never)] + fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { + use packed_simd::{f32x8, u32x8}; + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len()); + assert!(prev.len() >= f32x8::lanes()); + // The prev array must have contigous last dimension + assert_eq!(prev.strides()[1], 1); + + let nx = prev.len_of(Axis(1)); + let dx = 1.0 / (nx - 1) as f32; + let idx = 1.0 / dx; + + for j in 0..prev.len_of(Axis(0)) { + use std::slice; + let prev = + unsafe { slice::from_raw_parts(prev.uget((j, 0)) as *const f32, nx) }; + let fut = + unsafe { slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut f32, nx) }; + //let mut fut = fut.slice_mut(s![j, ..]); + + let first_elems = unsafe { f32x8::from_slice_unaligned_unchecked(prev) }; + let block = { + let bl = $BLOCK; + [ + f32x8::new( + bl[0][0], bl[0][1], bl[0][2], bl[0][3], bl[0][4], bl[0][5], + bl[0][6], 0.0, + ), + f32x8::new( + bl[1][0], bl[1][1], bl[1][2], bl[1][3], bl[1][4], bl[1][5], + bl[1][6], 0.0, + ), + f32x8::new( + bl[2][0], bl[2][1], bl[2][2], bl[2][3], bl[2][4], bl[2][5], + bl[2][6], 0.0, + ), + f32x8::new( + bl[3][0], bl[3][1], bl[3][2], bl[3][3], bl[3][4], bl[3][5], + bl[3][6], 0.0, + ), + ] + }; + fut[0] = idx * (block[0] * first_elems).sum(); + fut[1] = idx * (block[1] * first_elems).sum(); + fut[2] = idx * (block[2] * first_elems).sum(); + fut[3] = idx * (block[3] * first_elems).sum(); + + let diag = { + let diag = $DIAG; + f32x8::new( + diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, + ) + }; + for (f, p) in fut + .iter_mut() + .skip(block.len()) + .zip( + prev.windows(f32x8::lanes()) + .map(f32x8::from_slice_unaligned) + .skip(1), + ) + .take(nx - 2 * block.len()) + { + *f = idx * (p * diag).sum(); + } + + let last_elems = + unsafe { f32x8::from_slice_unaligned_unchecked(&prev[nx - 8..]) } + .shuffle1_dyn(u32x8::new(7, 6, 5, 4, 3, 2, 1, 0)); + if $symmetric { + fut[nx - 4] = idx * (block[3] * last_elems).sum(); + fut[nx - 3] = idx * (block[2] * last_elems).sum(); + fut[nx - 2] = idx * (block[1] * last_elems).sum(); + fut[nx - 1] = idx * (block[0] * last_elems).sum(); + } else { + fut[nx - 4] = -idx * (block[3] * last_elems).sum(); + fut[nx - 3] = -idx * (block[2] * last_elems).sum(); + fut[nx - 2] = -idx * (block[1] * last_elems).sum(); + fut[nx - 1] = -idx * (block[0] * last_elems).sum(); + } + } + } + } + }; +} + +diff_simd_row_7_47!(Upwind4, diff_simd_row, Upwind4::BLOCK, Upwind4::DIAG, false); +diff_simd_row_7_47!( + Upwind4, + diss_simd_row, + Upwind4::DISS_BLOCK, + Upwind4::DISS_DIAG, + true +); + +macro_rules! diff_simd_col_7_47 { + ($self: ident, $name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { + impl $self { + #[inline(never)] + fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { + use std::slice; + assert_eq!(prev.shape(), fut.shape()); + assert_eq!(prev.stride_of(Axis(0)), 1); + assert_eq!(prev.stride_of(Axis(0)), 1); + let ny = prev.len_of(Axis(0)); + let nx = prev.len_of(Axis(1)); + assert!(nx >= 2 * $BLOCK.len()); + assert!(ny >= SimdT::lanes()); + assert!(ny % SimdT::lanes() == 0); + + let dx = 1.0 / (nx - 1) as f32; + let idx = 1.0 / dx; + + for j in (0..ny).step_by(SimdT::lanes()) { + let a = unsafe { + [ + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 0)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 1)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 2)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 3)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 4)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 5)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 6)) as *const f32, + SimdT::lanes(), + )), + ] + }; + + for (i, bl) in $BLOCK.iter().enumerate() { + let b = idx + * (a[0] * bl[0] + + a[1] * bl[1] + + a[2] * bl[2] + + a[3] * bl[3] + + a[4] * bl[4] + + a[5] * bl[5] + + a[6] * bl[6]); + unsafe { + b.write_to_slice_unaligned(slice::from_raw_parts_mut( + fut.uget_mut((j, i)) as *mut f32, + SimdT::lanes(), + )); + } + } + + let mut a = a; + for i in $BLOCK.len()..nx - $BLOCK.len() { + // Push a onto circular buffer + a = [a[1], a[2], a[3], a[4], a[5], a[6], unsafe { + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, i + 3)) as *const f32, + SimdT::lanes(), + )) + }]; + let b = idx + * (a[0] * $DIAG[0] + + a[1] * $DIAG[1] + + a[2] * $DIAG[2] + + a[3] * $DIAG[3] + + a[4] * $DIAG[4] + + a[5] * $DIAG[5] + + a[6] * $DIAG[6]); + unsafe { + b.write_to_slice_unaligned(slice::from_raw_parts_mut( + fut.uget_mut((j, i)) as *mut f32, + SimdT::lanes(), + )); + } + } + + let a = unsafe { + [ + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 1)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 2)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 3)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 4)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 5)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 6)) as *const f32, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 7)) as *const f32, + SimdT::lanes(), + )), + ] + }; + + for (i, bl) in $BLOCK.iter().enumerate() { + let idx = if $symmetric { idx } else { -idx }; + let b = idx + * (a[0] * bl[0] + + a[1] * bl[1] + + a[2] * bl[2] + + a[3] * bl[3] + + a[4] * bl[4] + + a[5] * bl[5] + + a[6] * bl[6]); + unsafe { + b.write_to_slice_unaligned(slice::from_raw_parts_mut( + fut.slice_mut(s![j.., nx - 1 - i]).as_mut_ptr(), + SimdT::lanes(), + )); + } + } + } + } + } + }; +} + +diff_simd_col_7_47!(Upwind4, diff_simd_col, Upwind4::BLOCK, Upwind4::DIAG, false); +diff_simd_col_7_47!( + Upwind4, + diss_simd_col, + Upwind4::DISS_BLOCK, + Upwind4::DISS_DIAG, + true +); + impl Upwind4 { #[rustfmt::skip] const HBLOCK: &'static [f32] = &[ @@ -45,392 +303,6 @@ impl Upwind4 { const DISS_DIAG: &'static [f32; 7] = &[ 1.0 / 24.0, -1.0 / 4.0, 5.0 / 8.0, -5.0 / 6.0, 5.0 / 8.0, -1.0 / 4.0, 1.0 / 24.0 ]; - - #[inline(never)] - fn diff_simd_row(prev: ArrayView2, mut fut: ArrayViewMut2) { - use packed_simd::{f32x8, u32x8}; - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.len_of(Axis(1)) >= 2 * Self::BLOCK.len()); - assert!(prev.len() >= f32x8::lanes()); - // The prev array must have contigous last dimension - assert_eq!(prev.strides()[1], 1); - - let nx = prev.len_of(Axis(1)); - let dx = 1.0 / (nx - 1) as f32; - let idx = 1.0 / dx; - - for j in 0..prev.len_of(Axis(0)) { - use std::slice; - let prev = unsafe { slice::from_raw_parts(prev.uget((j, 0)) as *const f32, nx) }; - let fut = unsafe { slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut f32, nx) }; - //let mut fut = fut.slice_mut(s![j, ..]); - - let first_elems = unsafe { f32x8::from_slice_unaligned_unchecked(prev) }; - let block = { - let bl = Self::BLOCK; - [ - f32x8::new( - bl[0][0], bl[0][1], bl[0][2], bl[0][3], bl[0][4], bl[0][5], bl[0][6], 0.0, - ), - f32x8::new( - bl[1][0], bl[1][1], bl[1][2], bl[1][3], bl[1][4], bl[1][5], bl[1][6], 0.0, - ), - f32x8::new( - bl[2][0], bl[2][1], bl[2][2], bl[2][3], bl[2][4], bl[2][5], bl[2][6], 0.0, - ), - f32x8::new( - bl[3][0], bl[3][1], bl[3][2], bl[3][3], bl[3][4], bl[3][5], bl[3][6], 0.0, - ), - ] - }; - fut[0] = idx * (block[0] * first_elems).sum(); - fut[1] = idx * (block[1] * first_elems).sum(); - fut[2] = idx * (block[2] * first_elems).sum(); - fut[3] = idx * (block[3] * first_elems).sum(); - - let diag = { - let diag = Self::DIAG; - f32x8::new( - diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, - ) - }; - for (f, p) in fut - .iter_mut() - .skip(block.len()) - .zip( - prev.windows(f32x8::lanes()) - .map(f32x8::from_slice_unaligned) - .skip(1), - ) - .take(nx - 2 * block.len()) - { - *f = idx * (p * diag).sum(); - } - - let last_elems = unsafe { f32x8::from_slice_unaligned_unchecked(&prev[nx - 8..]) } - .shuffle1_dyn(u32x8::new(7, 6, 5, 4, 3, 2, 1, 0)); - fut[nx - 4] = -idx * (block[3] * last_elems).sum(); - fut[nx - 3] = -idx * (block[2] * last_elems).sum(); - fut[nx - 2] = -idx * (block[1] * last_elems).sum(); - fut[nx - 1] = -idx * (block[0] * last_elems).sum(); - } - } - - #[inline(never)] - fn diff_simd_col(prev: ArrayView2, mut fut: ArrayViewMut2) { - use std::slice; - assert_eq!(prev.shape(), fut.shape()); - assert_eq!(prev.stride_of(Axis(0)), 1); - assert_eq!(prev.stride_of(Axis(0)), 1); - let ny = prev.len_of(Axis(0)); - let nx = prev.len_of(Axis(1)); - assert!(nx >= 2 * Self::BLOCK.len()); - assert!(ny >= SimdT::lanes()); - assert!(ny % SimdT::lanes() == 0); - - let dx = 1.0 / (nx - 1) as f32; - let idx = 1.0 / dx; - - for j in (0..ny).step_by(SimdT::lanes()) { - let a = unsafe { - [ - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 0)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 1)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 2)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 3)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 4)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 5)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 6)) as *const f32, - SimdT::lanes(), - )), - ] - }; - - for (i, bl) in Self::BLOCK.iter().enumerate() { - let b = idx - * (a[0] * bl[0] - + a[1] * bl[1] - + a[2] * bl[2] - + a[3] * bl[3] - + a[4] * bl[4] - + a[5] * bl[5] - + a[6] * bl[6]); - unsafe { - b.write_to_slice_unaligned(slice::from_raw_parts_mut( - fut.uget_mut((j, i)) as *mut f32, - SimdT::lanes(), - )); - } - } - - let mut a = a; - for i in Self::BLOCK.len()..nx - Self::BLOCK.len() { - // Push a onto circular buffer - a = [a[1], a[2], a[3], a[4], a[5], a[6], unsafe { - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, i + 3)) as *const f32, - SimdT::lanes(), - )) - }]; - let b = idx - * (a[0] * Self::DIAG[0] - + a[1] * Self::DIAG[1] - + a[2] * Self::DIAG[2] - + a[3] * Self::DIAG[3] - + a[4] * Self::DIAG[4] - + a[5] * Self::DIAG[5] - + a[6] * Self::DIAG[6]); - unsafe { - b.write_to_slice_unaligned(slice::from_raw_parts_mut( - fut.uget_mut((j, i)) as *mut f32, - SimdT::lanes(), - )); - } - } - - let a = unsafe { - [ - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 1)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 2)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 3)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 4)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 5)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 6)) as *const f32, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 7)) as *const f32, - SimdT::lanes(), - )), - ] - }; - - for (i, bl) in Self::BLOCK.iter().enumerate() { - let b = -idx - * (a[0] * bl[0] - + a[1] * bl[1] - + a[2] * bl[2] - + a[3] * bl[3] - + a[4] * bl[4] - + a[5] * bl[5] - + a[6] * bl[6]); - unsafe { - b.write_to_slice_unaligned(slice::from_raw_parts_mut( - fut.slice_mut(s![j.., nx - 1 - i]).as_mut_ptr(), - SimdT::lanes(), - )); - } - } - } - } - - #[inline(never)] - fn diss_simd(prev: &[f32], fut: &mut [f32]) { - use packed_simd::{f32x8, u32x8}; - assert_eq!(prev.len(), fut.len()); - assert!(prev.len() >= 2 * Self::DISS_BLOCK.len()); - let nx = prev.len(); - let dx = 1.0 / (nx - 1) as f32; - let idx = 1.0 / dx; - - let first_elems = unsafe { f32x8::from_slice_unaligned_unchecked(prev) }; - let block = [ - f32x8::new( - Self::DISS_BLOCK[0][0], - Self::DISS_BLOCK[0][1], - Self::DISS_BLOCK[0][2], - Self::DISS_BLOCK[0][3], - Self::DISS_BLOCK[0][4], - Self::DISS_BLOCK[0][5], - Self::DISS_BLOCK[0][6], - 0.0, - ), - f32x8::new( - Self::DISS_BLOCK[1][0], - Self::DISS_BLOCK[1][1], - Self::DISS_BLOCK[1][2], - Self::DISS_BLOCK[1][3], - Self::DISS_BLOCK[1][4], - Self::DISS_BLOCK[1][5], - Self::DISS_BLOCK[1][6], - 0.0, - ), - f32x8::new( - Self::DISS_BLOCK[2][0], - Self::DISS_BLOCK[2][1], - Self::DISS_BLOCK[2][2], - Self::DISS_BLOCK[2][3], - Self::DISS_BLOCK[2][4], - Self::DISS_BLOCK[2][5], - Self::DISS_BLOCK[2][6], - 0.0, - ), - f32x8::new( - Self::DISS_BLOCK[3][0], - Self::DISS_BLOCK[3][1], - Self::DISS_BLOCK[3][2], - Self::DISS_BLOCK[3][3], - Self::DISS_BLOCK[3][4], - Self::DISS_BLOCK[3][5], - Self::DISS_BLOCK[3][6], - 0.0, - ), - ]; - unsafe { - *fut.get_unchecked_mut(0) = idx * (block[0] * first_elems).sum(); - *fut.get_unchecked_mut(1) = idx * (block[1] * first_elems).sum(); - *fut.get_unchecked_mut(2) = idx * (block[2] * first_elems).sum(); - *fut.get_unchecked_mut(3) = idx * (block[3] * first_elems).sum() - }; - - let diag = f32x8::new( - Self::DISS_DIAG[0], - Self::DISS_DIAG[1], - Self::DISS_DIAG[2], - Self::DISS_DIAG[3], - Self::DISS_DIAG[4], - Self::DISS_DIAG[5], - Self::DISS_DIAG[6], - 0.0, - ); - for (f, p) in fut - .iter_mut() - .skip(block.len()) - .zip( - prev.windows(f32x8::lanes()) - .map(f32x8::from_slice_unaligned) - .skip(1), - ) - .take(nx - 2 * block.len()) - { - *f = idx * (p * diag).sum(); - } - - let last_elems = unsafe { f32x8::from_slice_unaligned_unchecked(&prev[nx - 8..]) } - .shuffle1_dyn(u32x8::new(7, 6, 5, 4, 3, 2, 1, 0)); - unsafe { - *fut.get_unchecked_mut(nx - 4) = idx * (block[3] * last_elems).sum(); - *fut.get_unchecked_mut(nx - 3) = idx * (block[2] * last_elems).sum(); - *fut.get_unchecked_mut(nx - 2) = idx * (block[1] * last_elems).sum(); - *fut.get_unchecked_mut(nx - 1) = idx * (block[0] * last_elems).sum(); - } - } - - #[inline(never)] - fn disseta_simd(prev: &[f32], fut: &mut [f32], nx: usize, ny: usize) { - assert!(ny >= 2 * Self::DISS_BLOCK.len()); - assert!(nx >= SimdT::lanes()); - assert!(nx % SimdT::lanes() == 0); - assert_eq!(prev.len(), fut.len()); - assert_eq!(prev.len(), nx * ny); - - let dy = 1.0 / (ny - 1) as f32; - let idy = 1.0 / dy; - - for j in (0..nx).step_by(SimdT::lanes()) { - let a = [ - SimdT::from_slice_unaligned(&prev[0 * nx + j..]), - SimdT::from_slice_unaligned(&prev[1 * nx + j..]), - SimdT::from_slice_unaligned(&prev[2 * nx + j..]), - SimdT::from_slice_unaligned(&prev[3 * nx + j..]), - SimdT::from_slice_unaligned(&prev[4 * nx + j..]), - SimdT::from_slice_unaligned(&prev[5 * nx + j..]), - SimdT::from_slice_unaligned(&prev[6 * nx + j..]), - ]; - - for (i, bl) in Self::DISS_BLOCK.iter().enumerate() { - let b = idy - * (a[0] * bl[0] - + a[1] * bl[1] - + a[2] * bl[2] - + a[3] * bl[3] - + a[4] * bl[4] - + a[5] * bl[5] - + a[6] * bl[6]); - b.write_to_slice_unaligned(&mut fut[i * nx + j..]); - } - - let mut a = a; - for i in Self::DISS_BLOCK.len()..ny - Self::DISS_BLOCK.len() { - // Push a onto circular buffer - a = [ - a[1], - a[2], - a[3], - a[4], - a[5], - a[6], - SimdT::from_slice_unaligned(&prev[nx * (i + 3) + j..]), - ]; - let b = idy - * (a[0] * Self::DISS_DIAG[0] - + a[1] * Self::DISS_DIAG[1] - + a[2] * Self::DISS_DIAG[2] - + a[3] * Self::DISS_DIAG[3] - + a[4] * Self::DISS_DIAG[4] - + a[5] * Self::DISS_DIAG[5] - + a[6] * Self::DISS_DIAG[6]); - b.write_to_slice_unaligned(&mut fut[nx * i + j..]); - } - - let a = [ - SimdT::from_slice_unaligned(&prev[(ny - 1) * nx + j..]), - SimdT::from_slice_unaligned(&prev[(ny - 2) * nx + j..]), - SimdT::from_slice_unaligned(&prev[(ny - 3) * nx + j..]), - SimdT::from_slice_unaligned(&prev[(ny - 4) * nx + j..]), - SimdT::from_slice_unaligned(&prev[(ny - 5) * nx + j..]), - SimdT::from_slice_unaligned(&prev[(ny - 6) * nx + j..]), - SimdT::from_slice_unaligned(&prev[(ny - 7) * nx + j..]), - ]; - - for (i, bl) in Self::DISS_BLOCK.iter().enumerate() { - let b = idy - * (a[0] * bl[0] - + a[1] * bl[1] - + a[2] * bl[2] - + a[3] * bl[3] - + a[4] * bl[4] - + a[5] * bl[5] - + a[6] * bl[6]); - b.write_to_slice_unaligned(&mut fut[(ny - 1 - i) * nx + j..]); - } - } - } } impl SbpOperator for Upwind4 { @@ -556,24 +428,27 @@ fn upwind4_test() { impl UpwindOperator for Upwind4 { fn dissxi(prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::DISS_BLOCK.len()); - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diss_1d(r0, r1) + assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, _]) => { + Self::diss_simd_row(prev, fut); + } + ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { + Self::diss_simd_col(prev, fut); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Self::diss_1d(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), } } - fn disseta(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[0] >= 2 * Self::DISS_BLOCK.len()); - let nx = prev.shape()[1]; - let ny = prev.shape()[0]; - if nx >= SimdT::lanes() && nx % SimdT::lanes() == 0 { - if let (Some(p), Some(f)) = (prev.as_slice(), fut.as_slice_mut()) { - Self::disseta_simd(p, f, nx, ny); - return; - } - } - // diffeta = transpose then use diffxi + fn disseta(prev: ArrayView2, fut: ArrayViewMut2) { + // diffeta = transpose then use dissxi Self::dissxi(prev.reversed_axes(), fut.reversed_axes()); } }