From 5de82f393b548ad955f9864682777ba2fac49b3e Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Mon, 18 Jan 2021 22:37:24 +0100 Subject: [PATCH] move algos to separate file --- sbp/src/operators.rs | 495 +------------------------------------ sbp/src/operators/algos.rs | 493 ++++++++++++++++++++++++++++++++++++ 2 files changed, 496 insertions(+), 492 deletions(-) create mode 100644 sbp/src/operators/algos.rs diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 0a4f7f0..8f8598f 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -1,6 +1,9 @@ #![allow(clippy::excessive_precision)] #![allow(clippy::unreadable_literal)] +mod algos; +pub(crate) use algos::*; + use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -142,498 +145,6 @@ impl UpwindOperator2d for (Box, Box) } } -#[inline(always)] -fn diff_op_1d( - block: &[&[Float]], - diag: &[Float], - symmetry: Symmetry, - optype: OperatorType, - prev: ArrayView1, - mut fut: ArrayViewMut1, -) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[0]; - assert!(nx >= 2 * block.len()); - - let dx = if optype == OperatorType::H2 { - 1.0 / (nx - 2) as Float - } else { - 1.0 / (nx - 1) as Float - }; - let idx = 1.0 / dx; - - for (bl, f) in block.iter().zip(&mut fut) { - let diff = bl - .iter() - .zip(prev.iter()) - .map(|(x, y)| x * y) - .sum::(); - *f = diff * idx; - } - - // The window needs to be aligned to the diagonal elements, - // based on the block size - let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); - - for (window, f) in prev - .windows(diag.len()) - .into_iter() - .skip(window_elems_to_skip) - .zip(fut.iter_mut().skip(block.len())) - .take(nx - 2 * block.len()) - { - let diff = diag.iter().zip(&window).map(|(x, y)| x * y).sum::(); - *f = diff * idx; - } - - for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { - let diff = bl - .iter() - .zip(prev.iter().rev()) - .map(|(x, y)| x * y) - .sum::(); - - *f = idx - * if symmetry == Symmetry::Symmetric { - diff - } else { - -diff - }; - } -} - -#[derive(PartialEq, Copy, Clone)] -enum Symmetry { - Symmetric, - AntiSymmetric, -} - -#[derive(PartialEq, Copy, Clone)] -enum OperatorType { - Normal, - H2, -} - -#[inline(always)] -#[allow(unused)] -fn diff_op_col_naive( - block: &'static [&'static [Float]], - diag: &'static [Float], - symmetry: Symmetry, - optype: OperatorType, -) -> impl Fn(ArrayView2, ArrayViewMut2) { - #[inline(always)] - move |prev: ArrayView2, mut fut: ArrayViewMut2| { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[1]; - assert!(nx >= 2 * block.len()); - - assert_eq!(prev.strides()[0], 1); - assert_eq!(fut.strides()[0], 1); - - let dx = if optype == OperatorType::H2 { - 1.0 / (nx - 2) as Float - } else { - 1.0 / (nx - 1) as Float - }; - let idx = 1.0 / dx; - - fut.fill(0.0); - - // First block - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { - debug_assert_eq!(fut.len(), prev.shape()[0]); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { - debug_assert_eq!(prev.len(), fut.len()); - fut.scaled_add(idx * bl, &prev); - } - } - - let half_diag_width = (diag.len() - 1) / 2; - assert!(half_diag_width <= block.len()); - - // Diagonal entries - for (ifut, mut fut) in fut - .axis_iter_mut(ndarray::Axis(1)) - .enumerate() - .skip(block.len()) - .take(nx - 2 * block.len()) - { - for (id, &d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) - } - } - - // End block - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { - fut.fill(0.0); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { - if symmetry == Symmetry::Symmetric { - fut.scaled_add(idx * bl, &prev); - } else { - fut.scaled_add(-idx * bl, &prev); - } - } - } - } -} - -#[inline(always)] -fn diff_op_col( - block: &'static [&'static [Float]], - diag: &'static [Float], - symmetry: Symmetry, - optype: OperatorType, -) -> impl Fn(ArrayView2, ArrayViewMut2) { - diff_op_col_simd(block, diag, symmetry, optype) -} - -#[inline(always)] -fn diff_op_col_simd( - block: &'static [&'static [Float]], - diag: &'static [Float], - symmetry: Symmetry, - optype: OperatorType, -) -> impl Fn(ArrayView2, ArrayViewMut2) { - #[inline(always)] - move |prev: ArrayView2, mut fut: ArrayViewMut2| { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[1]; - assert!(nx >= 2 * block.len()); - - assert_eq!(prev.strides()[0], 1); - assert_eq!(fut.strides()[0], 1); - - let dx = if optype == OperatorType::H2 { - 1.0 / (nx - 2) as Float - } else { - 1.0 / (nx - 1) as Float - }; - let idx = 1.0 / dx; - - #[cfg(not(feature = "f32"))] - type SimdT = packed_simd::f64x8; - #[cfg(feature = "f32")] - type SimdT = packed_simd::f32x16; - - let ny = prev.shape()[0]; - // How many elements that can be simdified - let simdified = SimdT::lanes() * (ny / SimdT::lanes()); - - let half_diag_width = (diag.len() - 1) / 2; - assert!(half_diag_width <= block.len()); - - let fut_base_ptr = fut.as_mut_ptr(); - let fut_stride = fut.strides()[1]; - let fut_ptr = |j, i| { - debug_assert!(j < ny && i < nx); - unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) } - }; - - let prev_base_ptr = prev.as_ptr(); - let prev_stride = prev.strides()[1]; - let prev_ptr = |j, i| { - debug_assert!(j < ny && i < nx); - unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } - }; - - // Not algo necessary, but gives performance increase - assert_eq!(fut_stride, prev_stride); - - // First block - { - for (ifut, &bl) in block.iter().enumerate() { - for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |i| unsafe { - // j never moves past end of slice due to step_by and - // rounding down - SimdT::from_slice_unaligned(std::slice::from_raw_parts( - prev_ptr(j, i), - SimdT::lanes(), - )) - }; - let mut f = SimdT::splat(0.0); - for (iprev, &bl) in bl.iter().enumerate() { - f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); - } - f *= idx; - - unsafe { - f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( - fut_ptr(j, ifut), - SimdT::lanes(), - )); - } - } - for j in simdified..ny { - unsafe { - let mut f = 0.0; - for (iprev, bl) in bl.iter().enumerate() { - f += bl * *prev_ptr(j, iprev); - } - *fut_ptr(j, ifut) = f * idx; - } - } - } - } - - // Diagonal elements - { - for ifut in block.len()..nx - block.len() { - for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |i| unsafe { - // j never moves past end of slice due to step_by and - // rounding down - SimdT::from_slice_unaligned(std::slice::from_raw_parts( - prev_ptr(j, i), - SimdT::lanes(), - )) - }; - let mut f = SimdT::splat(0.0); - for (id, &d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - f = index_to_simd(offset).mul_adde(SimdT::splat(d), f); - } - f *= idx; - unsafe { - // puts simd along stride 1, j never goes past end of slice - f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( - fut_ptr(j, ifut), - SimdT::lanes(), - )); - } - } - for j in simdified..ny { - let mut f = 0.0; - for (id, &d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - unsafe { - f += d * *prev_ptr(j, offset); - } - } - unsafe { - *fut_ptr(j, ifut) = idx * f; - } - } - } - } - - // End block - { - // Get blocks and corresponding offsets - // (rev to iterate in ifut increasing order) - for (bl, ifut) in block.iter().zip((0..nx).rev()) { - for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |i| unsafe { - // j never moves past end of slice due to step_by and - // rounding down - SimdT::from_slice_unaligned(std::slice::from_raw_parts( - prev_ptr(j, i), - SimdT::lanes(), - )) - }; - let mut f = SimdT::splat(0.0); - for (&bl, iprev) in bl.iter().zip((0..nx).rev()) { - f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); - } - f = if symmetry == Symmetry::Symmetric { - f * idx - } else { - -f * idx - }; - unsafe { - f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( - fut_ptr(j, ifut), - SimdT::lanes(), - )); - } - } - - for j in simdified..ny { - unsafe { - let mut f = 0.0; - for (&bl, iprev) in bl.iter().zip((0..nx).rev()).rev() { - f += bl * *prev_ptr(j, iprev); - } - *fut_ptr(j, ifut) = if symmetry == Symmetry::Symmetric { - f * idx - } else { - -f * idx - }; - } - } - } - } - } -} - -#[inline(always)] -fn product_fast<'a>( - u: impl Iterator, - v: impl Iterator, -) -> Float { - use std::intrinsics::{fadd_fast, fmul_fast}; - u.zip(v).fold(0.0, |acc, (&u, &v)| unsafe { - // We do not care about the order of multiplication nor addition - fadd_fast(acc, fmul_fast(u, v)) - }) -} - -#[inline(always)] -fn diff_op_row( - block: &'static [&'static [Float]], - diag: &'static [Float], - symmetry: Symmetry, - optype: OperatorType, -) -> impl Fn(ArrayView2, ArrayViewMut2) { - #[inline(always)] - move |prev: ArrayView2, mut fut: ArrayViewMut2| { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[1]; - assert!(nx >= 2 * block.len()); - - assert_eq!(prev.strides()[1], 1); - assert_eq!(fut.strides()[1], 1); - - let dx = if optype == OperatorType::H2 { - 1.0 / (nx - 2) as Float - } else { - 1.0 / (nx - 1) as Float - }; - let idx = 1.0 / dx; - - for (prev, mut fut) in prev - .axis_iter(ndarray::Axis(0)) - .zip(fut.axis_iter_mut(ndarray::Axis(0))) - { - let prev = prev.as_slice().unwrap(); - let fut = fut.as_slice_mut().unwrap(); - assert_eq!(prev.len(), fut.len()); - assert!(prev.len() >= 2 * block.len()); - - for (bl, f) in block.iter().zip(fut.iter_mut()) { - let diff = product_fast(bl.iter(), prev[..bl.len()].iter()); - *f = diff * idx; - } - - // The window needs to be aligned to the diagonal elements, - // based on the block size - let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); - - for (window, f) in prev - .windows(diag.len()) - .skip(window_elems_to_skip) - .zip(fut.iter_mut().skip(block.len())) - .take(nx - 2 * block.len()) - { - let diff = product_fast(diag.iter(), window.iter()); - *f = diff * idx; - } - - for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { - let diff = product_fast(bl.iter(), prev.iter().rev()); - - *f = idx - * if symmetry == Symmetry::Symmetric { - diff - } else { - -diff - }; - } - } - } -} - -#[cfg(feature = "sparse")] -fn sparse_from_block( - block: &[&[Float]], - diag: &[Float], - symmetry: Symmetry, - optype: OperatorType, - n: usize, -) -> sprs::CsMat { - assert!(n >= 2 * block.len()); - - let nnz = { - let block_elems = block.iter().fold(0, |acc, x| { - acc + x - .iter() - .fold(0, |acc, &x| if x != 0.0 { acc + 1 } else { acc }) - }); - - let diag_elems = diag - .iter() - .fold(0, |acc, &x| if x != 0.0 { acc + 1 } else { acc }); - - 2 * block_elems + (n - 2 * block.len()) * diag_elems - }; - - let mut mat = sprs::TriMat::with_capacity((n, n), nnz); - - let dx = if optype == OperatorType::H2 { - 1.0 / (n - 2) as Float - } else { - 1.0 / (n - 1) as Float - }; - let idx = 1.0 / dx; - - for (j, bl) in block.iter().enumerate() { - for (i, &b) in bl.iter().enumerate() { - if b == 0.0 { - continue; - } - mat.add_triplet(j, i, b * idx); - } - } - - for j in block.len()..n - block.len() { - let half_diag_len = diag.len() / 2; - for (&d, i) in diag.iter().zip(j - half_diag_len..) { - if d == 0.0 { - continue; - } - mat.add_triplet(j, i, d * idx); - } - } - - for (bl, j) in block.iter().zip((0..n).rev()).rev() { - for (&b, i) in bl.iter().zip((0..n).rev()).rev() { - if b == 0.0 { - continue; - } - if symmetry == Symmetry::AntiSymmetric { - mat.add_triplet(j, i, -b * idx); - } else { - mat.add_triplet(j, i, b * idx); - } - } - } - - mat.to_csr() -} - -#[cfg(feature = "sparse")] -fn h_matrix(diag: &[Float], n: usize, is_h2: bool) -> sprs::CsMat { - let h = if is_h2 { - 1.0 / (n - 2) as Float - } else { - 1.0 / (n - 1) as Float - }; - let nmiddle = n - 2 * diag.len(); - let iter = diag - .iter() - .chain(std::iter::repeat(&1.0).take(nmiddle)) - .chain(diag.iter().rev()) - .map(|&x| h * x); - - let mut mat = sprs::TriMat::with_capacity((n, n), n); - for (i, d) in iter.enumerate() { - mat.add_triplet(i, i, d); - } - mat.to_csr() -} - mod upwind4; pub use upwind4::Upwind4; mod upwind9; diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs new file mode 100644 index 0000000..58a3072 --- /dev/null +++ b/sbp/src/operators/algos.rs @@ -0,0 +1,493 @@ +use super::*; + +#[inline(always)] +pub(crate) fn diff_op_1d( + block: &[&[Float]], + diag: &[Float], + symmetry: Symmetry, + optype: OperatorType, + prev: ArrayView1, + mut fut: ArrayViewMut1, +) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[0]; + assert!(nx >= 2 * block.len()); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + for (bl, f) in block.iter().zip(&mut fut) { + let diff = bl + .iter() + .zip(prev.iter()) + .map(|(x, y)| x * y) + .sum::(); + *f = diff * idx; + } + + // The window needs to be aligned to the diagonal elements, + // based on the block size + let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); + + for (window, f) in prev + .windows(diag.len()) + .into_iter() + .skip(window_elems_to_skip) + .zip(fut.iter_mut().skip(block.len())) + .take(nx - 2 * block.len()) + { + let diff = diag.iter().zip(&window).map(|(x, y)| x * y).sum::(); + *f = diff * idx; + } + + for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { + let diff = bl + .iter() + .zip(prev.iter().rev()) + .map(|(x, y)| x * y) + .sum::(); + + *f = idx + * if symmetry == Symmetry::Symmetric { + diff + } else { + -diff + }; + } +} + +#[derive(PartialEq, Copy, Clone)] +pub(crate) enum Symmetry { + Symmetric, + AntiSymmetric, +} + +#[derive(PartialEq, Copy, Clone)] +pub(crate) enum OperatorType { + Normal, + H2, +} + +#[inline(always)] +#[allow(unused)] +pub(crate) fn diff_op_col_naive( + block: &'static [&'static [Float]], + diag: &'static [Float], + symmetry: Symmetry, + optype: OperatorType, +) -> impl Fn(ArrayView2, ArrayViewMut2) { + #[inline(always)] + move |prev: ArrayView2, mut fut: ArrayViewMut2| { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); + + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + fut.fill(0.0); + + // First block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { + debug_assert_eq!(fut.len(), prev.shape()[0]); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + debug_assert_eq!(prev.len(), fut.len()); + fut.scaled_add(idx * bl, &prev); + } + } + + let half_diag_width = (diag.len() - 1) / 2; + assert!(half_diag_width <= block.len()); + + // Diagonal entries + for (ifut, mut fut) in fut + .axis_iter_mut(ndarray::Axis(1)) + .enumerate() + .skip(block.len()) + .take(nx - 2 * block.len()) + { + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) + } + } + + // End block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { + fut.fill(0.0); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { + if symmetry == Symmetry::Symmetric { + fut.scaled_add(idx * bl, &prev); + } else { + fut.scaled_add(-idx * bl, &prev); + } + } + } + } +} + +#[inline(always)] +pub(crate) fn diff_op_col( + block: &'static [&'static [Float]], + diag: &'static [Float], + symmetry: Symmetry, + optype: OperatorType, +) -> impl Fn(ArrayView2, ArrayViewMut2) { + diff_op_col_simd(block, diag, symmetry, optype) +} + +#[inline(always)] +pub(crate) fn diff_op_col_simd( + block: &'static [&'static [Float]], + diag: &'static [Float], + symmetry: Symmetry, + optype: OperatorType, +) -> impl Fn(ArrayView2, ArrayViewMut2) { + #[inline(always)] + move |prev: ArrayView2, mut fut: ArrayViewMut2| { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); + + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + #[cfg(not(feature = "f32"))] + type SimdT = packed_simd::f64x8; + #[cfg(feature = "f32")] + type SimdT = packed_simd::f32x16; + + let ny = prev.shape()[0]; + // How many elements that can be simdified + let simdified = SimdT::lanes() * (ny / SimdT::lanes()); + + let half_diag_width = (diag.len() - 1) / 2; + assert!(half_diag_width <= block.len()); + + let fut_base_ptr = fut.as_mut_ptr(); + let fut_stride = fut.strides()[1]; + let fut_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) } + }; + + let prev_base_ptr = prev.as_ptr(); + let prev_stride = prev.strides()[1]; + let prev_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } + }; + + // Not algo necessary, but gives performance increase + assert_eq!(fut_stride, prev_stride); + + // First block + { + for (ifut, &bl) in block.iter().enumerate() { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (iprev, &bl) in bl.iter().enumerate() { + f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); + } + f *= idx; + + unsafe { + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + unsafe { + let mut f = 0.0; + for (iprev, bl) in bl.iter().enumerate() { + f += bl * *prev_ptr(j, iprev); + } + *fut_ptr(j, ifut) = f * idx; + } + } + } + } + + // Diagonal elements + { + for ifut in block.len()..nx - block.len() { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + f = index_to_simd(offset).mul_adde(SimdT::splat(d), f); + } + f *= idx; + unsafe { + // puts simd along stride 1, j never goes past end of slice + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + let mut f = 0.0; + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + unsafe { + f += d * *prev_ptr(j, offset); + } + } + unsafe { + *fut_ptr(j, ifut) = idx * f; + } + } + } + } + + // End block + { + // Get blocks and corresponding offsets + // (rev to iterate in ifut increasing order) + for (bl, ifut) in block.iter().zip((0..nx).rev()) { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (&bl, iprev) in bl.iter().zip((0..nx).rev()) { + f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); + } + f = if symmetry == Symmetry::Symmetric { + f * idx + } else { + -f * idx + }; + unsafe { + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + + for j in simdified..ny { + unsafe { + let mut f = 0.0; + for (&bl, iprev) in bl.iter().zip((0..nx).rev()).rev() { + f += bl * *prev_ptr(j, iprev); + } + *fut_ptr(j, ifut) = if symmetry == Symmetry::Symmetric { + f * idx + } else { + -f * idx + }; + } + } + } + } + } +} + +#[inline(always)] +fn product_fast<'a>( + u: impl Iterator, + v: impl Iterator, +) -> Float { + use std::intrinsics::{fadd_fast, fmul_fast}; + u.zip(v).fold(0.0, |acc, (&u, &v)| unsafe { + // We do not care about the order of multiplication nor addition + fadd_fast(acc, fmul_fast(u, v)) + }) +} + +#[inline(always)] +pub(crate) fn diff_op_row( + block: &'static [&'static [Float]], + diag: &'static [Float], + symmetry: Symmetry, + optype: OperatorType, +) -> impl Fn(ArrayView2, ArrayViewMut2) { + #[inline(always)] + move |prev: ArrayView2, mut fut: ArrayViewMut2| { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); + + assert_eq!(prev.strides()[1], 1); + assert_eq!(fut.strides()[1], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + for (prev, mut fut) in prev + .axis_iter(ndarray::Axis(0)) + .zip(fut.axis_iter_mut(ndarray::Axis(0))) + { + let prev = prev.as_slice().unwrap(); + let fut = fut.as_slice_mut().unwrap(); + assert_eq!(prev.len(), fut.len()); + assert!(prev.len() >= 2 * block.len()); + + for (bl, f) in block.iter().zip(fut.iter_mut()) { + let diff = product_fast(bl.iter(), prev[..bl.len()].iter()); + *f = diff * idx; + } + + // The window needs to be aligned to the diagonal elements, + // based on the block size + let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); + + for (window, f) in prev + .windows(diag.len()) + .skip(window_elems_to_skip) + .zip(fut.iter_mut().skip(block.len())) + .take(nx - 2 * block.len()) + { + let diff = product_fast(diag.iter(), window.iter()); + *f = diff * idx; + } + + for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { + let diff = product_fast(bl.iter(), prev.iter().rev()); + + *f = idx + * if symmetry == Symmetry::Symmetric { + diff + } else { + -diff + }; + } + } + } +} + +#[cfg(feature = "sparse")] +fn sparse_from_block( + block: &[&[Float]], + diag: &[Float], + symmetry: Symmetry, + optype: OperatorType, + n: usize, +) -> sprs::CsMat { + assert!(n >= 2 * block.len()); + + let nnz = { + let block_elems = block.iter().fold(0, |acc, x| { + acc + x + .iter() + .fold(0, |acc, &x| if x != 0.0 { acc + 1 } else { acc }) + }); + + let diag_elems = diag + .iter() + .fold(0, |acc, &x| if x != 0.0 { acc + 1 } else { acc }); + + 2 * block_elems + (n - 2 * block.len()) * diag_elems + }; + + let mut mat = sprs::TriMat::with_capacity((n, n), nnz); + + let dx = if optype == OperatorType::H2 { + 1.0 / (n - 2) as Float + } else { + 1.0 / (n - 1) as Float + }; + let idx = 1.0 / dx; + + for (j, bl) in block.iter().enumerate() { + for (i, &b) in bl.iter().enumerate() { + if b == 0.0 { + continue; + } + mat.add_triplet(j, i, b * idx); + } + } + + for j in block.len()..n - block.len() { + let half_diag_len = diag.len() / 2; + for (&d, i) in diag.iter().zip(j - half_diag_len..) { + if d == 0.0 { + continue; + } + mat.add_triplet(j, i, d * idx); + } + } + + for (bl, j) in block.iter().zip((0..n).rev()).rev() { + for (&b, i) in bl.iter().zip((0..n).rev()).rev() { + if b == 0.0 { + continue; + } + if symmetry == Symmetry::AntiSymmetric { + mat.add_triplet(j, i, -b * idx); + } else { + mat.add_triplet(j, i, b * idx); + } + } + } + + mat.to_csr() +} + +#[cfg(feature = "sparse")] +fn h_matrix(diag: &[Float], n: usize, is_h2: bool) -> sprs::CsMat { + let h = if is_h2 { + 1.0 / (n - 2) as Float + } else { + 1.0 / (n - 1) as Float + }; + let nmiddle = n - 2 * diag.len(); + let iter = diag + .iter() + .chain(std::iter::repeat(&1.0).take(nmiddle)) + .chain(diag.iter().rev()) + .map(|&x| h * x); + + let mut mat = sprs::TriMat::with_capacity((n, n), n); + for (i, d) in iter.enumerate() { + mat.add_triplet(i, i, d); + } + mat.to_csr() +}