From 64e1aec2941adc9bddcf1bca715b9ed8b83bf6f1 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Fri, 28 Feb 2020 22:26:18 +0100 Subject: [PATCH] simplify loop conditions --- sbp/src/operators/upwind4.rs | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 83d5d8c..44d3b12 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -6,7 +6,7 @@ use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis}; #[derive(Debug)] pub struct Upwind4 {} -/// Simdtype used in diff_simd_col +/// Simdtype used in diff_simd_col and diff_simd_row #[cfg(feature = "f32")] type SimdT = packed_simd::f32x8; #[cfg(not(feature = "f32"))] @@ -29,7 +29,7 @@ macro_rules! diff_simd_row_7_47 { assert_eq!(prev.shape(), fut.shape()); assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len()); assert!(prev.len() >= SimdT::lanes()); - // The prev and fut array must have contigous last dimension + // The prev and fut array must have contiguous last dimension assert_eq!(prev.strides()[1], 1); assert_eq!(fut.strides()[1], 1); @@ -44,7 +44,6 @@ macro_rules! diff_simd_row_7_47 { let fut = unsafe { slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx) }; - //let mut fut = fut.slice_mut(s![j, ..]); let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) }; let block = { @@ -79,17 +78,11 @@ macro_rules! diff_simd_row_7_47 { diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, ) }; - for (f, p) in fut - .iter_mut() - .skip(block.len()) - .zip( - prev.windows(SimdT::lanes()) - .map(SimdT::from_slice_unaligned) - .skip(1), - ) - .take(nx - 2 * block.len()) - { - *f = idx * (p * diag).sum(); + for i in 4..nx - 4 { + unsafe { + let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]); + *fut.get_unchecked_mut(i) = idx * (prev * diag).sum(); + } } let last_elems = @@ -129,7 +122,7 @@ macro_rules! diff_simd_col_7_47 { use std::slice; assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.stride_of(Axis(0)), 1); - assert_eq!(prev.stride_of(Axis(0)), 1); + assert_eq!(fut.stride_of(Axis(0)), 1); let ny = prev.len_of(Axis(0)); let nx = prev.len_of(Axis(1)); assert!(nx >= 2 * $BLOCK.len());