diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index 3ec4aff..72750e9 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -390,10 +390,16 @@ pub(crate) fn diff_op_2d_sliceable_y_simd= 2 * M); - assert_eq!(prev.strides()[0], 1); - assert_eq!(fut.strides()[0], 1); + assert_eq!(prev.strides(), fut.strides()); + assert_eq!(prev.strides(), &[1, ny as isize]); + + let prev = prev.as_slice_memory_order().unwrap(); + let fut = fut.as_slice_memory_order_mut().unwrap(); + let prev = &prev[..nx * ny]; + let fut = &mut fut[..nx * ny]; let dx = if optype == OperatorType::H2 { 1.0 / (nx - 2) as Float @@ -407,103 +413,88 @@ pub(crate) fn diff_op_2d_sliceable_y_simd( + matrix: &Matrix, + idx: Float, + ny: usize, + simdified: usize, + prev: &[Float], + fut: &mut [Float], + ) { + assert_eq!(prev.len(), N * ny); + assert_eq!(fut.len(), M * ny); + let prev = &prev[..N * ny]; + let fut = &mut fut[..M * ny]; - let prev_base_ptr = prev.as_ptr(); - let prev_stride = prev.strides()[1]; - let prev_ptr = |j, i| { - debug_assert!(j < ny && i < nx); - unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) } - }; + let prevcol = |i: usize| -> &[Float] { &prev[i * ny..(i + 1) * ny] }; - // Not algo necessary, but gives performance increase - assert_eq!(fut_stride, prev_stride); - - // First block - { - for (ifut, &bl) in matrix.start.iter_rows().enumerate() { - for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |i| unsafe { - // j never moves past end of slice due to step_by and - // rounding down - SimdT::from_slice_unaligned(std::slice::from_raw_parts( - prev_ptr(j, i), - SimdT::lanes(), - )) - }; + for (&bl, fut) in matrix.iter_rows().zip(fut.chunks_exact_mut(ny)) { + let mut fut = fut.chunks_exact_mut(SimdT::lanes()); + for (j, fut) in fut.by_ref().enumerate() { + let index_to_simd = + |i| SimdT::from_slice_unaligned(&prevcol(i)[SimdT::lanes() * j..]); let mut f = SimdT::splat(0.0); for (iprev, &bl) in bl.iter().enumerate() { f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f); } f *= idx; - - unsafe { - f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( - fut_ptr(j, ifut), - SimdT::lanes(), - )); - } + f.write_to_slice_unaligned(fut); } - for j in simdified..ny { - unsafe { - let mut f = 0.0; - for (iprev, bl) in bl.iter().enumerate() { - f += bl * *prev_ptr(j, iprev); - } - *fut_ptr(j, ifut) = f * idx; + for (j, fut) in (simdified..ny).zip(fut.into_remainder()) { + let mut f = 0.0; + for (iprev, bl) in bl.iter().enumerate() { + f += bl * prevcol(iprev)[j]; } + *fut = f * idx; } } } + // First block + { + let prev = &prev[..N * ny]; + block_multiply(&matrix.start, idx, ny, simdified, prev, fut0); + } + // Diagonal elements { - for ifut in M..nx - M { - for j in (0..simdified).step_by(SimdT::lanes()) { - let index_to_simd = |i| unsafe { - // j never moves past end of slice due to step_by and - // rounding down - SimdT::from_slice_unaligned(std::slice::from_raw_parts( - prev_ptr(j, i), - SimdT::lanes(), - )) - }; + let half_diag_width = (D - 1) / 2; + assert!(half_diag_width <= M); + + let prevcol = |i: usize| -> &[Float] { &prev[i * ny..(i + 1) * ny] }; + for (fut, ifut) in futmid.chunks_exact_mut(ny).zip(M..nx - M) { + let mut fut = fut.chunks_exact_mut(SimdT::lanes()); + for (j, fut) in fut.by_ref().enumerate() { + let index_to_simd = + |i| SimdT::from_slice_unaligned(&prevcol(i)[SimdT::lanes() * j..]); let mut f = SimdT::splat(0.0); for (id, &d) in matrix.diag.iter().enumerate() { let offset = ifut - half_diag_width + id; f = index_to_simd(offset).mul_adde(SimdT::splat(d), f); } f *= idx; - unsafe { + { // puts simd along stride 1, j never goes past end of slice - f.write_to_slice_unaligned(std::slice::from_raw_parts_mut( - fut_ptr(j, ifut), - SimdT::lanes(), - )); + f.write_to_slice_unaligned(fut); } } - for j in simdified..ny { + for (j, fut) in (simdified..ny).zip(fut.into_remainder()) { let mut f = 0.0; for (id, &d) in matrix.diag.iter().enumerate() { let offset = ifut - half_diag_width + id; - unsafe { - f += d * *prev_ptr(j, offset); + { + f += d * prevcol(offset)[j]; } } - unsafe { - *fut_ptr(j, ifut) = idx * f; + { + *fut = idx * f; } } } @@ -511,41 +502,8 @@ pub(crate) fn diff_op_2d_sliceable_y_simd( assert_eq!(prev.shape(), fut.shape()); match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => diff_op_2d_sliceable(matrix, optype, prev, fut), - ([1, _], [1, _]) => diff_op_2d_sliceable_y_simd(matrix, optype, prev, fut), + ([1, _], [1, _]) + if prev.as_slice_memory_order().is_some() && fut.as_slice_memory_order().is_some() => + { + diff_op_2d_sliceable_y_simd(matrix, optype, prev, fut) + } _ => diff_op_2d_fallback(matrix, optype, prev, fut), } }