use Matrix in SBP diff
This commit is contained in:
		@@ -3,7 +3,7 @@ use super::*;
 | 
			
		||||
pub(crate) mod constmatrix {
 | 
			
		||||
    #![allow(unused)]
 | 
			
		||||
    /// A row-major matrix
 | 
			
		||||
    #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
 | 
			
		||||
    #[derive(Debug, Clone, PartialEq, Eq, Hash)]
 | 
			
		||||
    #[repr(transparent)]
 | 
			
		||||
    pub struct Matrix<T, const M: usize, const N: usize> {
 | 
			
		||||
        data: [[T; N]; M],
 | 
			
		||||
@@ -11,19 +11,11 @@ pub(crate) mod constmatrix {
 | 
			
		||||
    pub type RowVector<T, const N: usize> = Matrix<T, 1, N>;
 | 
			
		||||
    pub type ColVector<T, const N: usize> = Matrix<T, N, 1>;
 | 
			
		||||
 | 
			
		||||
    impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
 | 
			
		||||
    impl<T: Copy + Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
 | 
			
		||||
        fn default() -> Self {
 | 
			
		||||
            use std::mem::MaybeUninit;
 | 
			
		||||
            let mut d: [[MaybeUninit<T>; N]; M] = unsafe { MaybeUninit::uninit().assume_init() };
 | 
			
		||||
 | 
			
		||||
            for row in d.iter_mut() {
 | 
			
		||||
                for item in row.iter_mut() {
 | 
			
		||||
                    *item = MaybeUninit::new(T::default());
 | 
			
		||||
                }
 | 
			
		||||
            Self {
 | 
			
		||||
                data: [[T::default(); N]; M],
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            let data = unsafe { std::mem::transmute_copy::<_, [[T; N]; M]>(&d) };
 | 
			
		||||
            Self { data }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -111,12 +103,12 @@ pub(crate) mod constmatrix {
 | 
			
		||||
 | 
			
		||||
        pub fn flip(&self) -> Self
 | 
			
		||||
        where
 | 
			
		||||
            T: Default + Clone,
 | 
			
		||||
            T: Default + Copy,
 | 
			
		||||
        {
 | 
			
		||||
            let mut v = Self::default();
 | 
			
		||||
            for i in 0..M {
 | 
			
		||||
                for j in 0..N {
 | 
			
		||||
                    v[(i, j)] = self[(M - 1 - i, N - 1 - j)].clone()
 | 
			
		||||
                    v[(i, j)] = self[(M - 1 - i, N - 1 - j)]
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            v
 | 
			
		||||
@@ -163,10 +155,6 @@ pub(crate) mod constmatrix {
 | 
			
		||||
 | 
			
		||||
            let _m2 = Matrix::new([[1, 2], [3, 4]]);
 | 
			
		||||
        }
 | 
			
		||||
        #[test]
 | 
			
		||||
        fn construct_non_copy() {
 | 
			
		||||
            let _m = Matrix::<String, 2, 1>::default();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        #[test]
 | 
			
		||||
        fn matmul() {
 | 
			
		||||
@@ -184,8 +172,8 @@ pub(crate) use constmatrix::{ColVector, Matrix, RowVector};
 | 
			
		||||
#[inline(always)]
 | 
			
		||||
pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
    block: &Matrix<Float, M, N>,
 | 
			
		||||
    diag: &RowVector<Float, N>,
 | 
			
		||||
    symmetry: Symmetry,
 | 
			
		||||
    blockend: &Matrix<Float, M, N>,
 | 
			
		||||
    diag: &RowVector<Float, D>,
 | 
			
		||||
    optype: OperatorType,
 | 
			
		||||
    prev: ArrayView1<Float>,
 | 
			
		||||
    mut fut: ArrayViewMut1<Float>,
 | 
			
		||||
@@ -226,19 +214,14 @@ pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
        *f = diff * idx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (bl, f) in block.iter_rows().zip(fut.iter_mut().rev()) {
 | 
			
		||||
    for (bl, f) in blockend.iter_rows().zip(fut.iter_mut().rev().take(M).rev()) {
 | 
			
		||||
        let diff = bl
 | 
			
		||||
            .iter()
 | 
			
		||||
            .zip(prev.iter().rev())
 | 
			
		||||
            .zip(prev.iter())
 | 
			
		||||
            .map(|(x, y)| x * y)
 | 
			
		||||
            .sum::<Float>();
 | 
			
		||||
 | 
			
		||||
        *f = idx
 | 
			
		||||
            * if symmetry == Symmetry::Symmetric {
 | 
			
		||||
                diff
 | 
			
		||||
            } else {
 | 
			
		||||
                -diff
 | 
			
		||||
            };
 | 
			
		||||
        *f = diff * idx;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -708,6 +691,174 @@ fn product_fast<'a>(
 | 
			
		||||
    })
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[inline(always)]
 | 
			
		||||
pub(crate) fn diff_op_col_matrix<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
    block: &Matrix<Float, M, N>,
 | 
			
		||||
    block2: &Matrix<Float, M, N>,
 | 
			
		||||
    diag: &RowVector<Float, D>,
 | 
			
		||||
    optype: OperatorType,
 | 
			
		||||
    prev: ArrayView2<Float>,
 | 
			
		||||
    mut fut: ArrayViewMut2<Float>,
 | 
			
		||||
) {
 | 
			
		||||
    assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
    let nx = prev.shape()[1];
 | 
			
		||||
    assert!(nx >= 2 * M);
 | 
			
		||||
 | 
			
		||||
    assert_eq!(prev.strides()[0], 1);
 | 
			
		||||
    assert_eq!(fut.strides()[0], 1);
 | 
			
		||||
 | 
			
		||||
    let dx = if optype == OperatorType::H2 {
 | 
			
		||||
        1.0 / (nx - 2) as Float
 | 
			
		||||
    } else {
 | 
			
		||||
        1.0 / (nx - 1) as Float
 | 
			
		||||
    };
 | 
			
		||||
    let idx = 1.0 / dx;
 | 
			
		||||
 | 
			
		||||
    #[cfg(not(feature = "f32"))]
 | 
			
		||||
    type SimdT = packed_simd::f64x8;
 | 
			
		||||
    #[cfg(feature = "f32")]
 | 
			
		||||
    type SimdT = packed_simd::f32x16;
 | 
			
		||||
 | 
			
		||||
    let ny = prev.shape()[0];
 | 
			
		||||
    // How many elements that can be simdified
 | 
			
		||||
    let simdified = SimdT::lanes() * (ny / SimdT::lanes());
 | 
			
		||||
 | 
			
		||||
    let half_diag_width = (D - 1) / 2;
 | 
			
		||||
    assert!(half_diag_width <= M);
 | 
			
		||||
 | 
			
		||||
    let fut_base_ptr = fut.as_mut_ptr();
 | 
			
		||||
    let fut_stride = fut.strides()[1];
 | 
			
		||||
    let fut_ptr = |j, i| {
 | 
			
		||||
        debug_assert!(j < ny && i < nx);
 | 
			
		||||
        unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let prev_base_ptr = prev.as_ptr();
 | 
			
		||||
    let prev_stride = prev.strides()[1];
 | 
			
		||||
    let prev_ptr = |j, i| {
 | 
			
		||||
        debug_assert!(j < ny && i < nx);
 | 
			
		||||
        unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) }
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    // Not algo necessary, but gives performance increase
 | 
			
		||||
    assert_eq!(fut_stride, prev_stride);
 | 
			
		||||
 | 
			
		||||
    // First block
 | 
			
		||||
    {
 | 
			
		||||
        for (ifut, &bl) in block.iter_rows().enumerate() {
 | 
			
		||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
			
		||||
                let index_to_simd = |i| unsafe {
 | 
			
		||||
                    // j never moves past end of slice due to step_by and
 | 
			
		||||
                    // rounding down
 | 
			
		||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
			
		||||
                        prev_ptr(j, i),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ))
 | 
			
		||||
                };
 | 
			
		||||
                let mut f = SimdT::splat(0.0);
 | 
			
		||||
                for (iprev, &bl) in bl.iter().enumerate() {
 | 
			
		||||
                    f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
 | 
			
		||||
                }
 | 
			
		||||
                f *= idx;
 | 
			
		||||
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
			
		||||
                        fut_ptr(j, ifut),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            for j in simdified..ny {
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    let mut f = 0.0;
 | 
			
		||||
                    for (iprev, bl) in bl.iter().enumerate() {
 | 
			
		||||
                        f += bl * *prev_ptr(j, iprev);
 | 
			
		||||
                    }
 | 
			
		||||
                    *fut_ptr(j, ifut) = f * idx;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Diagonal elements
 | 
			
		||||
    {
 | 
			
		||||
        for ifut in M..nx - M {
 | 
			
		||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
			
		||||
                let index_to_simd = |i| unsafe {
 | 
			
		||||
                    // j never moves past end of slice due to step_by and
 | 
			
		||||
                    // rounding down
 | 
			
		||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
			
		||||
                        prev_ptr(j, i),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ))
 | 
			
		||||
                };
 | 
			
		||||
                let mut f = SimdT::splat(0.0);
 | 
			
		||||
                for (id, &d) in diag.iter().enumerate() {
 | 
			
		||||
                    let offset = ifut - half_diag_width + id;
 | 
			
		||||
                    f = index_to_simd(offset).mul_adde(SimdT::splat(d), f);
 | 
			
		||||
                }
 | 
			
		||||
                f *= idx;
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    // puts simd along stride 1, j never goes past end of slice
 | 
			
		||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
			
		||||
                        fut_ptr(j, ifut),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            for j in simdified..ny {
 | 
			
		||||
                let mut f = 0.0;
 | 
			
		||||
                for (id, &d) in diag.iter().enumerate() {
 | 
			
		||||
                    let offset = ifut - half_diag_width + id;
 | 
			
		||||
                    unsafe {
 | 
			
		||||
                        f += d * *prev_ptr(j, offset);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    *fut_ptr(j, ifut) = idx * f;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // End block
 | 
			
		||||
    {
 | 
			
		||||
        for (ifut, &bl) in (nx - M..nx).zip(block2.iter_rows()) {
 | 
			
		||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
			
		||||
                let index_to_simd = |i| unsafe {
 | 
			
		||||
                    // j never moves past end of slice due to step_by and
 | 
			
		||||
                    // rounding down
 | 
			
		||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
			
		||||
                        prev_ptr(j, i),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ))
 | 
			
		||||
                };
 | 
			
		||||
                let mut f = SimdT::splat(0.0);
 | 
			
		||||
                for (iprev, &bl) in (nx - N..nx).zip(bl.iter()) {
 | 
			
		||||
                    f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
 | 
			
		||||
                }
 | 
			
		||||
                f *= idx;
 | 
			
		||||
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
			
		||||
                        fut_ptr(j, ifut),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            for j in simdified..ny {
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    let mut f = 0.0;
 | 
			
		||||
                    for (iprev, bl) in (nx - N..nx).zip(bl.iter()) {
 | 
			
		||||
                        f += bl * *prev_ptr(j, iprev);
 | 
			
		||||
                    }
 | 
			
		||||
                    *fut_ptr(j, ifut) = f * idx;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[inline(always)]
 | 
			
		||||
pub(crate) fn diff_op_row(
 | 
			
		||||
    block: &'static [&'static [Float]],
 | 
			
		||||
 
 | 
			
		||||
@@ -1,4 +1,4 @@
 | 
			
		||||
use super::{diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d};
 | 
			
		||||
use super::{SbpOperator1d, SbpOperator2d};
 | 
			
		||||
use crate::Float;
 | 
			
		||||
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
 | 
			
		||||
 | 
			
		||||
@@ -55,10 +55,10 @@ impl SBP4 {
 | 
			
		||||
 | 
			
		||||
impl SbpOperator1d for SBP4 {
 | 
			
		||||
    fn diff(&self, prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            Self::BLOCK,
 | 
			
		||||
            Self::DIAG,
 | 
			
		||||
            super::Symmetry::AntiSymmetric,
 | 
			
		||||
        super::diff_op_1d_matrix(
 | 
			
		||||
            &Self::BLOCK_MATRIX,
 | 
			
		||||
            &Self::BLOCKEND_MATRIX,
 | 
			
		||||
            &Self::DIAG_MATRIX,
 | 
			
		||||
            super::OperatorType::Normal,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
@@ -104,6 +104,17 @@ fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayVi
 | 
			
		||||
        )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
fn diff_op_col_local(prev: ndarray::ArrayView2<Float>, fut: ndarray::ArrayViewMut2<Float>) {
 | 
			
		||||
    let optype = super::OperatorType::Normal;
 | 
			
		||||
    super::diff_op_col_matrix(
 | 
			
		||||
        &SBP4::BLOCK_MATRIX,
 | 
			
		||||
        &SBP4::BLOCKEND_MATRIX,
 | 
			
		||||
        &SBP4::DIAG_MATRIX,
 | 
			
		||||
        optype,
 | 
			
		||||
        prev,
 | 
			
		||||
        fut,
 | 
			
		||||
    )
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl SbpOperator2d for SBP4 {
 | 
			
		||||
    fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
@@ -112,14 +123,14 @@ impl SbpOperator2d for SBP4 {
 | 
			
		||||
 | 
			
		||||
        let symmetry = super::Symmetry::AntiSymmetric;
 | 
			
		||||
        let optype = super::OperatorType::Normal;
 | 
			
		||||
 | 
			
		||||
        match (prev.strides(), fut.strides()) {
 | 
			
		||||
            ([_, 1], [_, 1]) => {
 | 
			
		||||
                //diff_op_row(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
 | 
			
		||||
                diff_op_row_local(prev, fut)
 | 
			
		||||
            }
 | 
			
		||||
            ([1, _], [1, _]) => {
 | 
			
		||||
                diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
 | 
			
		||||
                //diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
 | 
			
		||||
                diff_op_col_local(prev, fut)
 | 
			
		||||
            }
 | 
			
		||||
            ([_, _], [_, _]) => {
 | 
			
		||||
                // Fallback, work row by row
 | 
			
		||||
 
 | 
			
		||||
@@ -55,10 +55,10 @@ impl SBP8 {
 | 
			
		||||
 | 
			
		||||
impl SbpOperator1d for SBP8 {
 | 
			
		||||
    fn diff(&self, prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
 | 
			
		||||
        super::diff_op_1d(
 | 
			
		||||
            Self::BLOCK,
 | 
			
		||||
            Self::DIAG,
 | 
			
		||||
            super::Symmetry::AntiSymmetric,
 | 
			
		||||
        super::diff_op_1d_matrix(
 | 
			
		||||
            &Self::BLOCK_MATRIX,
 | 
			
		||||
            &Self::BLOCKEND_MATRIX,
 | 
			
		||||
            &Self::DIAG_MATRIX,
 | 
			
		||||
            super::OperatorType::Normal,
 | 
			
		||||
            prev,
 | 
			
		||||
            fut,
 | 
			
		||||
@@ -107,6 +107,18 @@ fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayVi
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn diff_op_col_local(prev: ndarray::ArrayView2<Float>, fut: ndarray::ArrayViewMut2<Float>) {
 | 
			
		||||
    let optype = super::OperatorType::Normal;
 | 
			
		||||
    super::diff_op_col_matrix(
 | 
			
		||||
        &SBP8::BLOCK_MATRIX,
 | 
			
		||||
        &SBP8::BLOCKEND_MATRIX,
 | 
			
		||||
        &SBP8::DIAG_MATRIX,
 | 
			
		||||
        optype,
 | 
			
		||||
        prev,
 | 
			
		||||
        fut,
 | 
			
		||||
    )
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl SbpOperator2d for SBP8 {
 | 
			
		||||
    fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
 | 
			
		||||
        assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
@@ -114,14 +126,14 @@ impl SbpOperator2d for SBP8 {
 | 
			
		||||
 | 
			
		||||
        let symmetry = super::Symmetry::AntiSymmetric;
 | 
			
		||||
        let optype = super::OperatorType::Normal;
 | 
			
		||||
 | 
			
		||||
        match (prev.strides(), fut.strides()) {
 | 
			
		||||
            ([_, 1], [_, 1]) => {
 | 
			
		||||
                //diff_op_row(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
 | 
			
		||||
                diff_op_row_local(prev, fut);
 | 
			
		||||
            }
 | 
			
		||||
            ([1, _], [1, _]) => {
 | 
			
		||||
                diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
 | 
			
		||||
                //diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
 | 
			
		||||
                diff_op_col_local(prev, fut)
 | 
			
		||||
            }
 | 
			
		||||
            ([_, _], [_, _]) => {
 | 
			
		||||
                // Fallback, work row by row
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user