simplify loop conditions
This commit is contained in:
parent
3ac17995cd
commit
64e1aec294
|
@ -6,7 +6,7 @@ use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
|
|||
#[derive(Debug)]
|
||||
pub struct Upwind4 {}
|
||||
|
||||
/// Simdtype used in diff_simd_col
|
||||
/// Simdtype used in diff_simd_col and diff_simd_row
|
||||
#[cfg(feature = "f32")]
|
||||
type SimdT = packed_simd::f32x8;
|
||||
#[cfg(not(feature = "f32"))]
|
||||
|
@ -29,7 +29,7 @@ macro_rules! diff_simd_row_7_47 {
|
|||
assert_eq!(prev.shape(), fut.shape());
|
||||
assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len());
|
||||
assert!(prev.len() >= SimdT::lanes());
|
||||
// The prev and fut array must have contigous last dimension
|
||||
// The prev and fut array must have contiguous last dimension
|
||||
assert_eq!(prev.strides()[1], 1);
|
||||
assert_eq!(fut.strides()[1], 1);
|
||||
|
||||
|
@ -44,7 +44,6 @@ macro_rules! diff_simd_row_7_47 {
|
|||
let fut = unsafe {
|
||||
slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx)
|
||||
};
|
||||
//let mut fut = fut.slice_mut(s![j, ..]);
|
||||
|
||||
let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) };
|
||||
let block = {
|
||||
|
@ -79,17 +78,11 @@ macro_rules! diff_simd_row_7_47 {
|
|||
diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0,
|
||||
)
|
||||
};
|
||||
for (f, p) in fut
|
||||
.iter_mut()
|
||||
.skip(block.len())
|
||||
.zip(
|
||||
prev.windows(SimdT::lanes())
|
||||
.map(SimdT::from_slice_unaligned)
|
||||
.skip(1),
|
||||
)
|
||||
.take(nx - 2 * block.len())
|
||||
{
|
||||
*f = idx * (p * diag).sum();
|
||||
for i in 4..nx - 4 {
|
||||
unsafe {
|
||||
let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]);
|
||||
*fut.get_unchecked_mut(i) = idx * (prev * diag).sum();
|
||||
}
|
||||
}
|
||||
|
||||
let last_elems =
|
||||
|
@ -129,7 +122,7 @@ macro_rules! diff_simd_col_7_47 {
|
|||
use std::slice;
|
||||
assert_eq!(prev.shape(), fut.shape());
|
||||
assert_eq!(prev.stride_of(Axis(0)), 1);
|
||||
assert_eq!(prev.stride_of(Axis(0)), 1);
|
||||
assert_eq!(fut.stride_of(Axis(0)), 1);
|
||||
let ny = prev.len_of(Axis(0));
|
||||
let nx = prev.len_of(Axis(1));
|
||||
assert!(nx >= 2 * $BLOCK.len());
|
||||
|
|
Loading…
Reference in New Issue