use iterators to remove bounds checks

This commit is contained in:
Magnus Ulimoen 2019-11-08 08:19:02 +01:00
parent a6bf554c60
commit 1f745745ca
1 changed files with 9 additions and 10 deletions

View File

@ -58,14 +58,15 @@ impl Upwind4 {
let nx = prev.shape()[0]; let nx = prev.shape()[0];
let dx = 1.0 / (nx - 1) as f32; let dx = 1.0 / (nx - 1) as f32;
let idx = 1.0 / dx;
let diag = arr1(Self::DIAG); let diag = arr1(Self::DIAG);
let block = arr2(Self::BLOCK); let block = arr2(Self::BLOCK);
let first_elems = prev.slice(s!(..7)); let first_elems = prev.slice(s!(..7));
for i in 0..4 { for (bl, f) in block.outer_iter().zip(&mut fut) {
let diff = first_elems.dot(&block.slice(s!(i, ..))); let diff = first_elems.dot(&bl);
fut[i] += diff / dx; *f = diff * idx;
} }
for (window, f) in prev for (window, f) in prev
@ -76,15 +77,13 @@ impl Upwind4 {
.take(nx - 8) .take(nx - 8)
{ {
let diff = diag.dot(&window); let diff = diag.dot(&window);
*f += diff / dx; *f += diff * idx;
} }
let last_elems = prev.slice(s!(nx - 7..)); let last_elems = prev.slice(s!(nx - 7..;-1));
for i in 0..4 { for (bl, f) in block.outer_iter().zip(&mut fut.slice_mut(s![nx - 4..;-1])) {
let ii = nx - 4 + i; let diff = bl.dot(&last_elems);
let block = block.slice(s!(3 - i, ..;-1)); *f += -diff * idx;
let diff = last_elems.dot(&block);
fut[ii] += -diff / dx;
} }
} }
} }