diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 4ba77e5..f572b0c 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -293,42 +293,68 @@ fn diff_op_col_simd( // How many elements that can be simdified let simdified = SimdT::lanes() * (ny / SimdT::lanes()); + let half_diag_width = (diag.len() - 1) / 2; + assert!(half_diag_width <= block.len()); + + let fut_base_ptr = fut.as_mut_ptr(); + let fut_stride = fut.strides()[1]; + let fut_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) } + }; + + let prev_base_ptr = prev.as_ptr(); + let prev_stride = prev.strides()[1]; + let prev_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } + }; + + // Not algo necessary, but gives performance increase + assert_eq!(fut_stride, prev_stride); + // First block { - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { - fut.fill(0.0); - debug_assert_eq!(fut.len(), prev.shape()[0]); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { - debug_assert_eq!(prev.len(), fut.len()); - fut.scaled_add(idx * bl, &prev); + for (ifut, &bl) in block.iter().enumerate() { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (iprev, &bl) in bl.iter().enumerate() { + f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); + } + f = f * idx; + + unsafe { + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + unsafe { + let mut f = 0.0; + for (iprev, bl) in bl.iter().enumerate() { + f += bl * *prev_ptr(j, iprev); + } + *fut_ptr(j, ifut) = f * idx; + } } } } // Diagonal elements { - let half_diag_width = (diag.len() - 1) / 2; - assert!(half_diag_width <= block.len()); - - let fut_base_ptr = fut.as_mut_ptr(); - let fut_stride = fut.strides()[1]; - let fut_ptr = |j, i| { - debug_assert!(j < ny && i < nx); - unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) } - }; - - let prev_base_ptr = prev.as_ptr(); - let prev_stride = prev.strides()[1]; - let prev_ptr = |j, i| { - debug_assert!(j < ny && i < nx); - unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } - }; - - assert_eq!(fut_stride, prev_stride); - for ifut in block.len()..nx - block.len() { for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |(j, i)| unsafe { + let index_to_simd = |i| unsafe { // j never moves past end of slice due to step_by and // rounding down SimdT::from_slice_unaligned(std::slice::from_raw_parts( @@ -339,7 +365,7 @@ fn diff_op_col_simd( let mut f = SimdT::splat(0.0); for (id, &d) in diag.iter().enumerate() { let offset = ifut - half_diag_width + id; - f = index_to_simd((j, offset)).mul_adde(SimdT::splat(d), f); + f = index_to_simd(offset).mul_adde(SimdT::splat(d), f); } f = f * idx; unsafe { @@ -367,13 +393,46 @@ fn diff_op_col_simd( // End block { - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { - fut.fill(0.0); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { - if symmetry == Symmetry::Symmetric { - fut.scaled_add(idx * bl, &prev); + // Get blocks and corresponding offsets + // (rev to iterate in ifut increasing order) + for (bl, ifut) in block.iter().zip((0..nx).rev()) { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |i| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (&bl, iprev) in bl.iter().zip((0..nx).rev()) { + f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); + } + f = if symmetry == Symmetry::Symmetric { + f * idx } else { - fut.scaled_add(-idx * bl, &prev); + -f * idx + }; + unsafe { + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + + for j in simdified..ny { + unsafe { + let mut f = 0.0; + for (&bl, iprev) in bl.iter().zip((0..nx).rev()).rev() { + f += bl * *prev_ptr(j, iprev); + } + *fut_ptr(j, ifut) = if symmetry == Symmetry::Symmetric { + f * idx + } else { + -f * idx + }; } } }