From 6df463271932c3cd8311ea0cec8befafad79543c Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Wed, 15 Apr 2020 00:37:46 +0200 Subject: [PATCH] add back specialization --- sbp/src/lib.rs | 1 + sbp/src/operators.rs | 30 ++++++++++++------------------ sbp/src/operators/upwind4.rs | 15 ++++++--------- 3 files changed, 19 insertions(+), 27 deletions(-) diff --git a/sbp/src/lib.rs b/sbp/src/lib.rs index 2f37194..6192d39 100644 --- a/sbp/src/lib.rs +++ b/sbp/src/lib.rs @@ -1,4 +1,5 @@ #![feature(str_strip)] +#![feature(specialization)] #[cfg(feature = "f32")] pub type Float = f32; diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 8b02820..4d99597 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -25,7 +25,7 @@ pub trait SbpOperator2d: Copy + Clone { } impl SbpOperator2d for (SBPeta, SBPxi) { - fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + default fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { self.1.diff(r0, r1) @@ -50,26 +50,23 @@ impl SbpOperator2d for (SBPeta, SBP } impl SbpOperator2d for SBP { - fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - self.diff(r0, r1) - } + fn diffxi(&self, prev: ArrayView2, fut: ArrayViewMut2) { + <(SBP, SBP) as SbpOperator2d>::diffxi(&(*self, *self), prev, fut) } fn diffeta(&self, prev: ArrayView2, fut: ArrayViewMut2) { - self.diffxi(prev.reversed_axes(), fut.reversed_axes()) + <(SBP, SBP) as SbpOperator2d>::diffeta(&(*self, *self), prev, fut) } fn hxi(&self) -> &'static [Float] { - self.h() + <(SBP, SBP) as SbpOperator2d>::hxi(&(*self, *self)) } fn heta(&self) -> &'static [Float] { - self.h() + <(SBP, SBP) as SbpOperator2d>::heta(&(*self, *self)) } fn is_h2xi(&self) -> bool { - self.is_h2() + <(SBP, SBP) as SbpOperator2d>::is_h2xi(&(*self, *self)) } fn is_h2eta(&self) -> bool { - self.is_h2() + <(SBP, SBP) as SbpOperator2d>::is_h2eta(&(*self, *self)) } } @@ -83,7 +80,7 @@ pub trait UpwindOperator2d: SbpOperator2d + Copy + Clone { } impl UpwindOperator2d for (UOeta, UOxi) { - fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + default fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { self.1.diss(r0, r1); @@ -96,14 +93,11 @@ impl UpwindOperator2d for (UOet } impl UpwindOperator2d for UO { - fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - self.diss(r0, r1); - } + fn dissxi(&self, prev: ArrayView2, fut: ArrayViewMut2) { + <(UO, UO) as UpwindOperator2d>::dissxi(&(*self, *self), prev, fut) } fn disseta(&self, prev: ArrayView2, fut: ArrayViewMut2) { - self.dissxi(prev.reversed_axes(), fut.reversed_axes()) + <(UO, UO) as UpwindOperator2d>::disseta(&(*self, *self), prev, fut) } } diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index ba0f362..85c5b3c 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -291,10 +291,10 @@ impl SbpOperator1d for Upwind4 { } } -/* +impl SbpOperator2d for (Upwind4, SBP) { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); + assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { @@ -306,15 +306,13 @@ impl SbpOperator1d for Upwind4 { ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self.diff1d(r0, r1); + Upwind4.diff(r0, r1); } } _ => unreachable!("Should only be two elements in the strides vectors"), } } - } -*/ #[test] fn upwind4_test() { @@ -417,10 +415,10 @@ impl UpwindOperator1d for Upwind4 { } } -/* +impl UpwindOperator2d for (Upwind4, SBP) { fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); - assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); + assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { @@ -432,14 +430,13 @@ impl UpwindOperator1d for Upwind4 { ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - Self.diss1d(r0, r1); + Upwind4.diss(r0, r1); } } _ => unreachable!("Should only be two elements in the strides vectors"), } } } -*/ #[test] fn upwind4_test2() {