From 177a6abd994d4a249e6b1d069b281cc940ec1a3a Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Thu, 30 Apr 2020 23:37:51 +0200 Subject: [PATCH] use generalised diff_op_row for Upwind4 --- sbp/src/operators/upwind4.rs | 90 ++---------------------------------- 1 file changed, 4 insertions(+), 86 deletions(-) diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 6e9afb5..9e4bda4 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -1,98 +1,16 @@ -use super::{SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; +use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis}; #[derive(Debug, Copy, Clone)] pub struct Upwind4; -/// Simdtype used in diff_simd_col and diff_simd_row +/// Simdtype used in diff_simd_col #[cfg(feature = "f32")] type SimdT = packed_simd::f32x8; #[cfg(not(feature = "f32"))] type SimdT = packed_simd::f64x8; -macro_rules! diff_simd_row_7_47 { - ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { - #[inline(never)] - fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len()); - assert!(prev.len() >= SimdT::lanes()); - // The prev and fut array must have contiguous last dimension - assert_eq!(prev.strides()[1], 1); - assert_eq!(fut.strides()[1], 1); - - let nx = prev.len_of(Axis(1)); - let dx = 1.0 / (nx - 1) as Float; - let idx = 1.0 / dx; - - for j in 0..prev.len_of(Axis(0)) { - use std::slice; - let prev = unsafe { slice::from_raw_parts(prev.uget((j, 0)) as *const Float, nx) }; - let fut = - unsafe { slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx) }; - - let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) }; - let block = { - let bl = $BLOCK; - [ - SimdT::new( - bl[0][0], bl[0][1], bl[0][2], bl[0][3], bl[0][4], bl[0][5], bl[0][6], - 0.0, - ), - SimdT::new( - bl[1][0], bl[1][1], bl[1][2], bl[1][3], bl[1][4], bl[1][5], bl[1][6], - 0.0, - ), - SimdT::new( - bl[2][0], bl[2][1], bl[2][2], bl[2][3], bl[2][4], bl[2][5], bl[2][6], - 0.0, - ), - SimdT::new( - bl[3][0], bl[3][1], bl[3][2], bl[3][3], bl[3][4], bl[3][5], bl[3][6], - 0.0, - ), - ] - }; - fut[0] = idx * (block[0] * first_elems).sum(); - fut[1] = idx * (block[1] * first_elems).sum(); - fut[2] = idx * (block[2] * first_elems).sum(); - fut[3] = idx * (block[3] * first_elems).sum(); - - let diag = { - let diag = $DIAG; - SimdT::new( - diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, - ) - }; - for i in 4..nx - 4 { - unsafe { - let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]); - *fut.get_unchecked_mut(i) = idx * (prev * diag).sum(); - } - } - - let last_elems = unsafe { SimdT::from_slice_unaligned_unchecked(&prev[nx - 8..]) } - .shuffle1_dyn([7, 6, 5, 4, 3, 2, 1, 0].into()); - if $symmetric { - fut[nx - 4] = idx * (block[3] * last_elems).sum(); - fut[nx - 3] = idx * (block[2] * last_elems).sum(); - fut[nx - 2] = idx * (block[1] * last_elems).sum(); - fut[nx - 1] = idx * (block[0] * last_elems).sum(); - } else { - fut[nx - 4] = -idx * (block[3] * last_elems).sum(); - fut[nx - 3] = -idx * (block[2] * last_elems).sum(); - fut[nx - 2] = -idx * (block[1] * last_elems).sum(); - fut[nx - 1] = -idx * (block[0] * last_elems).sum(); - } - } - } - }; -} - -diff_simd_row_7_47!(diff_simd_row, Upwind4::BLOCK, Upwind4::DIAG, false); -diff_simd_row_7_47!(diss_simd_row, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true); - macro_rules! diff_simd_col_7_47 { ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { #[inline(never)] @@ -291,7 +209,7 @@ impl SbpOperator2d for (&SBP, &Upwind4) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_simd_row(prev, fut); + diff_op_row(Upwind4::BLOCK, Upwind4::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { diff_simd_col(prev, fut); @@ -412,7 +330,7 @@ impl UpwindOperator2d for (&UO, &Upwind4) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diss_simd_row(prev, fut); + diff_op_row(Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true, false)(prev, fut); } ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { diss_simd_col(prev, fut);