add back specialization
This commit is contained in:
		@@ -1,4 +1,5 @@
 | 
			
		||||
#![feature(str_strip)]
 | 
			
		||||
#![feature(specialization)]
 | 
			
		||||
 | 
			
		||||
#[cfg(feature = "f32")]
 | 
			
		||||
pub type Float = f32;
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,7 @@ pub trait SbpOperator2d: Copy + Clone {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (SBPeta, SBPxi) {
 | 
			
		||||
    fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
    default fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
        for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
            self.1.diff(r0, r1)
 | 
			
		||||
@@ -50,26 +50,23 @@ impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (SBPeta, SBP
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<SBP: SbpOperator1d> SbpOperator2d for SBP {
 | 
			
		||||
    fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
        for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
            self.diff(r0, r1)
 | 
			
		||||
        }
 | 
			
		||||
    fn diffxi(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        <(SBP, SBP) as SbpOperator2d>::diffxi(&(*self, *self), prev, fut)
 | 
			
		||||
    }
 | 
			
		||||
    fn diffeta(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        self.diffxi(prev.reversed_axes(), fut.reversed_axes())
 | 
			
		||||
        <(SBP, SBP) as SbpOperator2d>::diffeta(&(*self, *self), prev, fut)
 | 
			
		||||
    }
 | 
			
		||||
    fn hxi(&self) -> &'static [Float] {
 | 
			
		||||
        self.h()
 | 
			
		||||
        <(SBP, SBP) as SbpOperator2d>::hxi(&(*self, *self))
 | 
			
		||||
    }
 | 
			
		||||
    fn heta(&self) -> &'static [Float] {
 | 
			
		||||
        self.h()
 | 
			
		||||
        <(SBP, SBP) as SbpOperator2d>::heta(&(*self, *self))
 | 
			
		||||
    }
 | 
			
		||||
    fn is_h2xi(&self) -> bool {
 | 
			
		||||
        self.is_h2()
 | 
			
		||||
        <(SBP, SBP) as SbpOperator2d>::is_h2xi(&(*self, *self))
 | 
			
		||||
    }
 | 
			
		||||
    fn is_h2eta(&self) -> bool {
 | 
			
		||||
        self.is_h2()
 | 
			
		||||
        <(SBP, SBP) as SbpOperator2d>::is_h2eta(&(*self, *self))
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -83,7 +80,7 @@ pub trait UpwindOperator2d: SbpOperator2d + Copy + Clone {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<UOeta: UpwindOperator1d, UOxi: UpwindOperator1d> UpwindOperator2d for (UOeta, UOxi) {
 | 
			
		||||
    fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
    default fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
        for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
            self.1.diss(r0, r1);
 | 
			
		||||
@@ -96,14 +93,11 @@ impl<UOeta: UpwindOperator1d, UOxi: UpwindOperator1d> UpwindOperator2d for (UOet
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<UO: UpwindOperator1d> UpwindOperator2d for UO {
 | 
			
		||||
    fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
        for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
            self.diss(r0, r1);
 | 
			
		||||
        }
 | 
			
		||||
    fn dissxi(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        <(UO, UO) as UpwindOperator2d>::dissxi(&(*self, *self), prev, fut)
 | 
			
		||||
    }
 | 
			
		||||
    fn disseta(&self, prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        self.dissxi(prev.reversed_axes(), fut.reversed_axes())
 | 
			
		||||
        <(UO, UO) as UpwindOperator2d>::disseta(&(*self, *self), prev, fut)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -291,10 +291,10 @@ impl SbpOperator1d for Upwind4 {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
impl<SBP: SbpOperator1d> SbpOperator2d for (Upwind4, SBP) {
 | 
			
		||||
    fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
        assert!(prev.shape()[1] >= 2 * Self::BLOCK.len());
 | 
			
		||||
        assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
 | 
			
		||||
 | 
			
		||||
        match (prev.strides(), fut.strides()) {
 | 
			
		||||
            ([_, 1], [_, 1]) => {
 | 
			
		||||
@@ -306,15 +306,13 @@ impl SbpOperator1d for Upwind4 {
 | 
			
		||||
            ([_, _], [_, _]) => {
 | 
			
		||||
                // Fallback, work row by row
 | 
			
		||||
                for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
                    Self.diff1d(r0, r1);
 | 
			
		||||
                    Upwind4.diff(r0, r1);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            _ => unreachable!("Should only be two elements in the strides vectors"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#[test]
 | 
			
		||||
fn upwind4_test() {
 | 
			
		||||
@@ -417,10 +415,10 @@ impl UpwindOperator1d for Upwind4 {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
impl<SBP: UpwindOperator1d> UpwindOperator2d for (Upwind4, SBP) {
 | 
			
		||||
    fn dissxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
        assert!(prev.shape()[1] >= 2 * Self::BLOCK.len());
 | 
			
		||||
        assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
 | 
			
		||||
 | 
			
		||||
        match (prev.strides(), fut.strides()) {
 | 
			
		||||
            ([_, 1], [_, 1]) => {
 | 
			
		||||
@@ -432,14 +430,13 @@ impl UpwindOperator1d for Upwind4 {
 | 
			
		||||
            ([_, _], [_, _]) => {
 | 
			
		||||
                // Fallback, work row by row
 | 
			
		||||
                for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
                    Self.diss1d(r0, r1);
 | 
			
		||||
                    Upwind4.diss(r0, r1);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            _ => unreachable!("Should only be two elements in the strides vectors"),
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#[test]
 | 
			
		||||
fn upwind4_test2() {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user