simplify loop conditions
This commit is contained in:
parent
3ac17995cd
commit
64e1aec294
|
@ -6,7 +6,7 @@ use ndarray::{s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Upwind4 {}
|
pub struct Upwind4 {}
|
||||||
|
|
||||||
/// Simdtype used in diff_simd_col
|
/// Simdtype used in diff_simd_col and diff_simd_row
|
||||||
#[cfg(feature = "f32")]
|
#[cfg(feature = "f32")]
|
||||||
type SimdT = packed_simd::f32x8;
|
type SimdT = packed_simd::f32x8;
|
||||||
#[cfg(not(feature = "f32"))]
|
#[cfg(not(feature = "f32"))]
|
||||||
|
@ -29,7 +29,7 @@ macro_rules! diff_simd_row_7_47 {
|
||||||
assert_eq!(prev.shape(), fut.shape());
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len());
|
assert!(prev.len_of(Axis(1)) >= 2 * $BLOCK.len());
|
||||||
assert!(prev.len() >= SimdT::lanes());
|
assert!(prev.len() >= SimdT::lanes());
|
||||||
// The prev and fut array must have contigous last dimension
|
// The prev and fut array must have contiguous last dimension
|
||||||
assert_eq!(prev.strides()[1], 1);
|
assert_eq!(prev.strides()[1], 1);
|
||||||
assert_eq!(fut.strides()[1], 1);
|
assert_eq!(fut.strides()[1], 1);
|
||||||
|
|
||||||
|
@ -44,7 +44,6 @@ macro_rules! diff_simd_row_7_47 {
|
||||||
let fut = unsafe {
|
let fut = unsafe {
|
||||||
slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx)
|
slice::from_raw_parts_mut(fut.uget_mut((j, 0)) as *mut Float, nx)
|
||||||
};
|
};
|
||||||
//let mut fut = fut.slice_mut(s![j, ..]);
|
|
||||||
|
|
||||||
let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) };
|
let first_elems = unsafe { SimdT::from_slice_unaligned_unchecked(prev) };
|
||||||
let block = {
|
let block = {
|
||||||
|
@ -79,17 +78,11 @@ macro_rules! diff_simd_row_7_47 {
|
||||||
diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0,
|
diag[0], diag[1], diag[2], diag[3], diag[4], diag[5], diag[6], 0.0,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
for (f, p) in fut
|
for i in 4..nx - 4 {
|
||||||
.iter_mut()
|
unsafe {
|
||||||
.skip(block.len())
|
let prev = SimdT::from_slice_unaligned_unchecked(&prev[i - 3..]);
|
||||||
.zip(
|
*fut.get_unchecked_mut(i) = idx * (prev * diag).sum();
|
||||||
prev.windows(SimdT::lanes())
|
}
|
||||||
.map(SimdT::from_slice_unaligned)
|
|
||||||
.skip(1),
|
|
||||||
)
|
|
||||||
.take(nx - 2 * block.len())
|
|
||||||
{
|
|
||||||
*f = idx * (p * diag).sum();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let last_elems =
|
let last_elems =
|
||||||
|
@ -129,7 +122,7 @@ macro_rules! diff_simd_col_7_47 {
|
||||||
use std::slice;
|
use std::slice;
|
||||||
assert_eq!(prev.shape(), fut.shape());
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
assert_eq!(prev.stride_of(Axis(0)), 1);
|
assert_eq!(prev.stride_of(Axis(0)), 1);
|
||||||
assert_eq!(prev.stride_of(Axis(0)), 1);
|
assert_eq!(fut.stride_of(Axis(0)), 1);
|
||||||
let ny = prev.len_of(Axis(0));
|
let ny = prev.len_of(Axis(0));
|
||||||
let nx = prev.len_of(Axis(1));
|
let nx = prev.len_of(Axis(1));
|
||||||
assert!(nx >= 2 * $BLOCK.len());
|
assert!(nx >= 2 * $BLOCK.len());
|
||||||
|
|
Loading…
Reference in New Issue