From cbf6042055f4b5dc82a202b3463828f2ea67198a Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Fri, 1 May 2020 18:21:14 +0200 Subject: [PATCH] use raw pointer in simd code --- sbp/src/operators.rs | 134 ++++++++++++++++++++++++++----------------- 1 file changed, 82 insertions(+), 52 deletions(-) diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 76e4a55..4ba77e5 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -290,61 +290,91 @@ fn diff_op_col_simd( type SimdT = packed_simd::f32x16; let ny = prev.shape()[0]; + // How many elements that can be simdified + let simdified = SimdT::lanes() * (ny / SimdT::lanes()); // First block - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { - fut.fill(0.0); - debug_assert_eq!(fut.len(), prev.shape()[0]); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { - debug_assert_eq!(prev.len(), fut.len()); - fut.scaled_add(idx * bl, &prev); - } - } - - let half_diag_width = (diag.len() - 1) / 2; - assert!(half_diag_width <= block.len()); - - for ifut in block.len()..nx - block.len() { - let simdified = SimdT::lanes() * (ny / SimdT::lanes()); - for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |(j, i)| unsafe { - // gets simd along stride 1, j never goes past end of slice - SimdT::from_slice_unaligned(std::slice::from_raw_parts( - prev.uget((j, i)), - SimdT::lanes(), - )) - }; - let mut f = SimdT::splat(0.0); - for (id, &d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - f = f + d * index_to_simd((j, offset)); - } - f = f * idx; - unsafe { - // puts simd along stride 1, j never goes past end of slice - f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( - fut.uget_mut((j, ifut)), - SimdT::lanes(), - )); - } - } - for j in simdified..ny { - let mut f = 0.0; - for (id, &d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - f += d * prev[(j, offset)]; - } - fut[(j, ifut)] = idx * f; - } - } - // End block - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { - fut.fill(0.0); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { - if symmetry == Symmetry::Symmetric { + { + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { + fut.fill(0.0); + debug_assert_eq!(fut.len(), prev.shape()[0]); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + debug_assert_eq!(prev.len(), fut.len()); fut.scaled_add(idx * bl, &prev); - } else { - fut.scaled_add(-idx * bl, &prev); + } + } + } + + // Diagonal elements + { + let half_diag_width = (diag.len() - 1) / 2; + assert!(half_diag_width <= block.len()); + + let fut_base_ptr = fut.as_mut_ptr(); + let fut_stride = fut.strides()[1]; + let fut_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) } + }; + + let prev_base_ptr = prev.as_ptr(); + let prev_stride = prev.strides()[1]; + let prev_ptr = |j, i| { + debug_assert!(j < ny && i < nx); + unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } + }; + + assert_eq!(fut_stride, prev_stride); + + for ifut in block.len()..nx - block.len() { + for j in (0..simdified).step_by(SimdT::lanes()) { + let index_to_simd = |(j, i)| unsafe { + // j never moves past end of slice due to step_by and + // rounding down + SimdT::from_slice_unaligned(std::slice::from_raw_parts( + prev_ptr(j, i), + SimdT::lanes(), + )) + }; + let mut f = SimdT::splat(0.0); + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + f = index_to_simd((j, offset)).mul_adde(SimdT::splat(d), f); + } + f = f * idx; + unsafe { + // puts simd along stride 1, j never goes past end of slice + f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( + fut_ptr(j, ifut), + SimdT::lanes(), + )); + } + } + for j in simdified..ny { + let mut f = 0.0; + for (id, &d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + unsafe { + f += d * *prev_ptr(j, offset); + } + } + unsafe { + *fut_ptr(j, ifut) = idx * f; + } + } + } + } + + // End block + { + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { + fut.fill(0.0); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { + if symmetry == Symmetry::Symmetric { + fut.scaled_add(idx * bl, &prev); + } else { + fut.scaled_add(-idx * bl, &prev); + } } } }