diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index 8340632..2a12fc3 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -281,6 +281,105 @@ pub(crate) fn diff_op_2d_fallback( + matrix: &BlockMatrix, + optype: OperatorType, + prev: ArrayView2, + mut fut: ArrayViewMut2, +) { + /* Does not increase the perf... + #[cfg(feature = "fast-float")] + let (matrix, prev, mut fut) = unsafe { + ( + std::mem::transmute::<_, &BlockMatrix>(matrix), + std::mem::transmute::<_, ArrayView2>(prev), + std::mem::transmute::<_, ArrayViewMut2>(fut), + ) + }; + #[cfg(not(feature = "fast-float"))] + let mut fut = fut; + */ + + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + let ny = prev.shape()[0]; + assert!(nx >= 2 * M); + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if optype == OperatorType::H2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + fut.fill(0.0.into()); + let (mut fut0, mut futmid, mut futn) = fut.multi_slice_mut(( + ndarray::s![.., ..M], + ndarray::s![.., M..nx - M], + ndarray::s![.., nx - M..], + )); + + // First block + for (bl, mut fut) in matrix + .start + .iter_rows() + .zip(fut0.axis_iter_mut(ndarray::Axis(1))) + { + let fut = &mut fut.as_slice_mut().unwrap()[..ny]; + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + if bl.is_zero() { + continue; + } + let prev = &prev.as_slice().unwrap()[..ny]; + for (fut, prev) in fut.iter_mut().zip(prev) { + *fut = *fut + idx * prev * bl; + } + } + } + + let window_elems_to_skip = M - ((D - 1) / 2); + + // Diagonal entries + for (mut fut, id) in futmid + .axis_iter_mut(ndarray::Axis(1)) + .zip(prev.windows((ny, D)).into_iter().skip(window_elems_to_skip)) + { + let fut = &mut fut.as_slice_mut().unwrap()[..ny]; + for (&d, id) in matrix.diag.iter().zip(id.axis_iter(ndarray::Axis(1))) { + if d.is_zero() { + continue; + } + let id = id.as_slice().unwrap(); + for (fut, id) in fut.iter_mut().zip(id) { + *fut = *fut + idx * id * d; + } + } + } + + // End block + let prev = prev.slice(ndarray::s!(.., nx - N..)); + for (bl, mut fut) in matrix + .end + .iter_rows() + .zip(futn.axis_iter_mut(ndarray::Axis(1))) + { + let fut = &mut fut.as_slice_mut().unwrap()[..ny]; + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + if bl.is_zero() { + continue; + } + let prev = &prev.as_slice().unwrap()[..ny]; + for (fut, prev) in fut.iter_mut().zip(prev) { + *fut = *fut + idx * prev * bl; + } + } + } +} + #[inline(always)] pub(crate) fn diff_op_2d_sliceable( matrix: &BlockMatrix, @@ -308,6 +407,7 @@ pub(crate) fn diff_op_2d( assert_eq!(prev.shape(), fut.shape()); match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => diff_op_2d_sliceable(matrix, optype, prev, fut), + ([1, _], [1, _]) => diff_op_2d_sliceable_y(matrix, optype, prev, fut), _ => diff_op_2d_fallback(matrix, optype, prev, fut), } }