diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 050c451..76e4a55 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -186,12 +186,87 @@ enum OperatorType { H2, } +#[inline(always)] +#[allow(unused)] +fn diff_op_col_naive( + block: &'static [&'static [Float]], + diag: &'static [Float], + symmetry: Symmetry, + optype: OperatorType, +) -> impl Fn(ArrayView2, ArrayViewMut2) { + #[inline(always)] + move |prev: ArrayView2, mut fut: ArrayViewMut2| { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); + + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + fut.fill(0.0); + + // First block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { + debug_assert_eq!(fut.len(), prev.shape()[0]); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + debug_assert_eq!(prev.len(), fut.len()); + fut.scaled_add(idx * bl, &prev); + } + } + + let half_diag_width = (diag.len() - 1) / 2; + assert!(half_diag_width <= block.len()); + + // Diagonal entries + for (ifut, mut fut) in fut + .axis_iter_mut(ndarray::Axis(1)) + .enumerate() + .skip(block.len()) + .take(nx - 2 * block.len()) + { + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) + } + } + + // End block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { + fut.fill(0.0); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { + if symmetry == Symmetry::Symmetric { + fut.scaled_add(idx * bl, &prev); + } else { + fut.scaled_add(-idx * bl, &prev); + } + } + } + } +} + #[inline(always)] fn diff_op_col( block: &'static [&'static [Float]], diag: &'static [Float], symmetry: Symmetry, optype: OperatorType, +) -> impl Fn(ArrayView2, ArrayViewMut2) { + diff_op_col_simd(block, diag, symmetry, optype) +} + +#[inline(always)] +fn diff_op_col_simd( + block: &'static [&'static [Float]], + diag: &'static [Float], + symmetry: Symmetry, + optype: OperatorType, ) -> impl Fn(ArrayView2, ArrayViewMut2) { #[inline(always)] move |prev: ArrayView2, mut fut: ArrayViewMut2| {