diff --git a/src/operators/upwind4.rs b/src/operators/upwind4.rs index 14136d2..a8330b0 100644 --- a/src/operators/upwind4.rs +++ b/src/operators/upwind4.rs @@ -1,6 +1,9 @@ use super::SbpOperator; use ndarray::{arr1, arr2, s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; +/// Simdtype used in diffy_simd +type SimdT = packed_simd::f32x8; + pub struct Upwind4 {} impl Upwind4 { @@ -57,7 +60,7 @@ impl Upwind4 { fn diff_simd(prev: &[f32], fut: &mut [f32]) { use packed_simd::{f32x8, u32x8}; assert_eq!(prev.len(), fut.len()); - assert_eq!(prev.len() % 8, 0); + assert!(prev.len() >= 2 * Self::BLOCK.len()); let nx = prev.len(); let dx = 1.0 / (nx - 1) as f32; let idx = 1.0 / dx; @@ -147,25 +150,24 @@ impl Upwind4 { #[inline(never)] fn diffy_simd(prev: &[f32], fut: &mut [f32], nx: usize, ny: usize) { - use packed_simd::f32x4; - assert!(ny >= 8); - assert!(nx > 4); - assert!(nx % 4 == 0); + assert!(ny >= 2 * Self::BLOCK.len()); + assert!(nx >= SimdT::lanes()); + assert!(nx % SimdT::lanes() == 0); assert_eq!(prev.len(), fut.len()); assert_eq!(prev.len(), nx * ny); let dy = 1.0 / (ny - 1) as f32; let idy = 1.0 / dy; - for j in (0..nx).step_by(4) { + for j in (0..nx).step_by(SimdT::lanes()) { let a = [ - f32x4::from_slice_unaligned(&prev[0 * nx + j..]), - f32x4::from_slice_unaligned(&prev[1 * nx + j..]), - f32x4::from_slice_unaligned(&prev[2 * nx + j..]), - f32x4::from_slice_unaligned(&prev[3 * nx + j..]), - f32x4::from_slice_unaligned(&prev[4 * nx + j..]), - f32x4::from_slice_unaligned(&prev[5 * nx + j..]), - f32x4::from_slice_unaligned(&prev[6 * nx + j..]), + SimdT::from_slice_unaligned(&prev[0 * nx + j..]), + SimdT::from_slice_unaligned(&prev[1 * nx + j..]), + SimdT::from_slice_unaligned(&prev[2 * nx + j..]), + SimdT::from_slice_unaligned(&prev[3 * nx + j..]), + SimdT::from_slice_unaligned(&prev[4 * nx + j..]), + SimdT::from_slice_unaligned(&prev[5 * nx + j..]), + SimdT::from_slice_unaligned(&prev[6 * nx + j..]), ]; for (i, bl) in Self::BLOCK.iter().enumerate() { @@ -190,7 +192,7 @@ impl Upwind4 { a[4], a[5], a[6], - f32x4::from_slice_unaligned(&prev[nx * (i + 3) + j..]), + SimdT::from_slice_unaligned(&prev[nx * (i + 3) + j..]), ]; let b = idy * (a[0] * Self::DIAG[0] @@ -204,13 +206,13 @@ impl Upwind4 { } let a = [ - f32x4::from_slice_unaligned(&prev[(ny - 1) * nx + j..]), - f32x4::from_slice_unaligned(&prev[(ny - 2) * nx + j..]), - f32x4::from_slice_unaligned(&prev[(ny - 3) * nx + j..]), - f32x4::from_slice_unaligned(&prev[(ny - 4) * nx + j..]), - f32x4::from_slice_unaligned(&prev[(ny - 5) * nx + j..]), - f32x4::from_slice_unaligned(&prev[(ny - 6) * nx + j..]), - f32x4::from_slice_unaligned(&prev[(ny - 7) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 1) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 2) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 3) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 4) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 5) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 6) * nx + j..]), + SimdT::from_slice_unaligned(&prev[(ny - 7) * nx + j..]), ]; for (i, bl) in Self::BLOCK.iter().enumerate() { @@ -230,13 +232,11 @@ impl Upwind4 { fn diff(prev: ArrayView1, mut fut: ArrayViewMut1) { assert_eq!(prev.shape(), fut.shape()); let nx = prev.shape()[0]; - assert!(nx >= 8); + assert!(nx >= 2 * Self::BLOCK.len()); - if nx % 8 == 0 { - if let (Some(p), Some(f)) = (prev.as_slice(), fut.as_slice_mut()) { - Self::diff_simd(p, f); - return; - } + if let (Some(p), Some(f)) = (prev.as_slice(), fut.as_slice_mut()) { + Self::diff_simd(p, f); + return; } let dx = 1.0 / (nx - 1) as f32; @@ -273,7 +273,7 @@ impl Upwind4 { impl SbpOperator for Upwind4 { fn diffx(prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 8); + assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { Self::diff(r0, r1) } @@ -281,10 +281,10 @@ impl SbpOperator for Upwind4 { fn diffy(prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[0] >= 8); + assert!(prev.shape()[0] >= 2 * Self::BLOCK.len()); let nx = prev.shape()[1]; let ny = prev.shape()[0]; - if nx >= 4 && nx % 4 == 0 { + if nx >= SimdT::lanes() && nx % SimdT::lanes() == 0 { if let (Some(p), Some(f)) = (prev.as_slice(), fut.as_slice_mut()) { Self::diffy_simd(p, f, nx, ny); return;