diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index a504b4c..b32425c 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -572,6 +572,82 @@ pub(crate) fn diff_op_col_naive( } } +#[inline(always)] +#[allow(unused)] +pub(crate) fn diff_op_col_naive_matrix( + block: &Matrix, + blockend: &Matrix, + diag: &RowVector, + optype: OperatorType, + prev: ArrayView2, + mut fut: ArrayViewMut2, +) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + let ny = prev.shape()[0]; + assert!(nx >= 2 * M); + + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + fut.fill(0.0); + + let (mut fut0, mut futmid, mut futn) = fut.multi_slice_mut(( + ndarray::s![.., ..M], + ndarray::s![.., M..nx - M], + ndarray::s![.., nx - M..], + )); + + // First block + for (bl, mut fut) in block.iter_rows().zip(fut0.axis_iter_mut(ndarray::Axis(1))) { + debug_assert_eq!(fut.len(), prev.shape()[0]); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + if bl == 0.0 { + continue; + } + debug_assert_eq!(prev.len(), fut.len()); + fut.scaled_add(idx * bl, &prev); + } + } + + let window_elems_to_skip = M - ((D - 1) / 2); + + // Diagonal entries + for (mut fut, id) in futmid + .axis_iter_mut(ndarray::Axis(1)) + .zip(prev.windows((ny, D)).into_iter().skip(window_elems_to_skip)) + { + for (&d, id) in diag.iter().zip(id.axis_iter(ndarray::Axis(1))) { + if d == 0.0 { + continue; + } + fut.scaled_add(idx * d, &id) + } + } + + // End block + let prev = prev.slice(ndarray::s!(.., nx - N..)); + for (bl, mut fut) in blockend + .iter_rows() + .zip(futn.axis_iter_mut(ndarray::Axis(1))) + { + fut.fill(0.0); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + if bl == 0.0 { + continue; + } + fut.scaled_add(idx * bl, &prev); + } + } +} + #[inline(always)] pub(crate) fn diff_op_col( block: &'static [&'static [Float]], diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index db5802f..f480cda 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -106,7 +106,7 @@ fn diff_op_row_local(prev: ndarray::ArrayView2, mut fut: ndarray::ArrayVi } fn diff_op_col_local(prev: ndarray::ArrayView2, fut: ndarray::ArrayViewMut2) { let optype = super::OperatorType::Normal; - super::diff_op_col_matrix( + super::diff_op_col_naive_matrix( &SBP4::BLOCK_MATRIX, &SBP4::BLOCKEND_MATRIX, &SBP4::DIAG_MATRIX,