use simd
This commit is contained in:
parent
1bf4753e9a
commit
c30a463e83
|
@ -16,6 +16,7 @@ console_error_panic_hook = { version = "0.1.6", optional = true }
|
||||||
wee_alloc = { version = "0.4.5", optional = true }
|
wee_alloc = { version = "0.4.5", optional = true }
|
||||||
ndarray = { version = "0.13.0", features = ["approx"] }
|
ndarray = { version = "0.13.0", features = ["approx"] }
|
||||||
approx = "0.3.2"
|
approx = "0.3.2"
|
||||||
|
packed_simd = "0.3.3"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
opt-level = 3
|
opt-level = 3
|
||||||
|
|
|
@ -53,10 +53,106 @@ impl Upwind4 {
|
||||||
],
|
],
|
||||||
];
|
];
|
||||||
|
|
||||||
|
fn diff_simd(prev: &[f32], fut: &mut [f32]) {
|
||||||
|
use packed_simd::{f32x8, u32x8};
|
||||||
|
assert_eq!(prev.len(), fut.len());
|
||||||
|
assert!(prev.len() > 8);
|
||||||
|
let nx = prev.len();
|
||||||
|
let dx = 1.0 / (nx - 1) as f32;
|
||||||
|
let idx = 1.0 / dx;
|
||||||
|
|
||||||
|
let first_elems = unsafe { f32x8::from_slice_unaligned_unchecked(prev) };
|
||||||
|
let block = [
|
||||||
|
f32x8::new(
|
||||||
|
Self::BLOCK[0][0],
|
||||||
|
Self::BLOCK[0][1],
|
||||||
|
Self::BLOCK[0][2],
|
||||||
|
Self::BLOCK[0][3],
|
||||||
|
Self::BLOCK[0][4],
|
||||||
|
Self::BLOCK[0][5],
|
||||||
|
Self::BLOCK[0][6],
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
f32x8::new(
|
||||||
|
Self::BLOCK[1][0],
|
||||||
|
Self::BLOCK[1][1],
|
||||||
|
Self::BLOCK[1][2],
|
||||||
|
Self::BLOCK[1][3],
|
||||||
|
Self::BLOCK[1][4],
|
||||||
|
Self::BLOCK[1][5],
|
||||||
|
Self::BLOCK[1][6],
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
f32x8::new(
|
||||||
|
Self::BLOCK[2][0],
|
||||||
|
Self::BLOCK[2][1],
|
||||||
|
Self::BLOCK[2][2],
|
||||||
|
Self::BLOCK[2][3],
|
||||||
|
Self::BLOCK[2][4],
|
||||||
|
Self::BLOCK[2][5],
|
||||||
|
Self::BLOCK[2][6],
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
f32x8::new(
|
||||||
|
Self::BLOCK[3][0],
|
||||||
|
Self::BLOCK[3][1],
|
||||||
|
Self::BLOCK[3][2],
|
||||||
|
Self::BLOCK[3][3],
|
||||||
|
Self::BLOCK[3][4],
|
||||||
|
Self::BLOCK[3][5],
|
||||||
|
Self::BLOCK[3][6],
|
||||||
|
0.0,
|
||||||
|
),
|
||||||
|
];
|
||||||
|
unsafe {
|
||||||
|
*fut.get_unchecked_mut(0) += idx * (block[0] * first_elems).sum();
|
||||||
|
*fut.get_unchecked_mut(1) += idx * (block[1] * first_elems).sum();
|
||||||
|
*fut.get_unchecked_mut(2) += idx * (block[2] * first_elems).sum();
|
||||||
|
*fut.get_unchecked_mut(3) += idx * (block[3] * first_elems).sum()
|
||||||
|
};
|
||||||
|
|
||||||
|
let diag = f32x8::new(
|
||||||
|
Self::DIAG[0],
|
||||||
|
Self::DIAG[1],
|
||||||
|
Self::DIAG[2],
|
||||||
|
Self::DIAG[3],
|
||||||
|
Self::DIAG[4],
|
||||||
|
Self::DIAG[5],
|
||||||
|
Self::DIAG[6],
|
||||||
|
0.0,
|
||||||
|
);
|
||||||
|
for (f, p) in fut
|
||||||
|
.iter_mut()
|
||||||
|
.skip(block.len())
|
||||||
|
.zip(
|
||||||
|
prev.windows(f32x8::lanes())
|
||||||
|
.map(f32x8::from_slice_unaligned)
|
||||||
|
.skip(1),
|
||||||
|
)
|
||||||
|
.take(nx - 2 * block.len())
|
||||||
|
{
|
||||||
|
*f += idx * (p * diag).sum();
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_elems = unsafe { f32x8::from_slice_unaligned_unchecked(&prev[nx - 8..]) }
|
||||||
|
.shuffle1_dyn(u32x8::new(7, 6, 5, 4, 3, 2, 1, 0));
|
||||||
|
unsafe {
|
||||||
|
*fut.get_unchecked_mut(nx - 4) += -idx * (block[3] * last_elems).sum();
|
||||||
|
*fut.get_unchecked_mut(nx - 3) += -idx * (block[2] * last_elems).sum();
|
||||||
|
*fut.get_unchecked_mut(nx - 2) += -idx * (block[1] * last_elems).sum();
|
||||||
|
*fut.get_unchecked_mut(nx - 1) += -idx * (block[0] * last_elems).sum();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn diff(prev: ArrayView1<f32>, mut fut: ArrayViewMut1<f32>) {
|
fn diff(prev: ArrayView1<f32>, mut fut: ArrayViewMut1<f32>) {
|
||||||
assert_eq!(prev.shape(), fut.shape());
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
let nx = prev.shape()[0];
|
let nx = prev.shape()[0];
|
||||||
|
|
||||||
|
if let (Some(p), Some(f)) = (prev.as_slice(), fut.as_slice_mut()) {
|
||||||
|
Self::diff_simd(p, f);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
let dx = 1.0 / (nx - 1) as f32;
|
let dx = 1.0 / (nx - 1) as f32;
|
||||||
let idx = 1.0 / dx;
|
let idx = 1.0 / dx;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue