From c1335574593bf2d1a19bbed4a5c68f46486f417c Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Fri, 29 Jan 2021 17:36:05 +0100 Subject: [PATCH] use Matrix in SBP diff --- sbp/src/operators/algos.rs | 207 ++++++++++++++++++++++++++---- sbp/src/operators/traditional4.rs | 25 +++- sbp/src/operators/traditional8.rs | 24 +++- 3 files changed, 215 insertions(+), 41 deletions(-) diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index bd8c4cb..2b60090 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -3,7 +3,7 @@ use super::*; pub(crate) mod constmatrix { #![allow(unused)] /// A row-major matrix - #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[repr(transparent)] pub struct Matrix { data: [[T; N]; M], @@ -11,19 +11,11 @@ pub(crate) mod constmatrix { pub type RowVector = Matrix; pub type ColVector = Matrix; - impl Default for Matrix { + impl Default for Matrix { fn default() -> Self { - use std::mem::MaybeUninit; - let mut d: [[MaybeUninit; N]; M] = unsafe { MaybeUninit::uninit().assume_init() }; - - for row in d.iter_mut() { - for item in row.iter_mut() { - *item = MaybeUninit::new(T::default()); - } + Self { + data: [[T::default(); N]; M], } - - let data = unsafe { std::mem::transmute_copy::<_, [[T; N]; M]>(&d) }; - Self { data } } } @@ -111,12 +103,12 @@ pub(crate) mod constmatrix { pub fn flip(&self) -> Self where - T: Default + Clone, + T: Default + Copy, { let mut v = Self::default(); for i in 0..M { for j in 0..N { - v[(i, j)] = self[(M - 1 - i, N - 1 - j)].clone() + v[(i, j)] = self[(M - 1 - i, N - 1 - j)] } } v @@ -163,10 +155,6 @@ pub(crate) mod constmatrix { let _m2 = Matrix::new([[1, 2], [3, 4]]); } - #[test] - fn construct_non_copy() { - let _m = Matrix::::default(); - } #[test] fn matmul() { @@ -184,8 +172,8 @@ pub(crate) use constmatrix::{ColVector, Matrix, RowVector}; #[inline(always)] pub(crate) fn diff_op_1d_matrix( block: &Matrix, - diag: &RowVector, - symmetry: Symmetry, + blockend: &Matrix, + diag: &RowVector, optype: OperatorType, prev: ArrayView1, mut fut: ArrayViewMut1, @@ -226,19 +214,14 @@ pub(crate) fn diff_op_1d_matrix( *f = diff * idx; } - for (bl, f) in block.iter_rows().zip(fut.iter_mut().rev()) { + for (bl, f) in blockend.iter_rows().zip(fut.iter_mut().rev().take(M).rev()) { let diff = bl .iter() - .zip(prev.iter().rev()) + .zip(prev.iter()) .map(|(x, y)| x * y) .sum::(); - *f = idx - * if symmetry == Symmetry::Symmetric { - diff - } else { - -diff - }; + *f = diff * idx; } } @@ -708,6 +691,174 @@ fn product_fast<'a>( }) } +#[inline(always)] +pub(crate) fn diff_op_col_matrix( + block: &Matrix, + block2: &Matrix, + diag: &RowVector, + optype: OperatorType, + prev: ArrayView2, + mut fut: ArrayViewMut2, +) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * M); + + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + #[cfg(not(feature = "f32"))] + type SimdT = packed_simd::f64x8; + #[cfg(feature = "f32")] + type SimdT = packed_simd::f32x16; + + let ny = prev.shape()[0]; + // How many elements that can be simdified + let simdified = SimdT::lanes() * (ny / SimdT::lanes()); + + let half_diag_width = (D - 1) / 2; + assert!(half_diag_width <= M); + + let fut_base_ptr = fut.as_mut_ptr(); + let fut_stride = fut.strides()[1]; + let fut_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) } + }; + + let prev_base_ptr = prev.as_ptr(); + let prev_stride = prev.strides()[1]; + let prev_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } + }; + + // Not algo necessary, but gives performance increase + assert_eq!(fut_stride, prev_stride); + + // First block + { + for (ifut, &bl) in block.iter_rows().enumerate() { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (iprev, &bl) in bl.iter().enumerate() { + f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); + } + f *= idx; + + unsafe { + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + unsafe { + let mut f = 0.0; + for (iprev, bl) in bl.iter().enumerate() { + f += bl * *prev_ptr(j, iprev); + } + *fut_ptr(j, ifut) = f * idx; + } + } + } + } + + // Diagonal elements + { + for ifut in M..nx - M { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + f = index_to_simd(offset).mul_adde(SimdT::splat(d), f); + } + f *= idx; + unsafe { + // puts simd along stride 1, j never goes past end of slice + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + let mut f = 0.0; + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + unsafe { + f += d * *prev_ptr(j, offset); + } + } + unsafe { + *fut_ptr(j, ifut) = idx * f; + } + } + } + } + + // End block + { + for (ifut, &bl) in (nx - M..nx).zip(block2.iter_rows()) { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (iprev, &bl) in (nx - N..nx).zip(bl.iter()) { + f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); + } + f *= idx; + + unsafe { + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + unsafe { + let mut f = 0.0; + for (iprev, bl) in (nx - N..nx).zip(bl.iter()) { + f += bl * *prev_ptr(j, iprev); + } + *fut_ptr(j, ifut) = f * idx; + } + } + } + } +} + #[inline(always)] pub(crate) fn diff_op_row( block: &'static [&'static [Float]], diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index af84f0b..57afd35 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -1,4 +1,4 @@ -use super::{diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d}; +use super::{SbpOperator1d, SbpOperator2d}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -55,10 +55,10 @@ impl SBP4 { impl SbpOperator1d for SBP4 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d( - Self::BLOCK, - Self::DIAG, - super::Symmetry::AntiSymmetric, + super::diff_op_1d_matrix( + &Self::BLOCK_MATRIX, + &Self::BLOCKEND_MATRIX, + &Self::DIAG_MATRIX, super::OperatorType::Normal, prev, fut, @@ -104,6 +104,17 @@ fn diff_op_row_local(prev: ndarray::ArrayView2, mut fut: ndarray::ArrayVi ) } } +fn diff_op_col_local(prev: ndarray::ArrayView2, fut: ndarray::ArrayViewMut2) { + let optype = super::OperatorType::Normal; + super::diff_op_col_matrix( + &SBP4::BLOCK_MATRIX, + &SBP4::BLOCKEND_MATRIX, + &SBP4::DIAG_MATRIX, + optype, + prev, + fut, + ) +} impl SbpOperator2d for SBP4 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { @@ -112,14 +123,14 @@ impl SbpOperator2d for SBP4 { let symmetry = super::Symmetry::AntiSymmetric; let optype = super::OperatorType::Normal; - match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { //diff_op_row(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut); diff_op_row_local(prev, fut) } ([1, _], [1, _]) => { - diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut); + //diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut); + diff_op_col_local(prev, fut) } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index 76cd2e6..88631b3 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -55,10 +55,10 @@ impl SBP8 { impl SbpOperator1d for SBP8 { fn diff(&self, prev: ArrayView1, fut: ArrayViewMut1) { - super::diff_op_1d( - Self::BLOCK, - Self::DIAG, - super::Symmetry::AntiSymmetric, + super::diff_op_1d_matrix( + &Self::BLOCK_MATRIX, + &Self::BLOCKEND_MATRIX, + &Self::DIAG_MATRIX, super::OperatorType::Normal, prev, fut, @@ -107,6 +107,18 @@ fn diff_op_row_local(prev: ndarray::ArrayView2, mut fut: ndarray::ArrayVi } } +fn diff_op_col_local(prev: ndarray::ArrayView2, fut: ndarray::ArrayViewMut2) { + let optype = super::OperatorType::Normal; + super::diff_op_col_matrix( + &SBP8::BLOCK_MATRIX, + &SBP8::BLOCKEND_MATRIX, + &SBP8::DIAG_MATRIX, + optype, + prev, + fut, + ) +} + impl SbpOperator2d for SBP8 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); @@ -114,14 +126,14 @@ impl SbpOperator2d for SBP8 { let symmetry = super::Symmetry::AntiSymmetric; let optype = super::OperatorType::Normal; - match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { //diff_op_row(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut); diff_op_row_local(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut); + //diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut); + diff_op_col_local(prev, fut) } ([_, _], [_, _]) => { // Fallback, work row by row