From c3a40d81eeb9300a4d08bdb616e28b5947b90279 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Wed, 8 Apr 2020 23:07:14 +0200 Subject: [PATCH] use default trait methods --- sbp/src/operators.rs | 104 ++++--- sbp/src/operators/traditional4.rs | 18 +- sbp/src/operators/traditional8.rs | 18 +- sbp/src/operators/upwind4.rs | 467 ++++++++++++++---------------- sbp/src/operators/upwind9.rs | 40 +-- 5 files changed, 293 insertions(+), 354 deletions(-) diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index a624f2b..1e6d82f 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -3,68 +3,82 @@ use crate::Float; -use ndarray::{ArrayView2, ArrayViewMut2}; +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; pub trait SbpOperator: Send + Sync { - fn diffxi(prev: ArrayView2, fut: ArrayViewMut2); - fn diffeta(prev: ArrayView2, fut: ArrayViewMut2); + fn diff1d(prev: ArrayView1, fut: ArrayViewMut1); + fn diffxi(prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Self::diff1d(r0, r1); + } + } + fn diffeta(prev: ArrayView2, fut: ArrayViewMut2) { + Self::diffxi(prev.reversed_axes(), fut.reversed_axes()) + } fn h() -> &'static [Float]; } pub trait UpwindOperator: SbpOperator { - fn dissxi(prev: ArrayView2, fut: ArrayViewMut2); - fn disseta(prev: ArrayView2, fut: ArrayViewMut2); + fn diss1d(prev: ArrayView1, fut: ArrayViewMut1); + fn dissxi(prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Self::diss1d(r0, r1); + } + } + fn disseta(prev: ArrayView2, fut: ArrayViewMut2) { + Self::dissxi(prev.reversed_axes(), fut.reversed_axes()) + } } #[macro_export] macro_rules! diff_op_1d { - ($self: ty, $name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { - impl $self { - fn $name(prev: ArrayView1, mut fut: ArrayViewMut1) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[0]; - assert!(nx >= 2 * $BLOCK.len()); + ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { + fn $name(prev: ArrayView1, mut fut: ArrayViewMut1) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[0]; + assert!(nx >= 2 * $BLOCK.len()); - let dx = 1.0 / (nx - 1) as Float; - let idx = 1.0 / dx; + let dx = 1.0 / (nx - 1) as Float; + let idx = 1.0 / dx; - let block = ::ndarray::arr2($BLOCK); - let diag = ::ndarray::arr1($DIAG); + let block = ::ndarray::arr2($BLOCK); + let diag = ::ndarray::arr1($DIAG); - let first_elems = prev.slice(::ndarray::s!(..block.len_of(::ndarray::Axis(1)))); - for (bl, f) in block.outer_iter().zip(&mut fut) { - let diff = first_elems.dot(&bl); - *f = diff * idx; - } + let first_elems = prev.slice(::ndarray::s!(..block.len_of(::ndarray::Axis(1)))); + for (bl, f) in block.outer_iter().zip(&mut fut) { + let diff = first_elems.dot(&bl); + *f = diff * idx; + } - // The window needs to be aligned to the diagonal elements, - // based on the block size - let window_elems_to_skip = - block.len_of(::ndarray::Axis(0)) - ((diag.len() - 1) / 2); + // The window needs to be aligned to the diagonal elements, + // based on the block size + let window_elems_to_skip = + block.len_of(::ndarray::Axis(0)) - ((diag.len() - 1) / 2); - for (window, f) in prev - .windows(diag.len()) - .into_iter() - .skip(window_elems_to_skip) - .zip(fut.iter_mut().skip(block.len_of(::ndarray::Axis(0)))) - .take(nx - 2 * block.len_of(::ndarray::Axis(0))) - { - let diff = diag.dot(&window); - *f = diff * idx; - } + for (window, f) in prev + .windows(diag.len()) + .into_iter() + .skip(window_elems_to_skip) + .zip(fut.iter_mut().skip(block.len_of(::ndarray::Axis(0)))) + .take(nx - 2 * block.len_of(::ndarray::Axis(0))) + { + let diff = diag.dot(&window); + *f = diff * idx; + } - let last_elems = prev.slice(::ndarray::s!(nx - block.len_of(::ndarray::Axis(1))..;-1)); - for (bl, f) in block.outer_iter() - .zip(&mut fut.slice_mut(s![nx - block.len_of(::ndarray::Axis(0))..;-1])) - { - let diff = if $symmetric { - bl.dot(&last_elems) - } else { - -bl.dot(&last_elems) - }; - *f = diff * idx; - } + let last_elems = prev.slice(::ndarray::s!(nx - block.len_of(::ndarray::Axis(1))..;-1)); + for (bl, f) in block.outer_iter() + .zip(&mut fut.slice_mut(s![nx - block.len_of(::ndarray::Axis(0))..;-1])) + { + let diff = if $symmetric { + bl.dot(&last_elems) + } else { + -bl.dot(&last_elems) + }; + *f = diff * idx; } } }; diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index 91706d4..e856f8f 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -1,12 +1,12 @@ use super::SbpOperator; use crate::diff_op_1d; use crate::Float; -use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; +use ndarray::{s, ArrayView1, ArrayViewMut1}; #[derive(Debug)] pub struct SBP4 {} -diff_op_1d!(SBP4, diff_1d, SBP4::BLOCK, SBP4::DIAG, false); +diff_op_1d!(diff_1d, SBP4::BLOCK, SBP4::DIAG, false); impl SBP4 { #[rustfmt::skip] @@ -27,18 +27,8 @@ impl SBP4 { } impl SbpOperator for SBP4 { - fn diffxi(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); - - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diff_1d(r0, r1); - } - } - - fn diffeta(prev: ArrayView2, fut: ArrayViewMut2) { - // transpose then use diffxi - Self::diffxi(prev.reversed_axes(), fut.reversed_axes()); + fn diff1d(prev: ArrayView1, fut: ArrayViewMut1) { + diff_1d(prev, fut) } fn h() -> &'static [Float] { diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index 856e367..3b2f44f 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -1,12 +1,12 @@ use super::SbpOperator; use crate::diff_op_1d; use crate::Float; -use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; +use ndarray::{s, ArrayView1, ArrayViewMut1}; #[derive(Debug)] pub struct SBP8 {} -diff_op_1d!(SBP8, diff_1d, SBP8::BLOCK, SBP8::DIAG, false); +diff_op_1d!(diff_1d, SBP8::BLOCK, SBP8::DIAG, false); impl SBP8 { #[rustfmt::skip] @@ -31,18 +31,8 @@ impl SBP8 { } impl SbpOperator for SBP8 { - fn diffxi(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); - - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diff_1d(r0, r1); - } - } - - fn diffeta(prev: ArrayView2, fut: ArrayViewMut2) { - // transpose then use diffxi - Self::diffxi(prev.reversed_axes(), fut.reversed_axes()); + fn diff1d(prev: ArrayView1, fut: ArrayViewMut1) { + diff_1d(prev, fut) } fn h() -> &'static [Float] { diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 0de501b..acb79e6 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -12,251 +12,232 @@ type SimdT = packed_simd::f32x8; #[cfg(not(feature = "f32"))] type SimdT = packed_simd::f64x8; -diff_op_1d!(Upwind4, diff_1d, Upwind4::BLOCK, Upwind4::DIAG, false); -diff_op_1d!( - Upwind4, - diss_1d, - Upwind4::DISS_BLOCK, - Upwind4::DISS_DIAG, - true -); +diff_op_1d!(diff_1d, Upwind4::BLOCK, Upwind4::DIAG, false); +diff_op_1d!(diss_1d, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true); macro_rules! diff_simd_row_7_47 { - ($self: ident, $name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { - impl $self { - #[inline(never)] - fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len()); - assert!(prev.len() >= SimdT::lanes()); - // The prev and fut array must have contiguous last dimension - assert_eq!(prev.strides()[1], 1); - assert_eq!(fut.strides()[1], 1); + ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { + #[inline(never)] + fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len()); + assert!(prev.len() >= SimdT::lanes()); + // The prev and fut array must have contiguous last dimension + assert_eq!(prev.strides()[1], 1); + assert_eq!(fut.strides()[1], 1); - let nx = prev.len_of(Axis(1)); - let dx = 1.0 / (nx - 1) as Float; - let idx = 1.0 / dx; + let nx = prev.len_of(Axis(1)); + let dx = 1.0 / (nx - 1) as Float; + let idx = 1.0 / dx; - for j in 0..prev.len_of(Axis(0)) { - use std::slice; - let prev = - unsafe { slice::from_raw_parts(prev.uget((j, 0)) as *const Float, nx) }; - let fut = unsafe { - slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx) - }; + for j in 0..prev.len_of(Axis(0)) { + use std::slice; + let prev = unsafe { slice::from_raw_parts(prev.uget((j, 0)) as *const Float, nx) }; + let fut = + unsafe { slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx) }; - let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) }; - let block = { - let bl = $BLOCK; - [ - SimdT::new( - bl[0][0], bl[0][1], bl[0][2], bl[0][3], bl[0][4], bl[0][5], - bl[0][6], 0.0, - ), - SimdT::new( - bl[1][0], bl[1][1], bl[1][2], bl[1][3], bl[1][4], bl[1][5], - bl[1][6], 0.0, - ), - SimdT::new( - bl[2][0], bl[2][1], bl[2][2], bl[2][3], bl[2][4], bl[2][5], - bl[2][6], 0.0, - ), - SimdT::new( - bl[3][0], bl[3][1], bl[3][2], bl[3][3], bl[3][4], bl[3][5], - bl[3][6], 0.0, - ), - ] - }; - fut[0] = idx * (block[0] * first_elems).sum(); - fut[1] = idx * (block[1] * first_elems).sum(); - fut[2] = idx * (block[2] * first_elems).sum(); - fut[3] = idx * (block[3] * first_elems).sum(); - - let diag = { - let diag = $DIAG; + let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) }; + let block = { + let bl = $BLOCK; + [ SimdT::new( - diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, - ) - }; - for i in 4..nx - 4 { - unsafe { - let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]); - *fut.get_unchecked_mut(i) = idx * (prev * diag).sum(); - } - } + bl[0][0], bl[0][1], bl[0][2], bl[0][3], bl[0][4], bl[0][5], bl[0][6], + 0.0, + ), + SimdT::new( + bl[1][0], bl[1][1], bl[1][2], bl[1][3], bl[1][4], bl[1][5], bl[1][6], + 0.0, + ), + SimdT::new( + bl[2][0], bl[2][1], bl[2][2], bl[2][3], bl[2][4], bl[2][5], bl[2][6], + 0.0, + ), + SimdT::new( + bl[3][0], bl[3][1], bl[3][2], bl[3][3], bl[3][4], bl[3][5], bl[3][6], + 0.0, + ), + ] + }; + fut[0] = idx * (block[0] * first_elems).sum(); + fut[1] = idx * (block[1] * first_elems).sum(); + fut[2] = idx * (block[2] * first_elems).sum(); + fut[3] = idx * (block[3] * first_elems).sum(); - let last_elems = - unsafe { SimdT::from_slice_unaligned_unchecked(&prev[nx - 8..]) } - .shuffle1_dyn([7, 6, 5, 4, 3, 2, 1, 0].into()); - if $symmetric { - fut[nx - 4] = idx * (block[3] * last_elems).sum(); - fut[nx - 3] = idx * (block[2] * last_elems).sum(); - fut[nx - 2] = idx * (block[1] * last_elems).sum(); - fut[nx - 1] = idx * (block[0] * last_elems).sum(); - } else { - fut[nx - 4] = -idx * (block[3] * last_elems).sum(); - fut[nx - 3] = -idx * (block[2] * last_elems).sum(); - fut[nx - 2] = -idx * (block[1] * last_elems).sum(); - fut[nx - 1] = -idx * (block[0] * last_elems).sum(); + let diag = { + let diag = $DIAG; + SimdT::new( + diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, + ) + }; + for i in 4..nx - 4 { + unsafe { + let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]); + *fut.get_unchecked_mut(i) = idx * (prev * diag).sum(); } } + + let last_elems = unsafe { SimdT::from_slice_unaligned_unchecked(&prev[nx - 8..]) } + .shuffle1_dyn([7, 6, 5, 4, 3, 2, 1, 0].into()); + if $symmetric { + fut[nx - 4] = idx * (block[3] * last_elems).sum(); + fut[nx - 3] = idx * (block[2] * last_elems).sum(); + fut[nx - 2] = idx * (block[1] * last_elems).sum(); + fut[nx - 1] = idx * (block[0] * last_elems).sum(); + } else { + fut[nx - 4] = -idx * (block[3] * last_elems).sum(); + fut[nx - 3] = -idx * (block[2] * last_elems).sum(); + fut[nx - 2] = -idx * (block[1] * last_elems).sum(); + fut[nx - 1] = -idx * (block[0] * last_elems).sum(); + } } } }; } -diff_simd_row_7_47!(Upwind4, diff_simd_row, Upwind4::BLOCK, Upwind4::DIAG, false); -diff_simd_row_7_47!( - Upwind4, - diss_simd_row, - Upwind4::DISS_BLOCK, - Upwind4::DISS_DIAG, - true -); +diff_simd_row_7_47!(diff_simd_row, Upwind4::BLOCK, Upwind4::DIAG, false); +diff_simd_row_7_47!(diss_simd_row, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true); macro_rules! diff_simd_col_7_47 { - ($self: ident, $name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { - impl $self { - #[inline(never)] - fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { - use std::slice; - assert_eq!(prev.shape(), fut.shape()); - assert_eq!(prev.stride_of(Axis(0)), 1); - assert_eq!(fut.stride_of(Axis(0)), 1); - let ny = prev.len_of(Axis(0)); - let nx = prev.len_of(Axis(1)); - assert!(nx >= 2 * $BLOCK.len()); - assert!(ny >= SimdT::lanes()); - assert!(ny % SimdT::lanes() == 0); + ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => { + #[inline(never)] + fn $name(prev: ArrayView2, mut fut: ArrayViewMut2) { + use std::slice; + assert_eq!(prev.shape(), fut.shape()); + assert_eq!(prev.stride_of(Axis(0)), 1); + assert_eq!(fut.stride_of(Axis(0)), 1); + let ny = prev.len_of(Axis(0)); + let nx = prev.len_of(Axis(1)); + assert!(nx >= 2 * $BLOCK.len()); + assert!(ny >= SimdT::lanes()); + assert!(ny % SimdT::lanes() == 0); - let dx = 1.0 / (nx - 1) as Float; - let idx = 1.0 / dx; + let dx = 1.0 / (nx - 1) as Float; + let idx = 1.0 / dx; - for j in (0..ny).step_by(SimdT::lanes()) { - let a = unsafe { - [ - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 0)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 1)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 2)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 3)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 4)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 5)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, 6)) as *const Float, - SimdT::lanes(), - )), - ] - }; + for j in (0..ny).step_by(SimdT::lanes()) { + let a = unsafe { + [ + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 0)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 1)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 2)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 3)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 4)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 5)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, 6)) as *const Float, + SimdT::lanes(), + )), + ] + }; - for (i, bl) in $BLOCK.iter().enumerate() { - let b = idx - * (a[0] * bl[0] - + a[1] * bl[1] - + a[2] * bl[2] - + a[3] * bl[3] - + a[4] * bl[4] - + a[5] * bl[5] - + a[6] * bl[6]); - unsafe { - b.write_to_slice_unaligned(slice::from_raw_parts_mut( - fut.uget_mut((j, i)) as *mut Float, - SimdT::lanes(), - )); - } + for (i, bl) in $BLOCK.iter().enumerate() { + let b = idx + * (a[0] * bl[0] + + a[1] * bl[1] + + a[2] * bl[2] + + a[3] * bl[3] + + a[4] * bl[4] + + a[5] * bl[5] + + a[6] * bl[6]); + unsafe { + b.write_to_slice_unaligned(slice::from_raw_parts_mut( + fut.uget_mut((j, i)) as *mut Float, + SimdT::lanes(), + )); } + } - let mut a = a; - for i in $BLOCK.len()..nx - $BLOCK.len() { - // Push a onto circular buffer - a = [a[1], a[2], a[3], a[4], a[5], a[6], unsafe { - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, i + 3)) as *const Float, - SimdT::lanes(), - )) - }]; - let b = idx - * (a[0] * $DIAG[0] - + a[1] * $DIAG[1] - + a[2] * $DIAG[2] - + a[3] * $DIAG[3] - + a[4] * $DIAG[4] - + a[5] * $DIAG[5] - + a[6] * $DIAG[6]); - unsafe { - b.write_to_slice_unaligned(slice::from_raw_parts_mut( - fut.uget_mut((j, i)) as *mut Float, - SimdT::lanes(), - )); - } + let mut a = a; + for i in $BLOCK.len()..nx - $BLOCK.len() { + // Push a onto circular buffer + a = [a[1], a[2], a[3], a[4], a[5], a[6], unsafe { + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, i + 3)) as *const Float, + SimdT::lanes(), + )) + }]; + let b = idx + * (a[0] * $DIAG[0] + + a[1] * $DIAG[1] + + a[2] * $DIAG[2] + + a[3] * $DIAG[3] + + a[4] * $DIAG[4] + + a[5] * $DIAG[5] + + a[6] * $DIAG[6]); + unsafe { + b.write_to_slice_unaligned(slice::from_raw_parts_mut( + fut.uget_mut((j, i)) as *mut Float, + SimdT::lanes(), + )); } + } - let a = unsafe { - [ - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 1)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 2)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 3)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 4)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 5)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 6)) as *const Float, - SimdT::lanes(), - )), - SimdT::from_slice_unaligned(slice::from_raw_parts( - prev.uget((j, nx - 7)) as *const Float, - SimdT::lanes(), - )), - ] - }; + let a = unsafe { + [ + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 1)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 2)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 3)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 4)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 5)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 6)) as *const Float, + SimdT::lanes(), + )), + SimdT::from_slice_unaligned(slice::from_raw_parts( + prev.uget((j, nx - 7)) as *const Float, + SimdT::lanes(), + )), + ] + }; - for (i, bl) in $BLOCK.iter().enumerate() { - let idx = if $symmetric { idx } else { -idx }; - let b = idx - * (a[0] * bl[0] - + a[1] * bl[1] - + a[2] * bl[2] - + a[3] * bl[3] - + a[4] * bl[4] - + a[5] * bl[5] - + a[6] * bl[6]); - unsafe { - b.write_to_slice_unaligned(slice::from_raw_parts_mut( - fut.uget_mut((j, nx - 1 - i)) as *mut Float, - SimdT::lanes(), - )); - } + for (i, bl) in $BLOCK.iter().enumerate() { + let idx = if $symmetric { idx } else { -idx }; + let b = idx + * (a[0] * bl[0] + + a[1] * bl[1] + + a[2] * bl[2] + + a[3] * bl[3] + + a[4] * bl[4] + + a[5] * bl[5] + + a[6] * bl[6]); + unsafe { + b.write_to_slice_unaligned(slice::from_raw_parts_mut( + fut.uget_mut((j, nx - 1 - i)) as *mut Float, + SimdT::lanes(), + )); } } } @@ -264,14 +245,8 @@ macro_rules! diff_simd_col_7_47 { }; } -diff_simd_col_7_47!(Upwind4, diff_simd_col, Upwind4::BLOCK, Upwind4::DIAG, false); -diff_simd_col_7_47!( - Upwind4, - diss_simd_col, - Upwind4::DISS_BLOCK, - Upwind4::DISS_DIAG, - true -); +diff_simd_col_7_47!(diff_simd_col, Upwind4::BLOCK, Upwind4::DIAG, false); +diff_simd_col_7_47!(diss_simd_col, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true); impl Upwind4 { #[rustfmt::skip] @@ -305,32 +280,30 @@ impl Upwind4 { } impl SbpOperator for Upwind4 { + fn diff1d(prev: ArrayView1, fut: ArrayViewMut1) { + diff_1d(prev, fut) + } fn diffxi(prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - Self::diff_simd_row(prev, fut); + diff_simd_row(prev, fut); } ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { - Self::diff_simd_col(prev, fut); + diff_simd_col(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diff_1d(r0, r1); + Self::diff1d(r0, r1); } } _ => unreachable!("Should only be two elements in the strides vectors"), } } - fn diffeta(prev: ArrayView2, fut: ArrayViewMut2) { - // transpose then use diffxi - Self::diffxi(prev.reversed_axes(), fut.reversed_axes()); - } - fn h() -> &'static [Float] { Self::HBLOCK } @@ -350,7 +323,7 @@ fn upwind4_test() { target[i] = 1.0; } res.fill(0.0); - Upwind4::diff_1d(source.view(), res.view_mut()); + Upwind4::diff1d(source.view(), res.view_mut()); approx::assert_abs_diff_eq!(&res, &target, epsilon = 1e-4); { let source = source.to_owned().insert_axis(ndarray::Axis(0)); @@ -376,7 +349,7 @@ fn upwind4_test() { target[i] = 2.0 * x; } res.fill(0.0); - Upwind4::diff_1d(source.view(), res.view_mut()); + Upwind4::diff1d(source.view(), res.view_mut()); approx::assert_abs_diff_eq!(&res, &target, epsilon = 1e-4); { let source = source.to_owned().insert_axis(ndarray::Axis(0)); @@ -402,7 +375,7 @@ fn upwind4_test() { target[i] = 3.0 * x * x; } res.fill(0.0); - Upwind4::diff_1d(source.view(), res.view_mut()); + Upwind4::diff1d(source.view(), res.view_mut()); approx::assert_abs_diff_eq!(&res, &target, epsilon = 1e-2); { @@ -425,31 +398,29 @@ fn upwind4_test() { } impl UpwindOperator for Upwind4 { + fn diss1d(prev: ArrayView1, fut: ArrayViewMut1) { + diss_1d(prev, fut) + } fn dissxi(prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - Self::diss_simd_row(prev, fut); + diss_simd_row(prev, fut); } ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { - Self::diss_simd_col(prev, fut); + diss_simd_col(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diss_1d(r0, r1); + Self::diss1d(r0, r1); } } _ => unreachable!("Should only be two elements in the strides vectors"), } } - - fn disseta(prev: ArrayView2, fut: ArrayViewMut2) { - // diffeta = transpose then use dissxi - Self::dissxi(prev.reversed_axes(), fut.reversed_axes()); - } } #[test] diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index 3b11b99..1db5fa1 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -1,19 +1,13 @@ use super::{SbpOperator, UpwindOperator}; use crate::diff_op_1d; use crate::Float; -use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; +use ndarray::{s, ArrayView1, ArrayViewMut1}; #[derive(Debug)] pub struct Upwind9 {} -diff_op_1d!(Upwind9, diff_1d, Upwind9::BLOCK, Upwind9::DIAG, false); -diff_op_1d!( - Upwind9, - diss_1d, - Upwind9::DISS_BLOCK, - Upwind9::DISS_DIAG, - true -); +diff_op_1d!(diff_1d, Upwind9::BLOCK, Upwind9::DIAG, false); +diff_op_1d!(diss_1d, Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true); impl Upwind9 { #[rustfmt::skip] @@ -55,18 +49,8 @@ impl Upwind9 { } impl SbpOperator for Upwind9 { - fn diffxi(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); - - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diff_1d(r0, r1); - } - } - - fn diffeta(prev: ArrayView2, fut: ArrayViewMut2) { - // transpose then use diffxi - Self::diffxi(prev.reversed_axes(), fut.reversed_axes()); + fn diff1d(prev: ArrayView1, fut: ArrayViewMut1) { + diff_1d(prev, fut) } fn h() -> &'static [Float] { @@ -75,18 +59,8 @@ impl SbpOperator for Upwind9 { } impl UpwindOperator for Upwind9 { - fn dissxi(prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); - - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self::diss_1d(r0, r1); - } - } - - fn disseta(prev: ArrayView2, fut: ArrayViewMut2) { - // diffeta = transpose then use dissxi - Self::dissxi(prev.reversed_axes(), fut.reversed_axes()); + fn diss1d(prev: ArrayView1, fut: ArrayViewMut1) { + diss_1d(prev, fut) } }