diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 70d8fea..4aa79ef 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -169,6 +169,72 @@ pub(crate) fn diff_op_1d( } } +#[inline(always)] +pub(crate) fn diff_op_row( + block: &[&[Float]], + diag: &[Float], + symmetric: bool, + is_h2: bool, + prev: ArrayView2, + mut fut: ArrayViewMut2, +) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); + + assert_eq!(prev.strides()[1], 1); + assert_eq!(fut.strides()[1], 1); + + let dx = if is_h2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + for (prev, mut fut) in prev + .axis_iter(ndarray::Axis(0)) + .zip(fut.axis_iter_mut(ndarray::Axis(0))) + { + let prev = prev.as_slice().unwrap(); + let fut = fut.as_slice_mut().unwrap(); + + for (bl, f) in block.iter().zip(fut.iter_mut()) { + let diff = bl + .iter() + .zip(prev.iter()) + .map(|(x, y)| x * y) + .sum::(); + *f = diff * idx; + } + + // The window needs to be aligned to the diagonal elements, + // based on the block size + let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); + + for (window, f) in prev + .windows(diag.len()) + .into_iter() + .skip(window_elems_to_skip) + .zip(fut.iter_mut().skip(block.len())) + .take(nx - 2 * block.len()) + { + let diff = diag.iter().zip(window).map(|(&x, &y)| x * y).sum::(); + *f = diff * idx; + } + + for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { + let diff = bl + .iter() + .zip(prev.iter().rev()) + .map(|(x, y)| x * y) + .sum::(); + + *f = idx * if symmetric { diff } else { -diff }; + } + } +} + mod upwind4; pub use upwind4::Upwind4; mod upwind9; diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index e1018c6..2221966 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -1,6 +1,6 @@ -use super::*; +use super::{diff_op_row, SbpOperator1d, SbpOperator2d}; use crate::Float; -use ndarray::{ArrayView1, ArrayViewMut1}; +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; #[derive(Debug, Copy, Clone)] pub struct SBP4; @@ -33,6 +33,26 @@ impl SbpOperator1d for SBP4 { } } +impl SbpOperator2d for (&SBP, &SBP4) { + fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * SBP4::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row(SBP4::BLOCK, SBP4::DIAG, false, false, prev, fut); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + SBP4.diff(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + #[test] fn test_trad4() { use super::testing::*; diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index 84ec3ab..062d134 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -1,6 +1,6 @@ -use super::*; +use super::{diff_op_row, SbpOperator1d, SbpOperator2d}; use crate::Float; -use ndarray::{ArrayView1, ArrayViewMut1}; +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; #[derive(Debug, Copy, Clone)] pub struct SBP8; @@ -37,6 +37,26 @@ impl SbpOperator1d for SBP8 { } } +impl SbpOperator2d for (&SBP, &SBP8) { + fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * SBP8::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row(SBP8::BLOCK, SBP8::DIAG, false, false, prev, fut); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + SBP8.diff(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + #[test] fn test_trad8() { use super::testing::*; diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 42a958c..6e9afb5 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -1,4 +1,4 @@ -use super::*; +use super::{SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis}; @@ -284,7 +284,7 @@ impl SbpOperator1d for Upwind4 { } } -impl SbpOperator2d for (&Upwind4, &SBP) { +impl SbpOperator2d for (&SBP, &Upwind4) { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); @@ -405,7 +405,7 @@ impl UpwindOperator1d for Upwind4 { } } -impl UpwindOperator2d for (&Upwind4, &SBP) { +impl UpwindOperator2d for (&UO, &Upwind4) { fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index ede6b04..a8d3dd3 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -1,6 +1,6 @@ -use super::*; +use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; use crate::Float; -use ndarray::{ArrayView1, ArrayViewMut1}; +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; #[derive(Debug, Copy, Clone)] pub struct Upwind4h2; @@ -49,6 +49,53 @@ impl SbpOperator1d for Upwind4h2 { } } +impl SbpOperator2d for (&SBP, &Upwind4h2) { + fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * Upwind4h2::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true, prev, fut); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Upwind4h2.diff(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + +impl UpwindOperator2d for (&UO, &Upwind4h2) { + fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * Upwind4h2::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row( + Upwind4h2::DISS_BLOCK, + Upwind4h2::DISS_DIAG, + true, + true, + prev, + fut, + ); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Upwind4h2.diss(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + #[test] fn upwind4h2_test() { let nx = 20; diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index f5c775d..cb9abaf 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -1,6 +1,6 @@ -use super::*; +use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; use crate::Float; -use ndarray::{ArrayView1, ArrayViewMut1}; +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; #[derive(Debug, Copy, Clone)] pub struct Upwind9; @@ -54,6 +54,26 @@ impl SbpOperator1d for Upwind9 { } } +impl SbpOperator2d for (&SBP, &Upwind9) { + fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * Upwind9::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, false, false, prev, fut); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Upwind9.diff(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + impl UpwindOperator1d for Upwind9 { fn diss(&self, prev: ArrayView1, fut: ArrayViewMut1) { super::diff_op_1d(Self::DISS_BLOCK, Self::DISS_DIAG, true, false, prev, fut) @@ -64,6 +84,33 @@ impl UpwindOperator1d for Upwind9 { } } +impl UpwindOperator2d for (&UO, &Upwind9) { + fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * Upwind9::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row( + Upwind9::DISS_BLOCK, + Upwind9::DISS_DIAG, + true, + false, + prev, + fut, + ); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Upwind9.diss(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + #[test] fn test_upwind9() { use super::testing::*; diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index 51def84..a070893 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -1,6 +1,6 @@ -use super::*; +use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; use crate::Float; -use ndarray::{ArrayView1, ArrayViewMut1}; +use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; #[derive(Debug, Copy, Clone)] pub struct Upwind9h2; @@ -57,6 +57,26 @@ impl SbpOperator1d for Upwind9h2 { } } +impl SbpOperator2d for (&SBP, &Upwind9h2) { + fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * Upwind9h2::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true, prev, fut); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Upwind9h2.diff(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +} + #[test] fn upwind9h2_test() { let nx = 30; @@ -90,3 +110,30 @@ impl UpwindOperator1d for Upwind9h2 { self } } + +impl UpwindOperator2d for (&UO, &Upwind9h2) { + fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + assert!(prev.shape()[1] >= 2 * Upwind9h2::BLOCK.len()); + + match (prev.strides(), fut.strides()) { + ([_, 1], [_, 1]) => { + diff_op_row( + Upwind9h2::DISS_BLOCK, + Upwind9h2::DISS_DIAG, + true, + true, + prev, + fut, + ); + } + ([_, _], [_, _]) => { + // Fallback, work row by row + for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { + Upwind9h2.diss(r0, r1); + } + } + _ => unreachable!("Should only be two elements in the strides vectors"), + } + } +}