add back specialization

This commit is contained in:
Magnus Ulimoen 2020-04-15 00:37:46 +02:00
parent 1667eaaca0
commit 6df4632719
3 changed files with 19 additions and 27 deletions

View File

@ -1,4 +1,5 @@
#![feature(str_strip)] #![feature(str_strip)]
#![feature(specialization)]
#[cfg(feature = "f32")] #[cfg(feature = "f32")]
pub type Float = f32; pub type Float = f32;

View File

@ -25,7 +25,7 @@ pub trait SbpOperator2d: Copy + Clone {
} }
impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (SBPeta, SBPxi) { impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (SBPeta, SBPxi) {
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) { default fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
self.1.diff(r0, r1) self.1.diff(r0, r1)
@ -50,26 +50,23 @@ impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (SBPeta, SBP
} }
impl<SBP: SbpOperator1d> SbpOperator2d for SBP { impl<SBP: SbpOperator1d> SbpOperator2d for SBP {
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) { fn diffxi(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
assert_eq!(prev.shape(), fut.shape()); <(SBP, SBP) as SbpOperator2d>::diffxi(&(*self, *self), prev, fut)
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
self.diff(r0, r1)
}
} }
fn diffeta(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) { fn diffeta(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
self.diffxi(prev.reversed_axes(), fut.reversed_axes()) <(SBP, SBP) as SbpOperator2d>::diffeta(&(*self, *self), prev, fut)
} }
fn hxi(&self) -> &'static [Float] { fn hxi(&self) -> &'static [Float] {
self.h() <(SBP, SBP) as SbpOperator2d>::hxi(&(*self, *self))
} }
fn heta(&self) -> &'static [Float] { fn heta(&self) -> &'static [Float] {
self.h() <(SBP, SBP) as SbpOperator2d>::heta(&(*self, *self))
} }
fn is_h2xi(&self) -> bool { fn is_h2xi(&self) -> bool {
self.is_h2() <(SBP, SBP) as SbpOperator2d>::is_h2xi(&(*self, *self))
} }
fn is_h2eta(&self) -> bool { fn is_h2eta(&self) -> bool {
self.is_h2() <(SBP, SBP) as SbpOperator2d>::is_h2eta(&(*self, *self))
} }
} }
@ -83,7 +80,7 @@ pub trait UpwindOperator2d: SbpOperator2d + Copy + Clone {
} }
impl<UOeta: UpwindOperator1d, UOxi: UpwindOperator1d> UpwindOperator2d for (UOeta, UOxi) { impl<UOeta: UpwindOperator1d, UOxi: UpwindOperator1d> UpwindOperator2d for (UOeta, UOxi) {
fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) { default fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
self.1.diss(r0, r1); self.1.diss(r0, r1);
@ -96,14 +93,11 @@ impl<UOeta: UpwindOperator1d, UOxi: UpwindOperator1d> UpwindOperator2d for (UOet
} }
impl<UO: UpwindOperator1d> UpwindOperator2d for UO { impl<UO: UpwindOperator1d> UpwindOperator2d for UO {
fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) { fn dissxi(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
assert_eq!(prev.shape(), fut.shape()); <(UO, UO) as UpwindOperator2d>::dissxi(&(*self, *self), prev, fut)
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
self.diss(r0, r1);
}
} }
fn disseta(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) { fn disseta(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
self.dissxi(prev.reversed_axes(), fut.reversed_axes()) <(UO, UO) as UpwindOperator2d>::disseta(&(*self, *self), prev, fut)
} }
} }

View File

@ -291,10 +291,10 @@ impl SbpOperator1d for Upwind4 {
} }
} }
/* impl<SBP: SbpOperator1d> SbpOperator2d for (Upwind4, SBP) {
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) { fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
match (prev.strides(), fut.strides()) { match (prev.strides(), fut.strides()) {
([_, 1], [_, 1]) => { ([_, 1], [_, 1]) => {
@ -306,15 +306,13 @@ impl SbpOperator1d for Upwind4 {
([_, _], [_, _]) => { ([_, _], [_, _]) => {
// Fallback, work row by row // Fallback, work row by row
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
Self.diff1d(r0, r1); Upwind4.diff(r0, r1);
} }
} }
_ => unreachable!("Should only be two elements in the strides vectors"), _ => unreachable!("Should only be two elements in the strides vectors"),
} }
} }
} }
*/
#[test] #[test]
fn upwind4_test() { fn upwind4_test() {
@ -417,10 +415,10 @@ impl UpwindOperator1d for Upwind4 {
} }
} }
/* impl<SBP: UpwindOperator1d> UpwindOperator2d for (Upwind4, SBP) {
fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) { fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
assert!(prev.shape()[1] >= 2 * Self::BLOCK.len()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
match (prev.strides(), fut.strides()) { match (prev.strides(), fut.strides()) {
([_, 1], [_, 1]) => { ([_, 1], [_, 1]) => {
@ -432,14 +430,13 @@ impl UpwindOperator1d for Upwind4 {
([_, _], [_, _]) => { ([_, _], [_, _]) => {
// Fallback, work row by row // Fallback, work row by row
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
Self.diss1d(r0, r1); Upwind4.diss(r0, r1);
} }
} }
_ => unreachable!("Should only be two elements in the strides vectors"), _ => unreachable!("Should only be two elements in the strides vectors"),
} }
} }
} }
*/
#[test] #[test]
fn upwind4_test2() { fn upwind4_test2() {