diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 64127cf..1b3e25d 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -192,10 +192,16 @@ pub(crate) fn diff_op_col( }; let idx = 1.0 / dx; - fut.fill(0.0); + #[cfg(not(feature = "f32"))] + type SimdT = packed_simd::f64x8; + #[cfg(feature = "f32")] + type SimdT = packed_simd::f32x16; + + let ny = prev.shape()[0]; // First block for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { + fut.fill(0.0); debug_assert_eq!(fut.len(), prev.shape()[0]); for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { debug_assert_eq!(prev.len(), fut.len()); @@ -204,22 +210,44 @@ pub(crate) fn diff_op_col( } let half_diag_width = (diag.len() - 1) / 2; + assert!(half_diag_width <= block.len()); - // Diagonal entries - for (ifut, mut fut) in fut - .axis_iter_mut(ndarray::Axis(1)) - .enumerate() - .skip(block.len()) - .take(nx - 2 * block.len()) - { - for (id, d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) + for ifut in block.len()..nx - block.len() { + let simdified = SimdT::lanes() * (ny / SimdT::lanes()); + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |(j, i)| unsafe { + // gets simd along stride 1, j never goes past end of slice + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev.uget((j, i)), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + f = f + d * index_to_simd((j, offset)); + } + f = f * idx; + unsafe { + // puts simd along stride 1, j never goes past end of slice + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut.uget_mut((j, ifut)), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + let mut f = 0.0; + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + f += d * prev[(j, offset)]; + } + fut[(j, ifut)] = idx * f; } } - // End block for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { + fut.fill(0.0); for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { if symmetric { fut.scaled_add(idx * bl, &prev);