change macro to inlined func
This commit is contained in:
		@@ -37,59 +37,54 @@ pub trait InterpolationOperator: Send + Sync {
 | 
			
		||||
    fn coarse2fine(coarse: ArrayView1<Float>, fine: ArrayViewMut1<Float>);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[macro_export]
 | 
			
		||||
macro_rules! diff_op_1d {
 | 
			
		||||
    ($name: ident, $BLOCK: expr, $DIAG: expr) => {
 | 
			
		||||
        diff_op_1d!($name, $BLOCK, $DIAG, false);
 | 
			
		||||
    };
 | 
			
		||||
    ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => {
 | 
			
		||||
        fn $name(prev: ArrayView1<Float>, mut fut: ArrayViewMut1<Float>) {
 | 
			
		||||
            assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
            let nx = prev.shape()[0];
 | 
			
		||||
            assert!(nx >= 2 * $BLOCK.len());
 | 
			
		||||
#[inline(always)]
 | 
			
		||||
pub(crate) fn diff_op_1d(
 | 
			
		||||
    block: ndarray::ArrayView2<Float>,
 | 
			
		||||
    diag: ndarray::ArrayView1<Float>,
 | 
			
		||||
    symmetric: bool,
 | 
			
		||||
    prev: ArrayView1<Float>,
 | 
			
		||||
    mut fut: ArrayViewMut1<Float>,
 | 
			
		||||
) {
 | 
			
		||||
    assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
    let nx = prev.shape()[0];
 | 
			
		||||
    assert!(nx >= 2 * block.len_of(ndarray::Axis(0)));
 | 
			
		||||
 | 
			
		||||
            let dx = 1.0 / (nx - 1) as Float;
 | 
			
		||||
            let idx = 1.0 / dx;
 | 
			
		||||
    let dx = 1.0 / (nx - 1) as Float;
 | 
			
		||||
    let idx = 1.0 / dx;
 | 
			
		||||
 | 
			
		||||
            let block = ::ndarray::arr2($BLOCK);
 | 
			
		||||
            let diag = ::ndarray::arr1($DIAG);
 | 
			
		||||
    let first_elems = prev.slice(::ndarray::s!(..block.len_of(::ndarray::Axis(1))));
 | 
			
		||||
    for (bl, f) in block.outer_iter().zip(&mut fut) {
 | 
			
		||||
        let diff = first_elems.dot(&bl);
 | 
			
		||||
        *f = diff * idx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // The window needs to be aligned to the diagonal elements,
 | 
			
		||||
    // based on the block size
 | 
			
		||||
    let window_elems_to_skip = block.len_of(::ndarray::Axis(0)) - ((diag.len() - 1) / 2);
 | 
			
		||||
 | 
			
		||||
            let first_elems = prev.slice(::ndarray::s!(..block.len_of(::ndarray::Axis(1))));
 | 
			
		||||
            for (bl, f) in block.outer_iter().zip(&mut fut) {
 | 
			
		||||
                let diff = first_elems.dot(&bl);
 | 
			
		||||
                *f = diff * idx;
 | 
			
		||||
            }
 | 
			
		||||
    for (window, f) in prev
 | 
			
		||||
        .windows(diag.len())
 | 
			
		||||
        .into_iter()
 | 
			
		||||
        .skip(window_elems_to_skip)
 | 
			
		||||
        .zip(fut.iter_mut().skip(block.len_of(::ndarray::Axis(0))))
 | 
			
		||||
        .take(nx - 2 * block.len_of(::ndarray::Axis(0)))
 | 
			
		||||
    {
 | 
			
		||||
        let diff = diag.dot(&window);
 | 
			
		||||
        *f = diff * idx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
            // The window needs to be aligned to the diagonal elements,
 | 
			
		||||
            // based on the block size
 | 
			
		||||
            let window_elems_to_skip =
 | 
			
		||||
                block.len_of(::ndarray::Axis(0)) - ((diag.len() - 1) / 2);
 | 
			
		||||
 | 
			
		||||
            for (window, f) in prev
 | 
			
		||||
                .windows(diag.len())
 | 
			
		||||
                .into_iter()
 | 
			
		||||
                .skip(window_elems_to_skip)
 | 
			
		||||
                .zip(fut.iter_mut().skip(block.len_of(::ndarray::Axis(0))))
 | 
			
		||||
                .take(nx - 2 * block.len_of(::ndarray::Axis(0)))
 | 
			
		||||
            {
 | 
			
		||||
                let diff = diag.dot(&window);
 | 
			
		||||
                *f = diff * idx;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let last_elems = prev.slice(::ndarray::s!(nx - block.len_of(::ndarray::Axis(1))..;-1));
 | 
			
		||||
            for (bl, f) in block.outer_iter()
 | 
			
		||||
                .zip(&mut fut.slice_mut(s![nx - block.len_of(::ndarray::Axis(0))..;-1]))
 | 
			
		||||
            {
 | 
			
		||||
                let diff = if $symmetric {
 | 
			
		||||
                    bl.dot(&last_elems)
 | 
			
		||||
                } else {
 | 
			
		||||
                    -bl.dot(&last_elems)
 | 
			
		||||
                };
 | 
			
		||||
                *f = diff * idx;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    let last_elems = prev.slice(::ndarray::s!(nx - block.len_of(::ndarray::Axis(1))..;-1));
 | 
			
		||||
    for (bl, f) in block
 | 
			
		||||
        .outer_iter()
 | 
			
		||||
        .zip(&mut fut.slice_mut(::ndarray::s![nx - block.len_of(::ndarray::Axis(0))..;-1]))
 | 
			
		||||
    {
 | 
			
		||||
        let diff = if symmetric {
 | 
			
		||||
            bl.dot(&last_elems)
 | 
			
		||||
        } else {
 | 
			
		||||
            -bl.dot(&last_elems)
 | 
			
		||||
        };
 | 
			
		||||
        *f = diff * idx;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
mod upwind4;
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,10 @@
 | 
			
		||||
use super::SbpOperator;
 | 
			
		||||
use crate::diff_op_1d;
 | 
			
		||||
use crate::Float;
 | 
			
		||||
use ndarray::{s, ArrayView1, ArrayViewMut1};
 | 
			
		||||
use ndarray::{ArrayView1, ArrayViewMut1};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct SBP4 {}
 | 
			
		||||
 | 
			
		||||
diff_op_1d!(diff_1d, SBP4::BLOCK, SBP4::DIAG);
 | 
			
		||||
 | 
			
		||||
impl SBP4 {
 | 
			
		||||
    #[rustfmt::skip]
 | 
			
		||||
    const HBLOCK: &'static [Float] = &[
 | 
			
		||||
@@ -28,7 +25,13 @@ impl SBP4 {
 | 
			
		||||
 | 
			
		||||
impl SbpOperator for SBP4 {
 | 
			
		||||
    fn diff1d(prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        diff_1d(prev, fut)
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            ndarray::arr2(Self::BLOCK).view(),
 | 
			
		||||
            ndarray::arr1(Self::DIAG).view(),
 | 
			
		||||
            false,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn h() -> &'static [Float] {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,13 +1,10 @@
 | 
			
		||||
use super::SbpOperator;
 | 
			
		||||
use crate::diff_op_1d;
 | 
			
		||||
use crate::Float;
 | 
			
		||||
use ndarray::{s, ArrayView1, ArrayViewMut1};
 | 
			
		||||
use ndarray::{ArrayView1, ArrayViewMut1};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct SBP8 {}
 | 
			
		||||
 | 
			
		||||
diff_op_1d!(diff_1d, SBP8::BLOCK, SBP8::DIAG);
 | 
			
		||||
 | 
			
		||||
impl SBP8 {
 | 
			
		||||
    #[rustfmt::skip]
 | 
			
		||||
    const HBLOCK: &'static [Float] = &[
 | 
			
		||||
@@ -32,7 +29,13 @@ impl SBP8 {
 | 
			
		||||
 | 
			
		||||
impl SbpOperator for SBP8 {
 | 
			
		||||
    fn diff1d(prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        diff_1d(prev, fut)
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            ndarray::arr2(Self::BLOCK).view(),
 | 
			
		||||
            ndarray::arr1(Self::DIAG).view(),
 | 
			
		||||
            false,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn h() -> &'static [Float] {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,6 @@
 | 
			
		||||
use super::{SbpOperator, UpwindOperator};
 | 
			
		||||
use crate::diff_op_1d;
 | 
			
		||||
use crate::Float;
 | 
			
		||||
use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
 | 
			
		||||
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct Upwind4 {}
 | 
			
		||||
@@ -12,9 +11,6 @@ type SimdT = packed_simd::f32x8;
 | 
			
		||||
#[cfg(not(feature = "f32"))]
 | 
			
		||||
type SimdT = packed_simd::f64x8;
 | 
			
		||||
 | 
			
		||||
diff_op_1d!(diff_1d, Upwind4::BLOCK, Upwind4::DIAG);
 | 
			
		||||
diff_op_1d!(diss_1d, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true);
 | 
			
		||||
 | 
			
		||||
macro_rules! diff_simd_row_7_47 {
 | 
			
		||||
    ($name: ident, $BLOCK: expr, $DIAG: expr, $symmetric: expr) => {
 | 
			
		||||
        #[inline(never)]
 | 
			
		||||
@@ -281,7 +277,13 @@ impl Upwind4 {
 | 
			
		||||
 | 
			
		||||
impl SbpOperator for Upwind4 {
 | 
			
		||||
    fn diff1d(prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        diff_1d(prev, fut)
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            ndarray::arr2(Self::BLOCK).view(),
 | 
			
		||||
            ndarray::arr1(Self::DIAG).view(),
 | 
			
		||||
            false,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
    fn diffxi(prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
@@ -399,7 +401,13 @@ fn upwind4_test() {
 | 
			
		||||
 | 
			
		||||
impl UpwindOperator for Upwind4 {
 | 
			
		||||
    fn diss1d(prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        diss_1d(prev, fut)
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            ndarray::arr2(Self::DISS_BLOCK).view(),
 | 
			
		||||
            ndarray::arr1(Self::DISS_DIAG).view(),
 | 
			
		||||
            true,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
    fn dissxi(prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
 
 | 
			
		||||
@@ -1,14 +1,10 @@
 | 
			
		||||
use super::{SbpOperator, UpwindOperator};
 | 
			
		||||
use crate::diff_op_1d;
 | 
			
		||||
use crate::Float;
 | 
			
		||||
use ndarray::{s, ArrayView1, ArrayViewMut1};
 | 
			
		||||
use ndarray::{ArrayView1, ArrayViewMut1};
 | 
			
		||||
 | 
			
		||||
#[derive(Debug)]
 | 
			
		||||
pub struct Upwind9 {}
 | 
			
		||||
 | 
			
		||||
diff_op_1d!(diff_1d, Upwind9::BLOCK, Upwind9::DIAG);
 | 
			
		||||
diff_op_1d!(diss_1d, Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true);
 | 
			
		||||
 | 
			
		||||
impl Upwind9 {
 | 
			
		||||
    #[rustfmt::skip]
 | 
			
		||||
    const HBLOCK: &'static [Float] = &[
 | 
			
		||||
@@ -50,7 +46,13 @@ impl Upwind9 {
 | 
			
		||||
 | 
			
		||||
impl SbpOperator for Upwind9 {
 | 
			
		||||
    fn diff1d(prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        diff_1d(prev, fut)
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            ndarray::arr2(Self::BLOCK).view(),
 | 
			
		||||
            ndarray::arr1(Self::DIAG).view(),
 | 
			
		||||
            false,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn h() -> &'static [Float] {
 | 
			
		||||
@@ -60,7 +62,13 @@ impl SbpOperator for Upwind9 {
 | 
			
		||||
 | 
			
		||||
impl UpwindOperator for Upwind9 {
 | 
			
		||||
    fn diss1d(prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        diss_1d(prev, fut)
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            ndarray::arr2(Self::DISS_BLOCK).view(),
 | 
			
		||||
            ndarray::arr1(Self::DISS_DIAG).view(),
 | 
			
		||||
            true,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user