From 78da9baaeade32fd3565507769050c2de7c92823 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Fri, 1 May 2020 00:09:46 +0200 Subject: [PATCH] name booleans --- sbp/src/operators.rs | 52 ++++++++++++++++++++++--------- sbp/src/operators/traditional4.rs | 16 ++++++++-- sbp/src/operators/traditional8.rs | 16 ++++++++-- sbp/src/operators/upwind4.rs | 32 ++++++++++++++++--- sbp/src/operators/upwind4h2.rs | 42 +++++++++++++++++++++---- sbp/src/operators/upwind9.rs | 32 +++++++++++++++---- sbp/src/operators/upwind9h2.rs | 42 +++++++++++++++++++++---- 7 files changed, 189 insertions(+), 43 deletions(-) diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 1b3e25d..050c451 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -115,11 +115,11 @@ pub trait InterpolationOperator: Send + Sync { } #[inline(always)] -pub(crate) fn diff_op_1d( +fn diff_op_1d( block: &[&[Float]], diag: &[Float], - symmetric: bool, - is_h2: bool, + symmetry: Symmetry, + optype: OperatorType, prev: ArrayView1, mut fut: ArrayViewMut1, ) { @@ -127,7 +127,7 @@ pub(crate) fn diff_op_1d( let nx = prev.shape()[0]; assert!(nx >= 2 * block.len()); - let dx = if is_h2 { + let dx = if optype == OperatorType::H2 { 1.0 / (nx - 2) as Float } else { 1.0 / (nx - 1) as Float @@ -165,16 +165,33 @@ pub(crate) fn diff_op_1d( .map(|(x, y)| x * y) .sum::(); - *f = idx * if symmetric { diff } else { -diff }; + *f = idx + * if symmetry == Symmetry::Symmetric { + diff + } else { + -diff + }; } } +#[derive(PartialEq, Copy, Clone)] +enum Symmetry { + Symmetric, + AntiSymmetric, +} + +#[derive(PartialEq, Copy, Clone)] +enum OperatorType { + Normal, + H2, +} + #[inline(always)] -pub(crate) fn diff_op_col( +fn diff_op_col( block: &'static [&'static [Float]], diag: &'static [Float], - symmetric: bool, - is_h2: bool, + symmetry: Symmetry, + optype: OperatorType, ) -> impl Fn(ArrayView2, ArrayViewMut2) { #[inline(always)] move |prev: ArrayView2, mut fut: ArrayViewMut2| { @@ -185,7 +202,7 @@ pub(crate) fn diff_op_col( assert_eq!(prev.strides()[0], 1); assert_eq!(fut.strides()[0], 1); - let dx = if is_h2 { + let dx = if optype == OperatorType::H2 { 1.0 / (nx - 2) as Float } else { 1.0 / (nx - 1) as Float @@ -249,7 +266,7 @@ pub(crate) fn diff_op_col( for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { fut.fill(0.0); for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { - if symmetric { + if symmetry == Symmetry::Symmetric { fut.scaled_add(idx * bl, &prev); } else { fut.scaled_add(-idx * bl, &prev); @@ -260,11 +277,11 @@ pub(crate) fn diff_op_col( } #[inline(always)] -pub(crate) fn diff_op_row( +fn diff_op_row( block: &'static [&'static [Float]], diag: &'static [Float], - symmetric: bool, - is_h2: bool, + symmetry: Symmetry, + optype: OperatorType, ) -> impl Fn(ArrayView2, ArrayViewMut2) { #[inline(always)] move |prev: ArrayView2, mut fut: ArrayViewMut2| { @@ -275,7 +292,7 @@ pub(crate) fn diff_op_row( assert_eq!(prev.strides()[1], 1); assert_eq!(fut.strides()[1], 1); - let dx = if is_h2 { + let dx = if optype == OperatorType::H2 { 1.0 / (nx - 2) as Float } else { 1.0 / (nx - 1) as Float @@ -319,7 +336,12 @@ pub(crate) fn diff_op_row( .map(|(x, y)| x * y) .sum::(); - *f = idx * if symmetric { diff } else { -diff }; + *f = idx + * if symmetry == Symmetry::Symmetric { + diff + } else { + -diff + }; } } } diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index 1a26ae5..deb4a13 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -25,7 +25,14 @@ impl SBP4 { impl SbpOperator1d for SBP4 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::BLOCK, Self::DIAG, false, false, prev, fut) + super::diff_op_1d( + Self::BLOCK, + Self::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + prev, + fut, + ) } fn h(&self) -> &'static [Float] { @@ -38,12 +45,15 @@ impl SbpOperator2d for (&SBP, &SBP4) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * SBP4::BLOCK.len()); + let symmetry = super::Symmetry::AntiSymmetric; + let optype = super::OperatorType::Normal; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(SBP4::BLOCK, SBP4::DIAG, false, false)(prev, fut); + diff_op_row(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(SBP4::BLOCK, SBP4::DIAG, false, false)(prev, fut); + diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index 7d34931..cdcbde1 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -29,7 +29,14 @@ impl SBP8 { impl SbpOperator1d for SBP8 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::BLOCK, Self::DIAG, false, false, prev, fut) + super::diff_op_1d( + Self::BLOCK, + Self::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + prev, + fut, + ) } fn h(&self) -> &'static [Float] { @@ -42,12 +49,15 @@ impl SbpOperator2d for (&SBP, &SBP8) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * SBP8::BLOCK.len()); + let symmetry = super::Symmetry::AntiSymmetric; + let optype = super::OperatorType::Normal; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(SBP8::BLOCK, SBP8::DIAG, false, false)(prev, fut); + diff_op_row(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(SBP8::BLOCK, SBP8::DIAG, false, false)(prev, fut); + diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 9e4bda4..e5f6df8 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -195,7 +195,14 @@ impl Upwind4 { impl SbpOperator1d for Upwind4 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::BLOCK, Self::DIAG, false, false, prev, fut) + super::diff_op_1d( + Self::BLOCK, + Self::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + prev, + fut, + ) } fn h(&self) -> &'static [Float] { Self::HBLOCK @@ -209,7 +216,12 @@ impl SbpOperator2d for (&SBP, &Upwind4) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind4::BLOCK, Upwind4::DIAG, false, false)(prev, fut); + diff_op_row( + Upwind4::BLOCK, + Upwind4::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + )(prev, fut); } ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { diff_simd_col(prev, fut); @@ -315,7 +327,14 @@ fn upwind4_test() { impl UpwindOperator1d for Upwind4 { fn diss(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::DISS_BLOCK, Self::DISS_DIAG, true, false, prev, fut) + super::diff_op_1d( + Self::DISS_BLOCK, + Self::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::Normal, + prev, + fut, + ) } fn as_sbp(&self) -> &dyn SbpOperator1d { @@ -330,7 +349,12 @@ impl UpwindOperator2d for (&UO, &Upwind4) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true, false)(prev, fut); + diff_op_row( + Upwind4::DISS_BLOCK, + Upwind4::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::Normal, + )(prev, fut); } ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { diss_simd_col(prev, fut); diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index 06aa12f..0f8acc0 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -40,7 +40,14 @@ impl Upwind4h2 { impl SbpOperator1d for Upwind4h2 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::BLOCK, Self::DIAG, false, true, prev, fut) + super::diff_op_1d( + Self::BLOCK, + Self::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::H2, + prev, + fut, + ) } fn h(&self) -> &'static [Float] { @@ -56,12 +63,15 @@ impl SbpOperator2d for (&SBP, &Upwind4h2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4h2::BLOCK.len()); + let symmetry = super::Symmetry::AntiSymmetric; + let optype = super::OperatorType::H2; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true)(prev, fut); + diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, symmetry, optype)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true)(prev, fut); + diff_op_col(Upwind4h2::BLOCK, Upwind4h2::DIAG, symmetry, optype)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -79,12 +89,25 @@ impl UpwindOperator2d for (&UO, &Upwind4h2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4h2::BLOCK.len()); + let symmetry = super::Symmetry::Symmetric; + let optype = super::OperatorType::H2; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind4h2::DISS_BLOCK, Upwind4h2::DISS_DIAG, true, true)(prev, fut); + diff_op_row( + Upwind4h2::DISS_BLOCK, + Upwind4h2::DISS_DIAG, + symmetry, + optype, + )(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind4h2::DISS_BLOCK, Upwind4h2::DISS_DIAG, true, true)(prev, fut); + diff_op_col( + Upwind4h2::DISS_BLOCK, + Upwind4h2::DISS_DIAG, + symmetry, + optype, + )(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -124,7 +147,14 @@ fn upwind4h2_test() { impl UpwindOperator1d for Upwind4h2 { fn diss(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::DISS_BLOCK, Self::DISS_DIAG, true, true, prev, fut) + super::diff_op_1d( + Self::DISS_BLOCK, + Self::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::H2, + prev, + fut, + ) } fn as_sbp(&self) -> &dyn SbpOperator1d { diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index 484b316..cefd516 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -48,7 +48,14 @@ impl Upwind9 { impl SbpOperator1d for Upwind9 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::BLOCK, Self::DIAG, false, false, prev, fut) + super::diff_op_1d( + Self::BLOCK, + Self::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + prev, + fut, + ) } fn h(&self) -> &'static [Float] { @@ -61,12 +68,15 @@ impl SbpOperator2d for (&SBP, &Upwind9) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9::BLOCK.len()); + let symmetry = super::Symmetry::AntiSymmetric; + let optype = super::OperatorType::Normal; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, false, false)(prev, fut); + diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, symmetry, optype)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind9::BLOCK, Upwind9::DIAG, false, false)(prev, fut); + diff_op_col(Upwind9::BLOCK, Upwind9::DIAG, symmetry, optype)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -81,7 +91,14 @@ impl SbpOperator2d for (&SBP, &Upwind9) { impl UpwindOperator1d for Upwind9 { fn diss(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::DISS_BLOCK, Self::DISS_DIAG, true, false, prev, fut) + super::diff_op_1d( + Self::DISS_BLOCK, + Self::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::Normal, + prev, + fut, + ) } fn as_sbp(&self) -> &dyn SbpOperator1d { @@ -94,12 +111,15 @@ impl UpwindOperator2d for (&UO, &Upwind9) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9::BLOCK.len()); + let symmetry = super::Symmetry::Symmetric; + let optype = super::OperatorType::Normal; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true, false)(prev, fut); + diff_op_row(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, symmetry, optype)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true, false)(prev, fut); + diff_op_col(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, symmetry, optype)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index c4ebd42..0d44f0f 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -48,7 +48,14 @@ impl Upwind9h2 { impl SbpOperator1d for Upwind9h2 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::BLOCK, Self::DIAG, false, true, prev, fut) + super::diff_op_1d( + Self::BLOCK, + Self::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::H2, + prev, + fut, + ) } fn h(&self) -> &'static [Float] { @@ -64,12 +71,15 @@ impl SbpOperator2d for (&SBP, &Upwind9h2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9h2::BLOCK.len()); + let symmetry = super::Symmetry::AntiSymmetric; + let optype = super::OperatorType::H2; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true)(prev, fut); + diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, symmetry, optype)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true)(prev, fut); + diff_op_col(Upwind9h2::BLOCK, Upwind9h2::DIAG, symmetry, optype)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -109,7 +119,14 @@ fn upwind9h2_test() { impl UpwindOperator1d for Upwind9h2 { fn diss(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d(Self::DISS_BLOCK, Self::DISS_DIAG, true, true, prev, fut) + super::diff_op_1d( + Self::DISS_BLOCK, + Self::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::H2, + prev, + fut, + ) } fn as_sbp(&self) -> &dyn SbpOperator1d { self @@ -121,12 +138,25 @@ impl UpwindOperator2d for (&UO, &Upwind9h2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9h2::BLOCK.len()); + let symmetry = super::Symmetry::Symmetric; + let optype = super::OperatorType::H2; + match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind9h2::DISS_BLOCK, Upwind9h2::DISS_DIAG, true, true)(prev, fut); + diff_op_row( + Upwind9h2::DISS_BLOCK, + Upwind9h2::DISS_DIAG, + symmetry, + optype, + )(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind9h2::DISS_BLOCK, Upwind9h2::DISS_DIAG, true, true)(prev, fut); + diff_op_col( + Upwind9h2::DISS_BLOCK, + Upwind9h2::DISS_DIAG, + symmetry, + optype, + )(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row