simplify loop conditions

This commit is contained in:
Magnus Ulimoen 2020-02-28 22:26:18 +01:00
parent 3ac17995cd
commit 64e1aec294
1 changed files with 8 additions and 15 deletions

View File

@ -6,7 +6,7 @@ use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
#[derive(Debug)] #[derive(Debug)]
pub struct Upwind4 {} pub struct Upwind4 {}
/// Simdtype used in diff_simd_col /// Simdtype used in diff_simd_col and diff_simd_row
#[cfg(feature = "f32")] #[cfg(feature = "f32")]
type SimdT = packed_simd::f32x8; type SimdT = packed_simd::f32x8;
#[cfg(not(feature = "f32"))] #[cfg(not(feature = "f32"))]
@ -29,7 +29,7 @@ macro_rules! diff_simd_row_7_47 {
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len()); assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len());
assert!(prev.len() >= SimdT::lanes()); assert!(prev.len() >= SimdT::lanes());
// The prev and fut array must have contigous last dimension // The prev and fut array must have contiguous last dimension
assert_eq!(prev.strides()[1], 1); assert_eq!(prev.strides()[1], 1);
assert_eq!(fut.strides()[1], 1); assert_eq!(fut.strides()[1], 1);
@ -44,7 +44,6 @@ macro_rules! diff_simd_row_7_47 {
let fut = unsafe { let fut = unsafe {
slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx) slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx)
}; };
//let mut fut = fut.slice_mut(s![j, ..]);
let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) }; let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) };
let block = { let block = {
@ -79,17 +78,11 @@ macro_rules! diff_simd_row_7_47 {
diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0, diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0,
) )
}; };
for (f, p) in fut for i in 4..nx - 4 {
.iter_mut() unsafe {
.skip(block.len()) let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]);
.zip( *fut.get_unchecked_mut(i) = idx * (prev * diag).sum();
prev.windows(SimdT::lanes()) }
.map(SimdT::from_slice_unaligned)
.skip(1),
)
.take(nx - 2 * block.len())
{
*f = idx * (p * diag).sum();
} }
let last_elems = let last_elems =
@ -129,7 +122,7 @@ macro_rules! diff_simd_col_7_47 {
use std::slice; use std::slice;
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
assert_eq!(prev.stride_of(Axis(0)), 1); assert_eq!(prev.stride_of(Axis(0)), 1);
assert_eq!(prev.stride_of(Axis(0)), 1); assert_eq!(fut.stride_of(Axis(0)), 1);
let ny = prev.len_of(Axis(0)); let ny = prev.len_of(Axis(0));
let nx = prev.len_of(Axis(1)); let nx = prev.len_of(Axis(1));
assert!(nx >= 2 * $BLOCK.len()); assert!(nx >= 2 * $BLOCK.len());