diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 82941c2..19c3199 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -232,65 +232,66 @@ pub(crate) fn diff_op_col( #[inline(always)] pub(crate) fn diff_op_row( - block: &[&[Float]], - diag: &[Float], + block: &'static [&'static [Float]], + diag: &'static [Float], symmetric: bool, is_h2: bool, - prev: ArrayView2, - mut fut: ArrayViewMut2, -) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[1]; - assert!(nx >= 2 * block.len()); +) -> impl Fn(ArrayView2, ArrayViewMut2) { + #[inline(always)] + move |prev: ArrayView2, mut fut: ArrayViewMut2| { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); - assert_eq!(prev.strides()[1], 1); - assert_eq!(fut.strides()[1], 1); + assert_eq!(prev.strides()[1], 1); + assert_eq!(fut.strides()[1], 1); - let dx = if is_h2 { - 1.0 / (nx - 2) as Float - } else { - 1.0 / (nx - 1) as Float - }; - let idx = 1.0 / dx; + let dx = if is_h2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; - for (prev, mut fut) in prev - .axis_iter(ndarray::Axis(0)) - .zip(fut.axis_iter_mut(ndarray::Axis(0))) - { - let prev = prev.as_slice().unwrap(); - let fut = fut.as_slice_mut().unwrap(); - - for (bl, f) in block.iter().zip(fut.iter_mut()) { - let diff = bl - .iter() - .zip(prev.iter()) - .map(|(x, y)| x * y) - .sum::(); - *f = diff * idx; - } - - // The window needs to be aligned to the diagonal elements, - // based on the block size - let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); - - for (window, f) in prev - .windows(diag.len()) - .skip(window_elems_to_skip) - .zip(fut.iter_mut().skip(block.len())) - .take(nx - 2 * block.len()) + for (prev, mut fut) in prev + .axis_iter(ndarray::Axis(0)) + .zip(fut.axis_iter_mut(ndarray::Axis(0))) { - let diff = diag.iter().zip(window).map(|(&x, &y)| x * y).sum::(); - *f = diff * idx; - } + let prev = prev.as_slice().unwrap(); + let fut = fut.as_slice_mut().unwrap(); - for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { - let diff = bl - .iter() - .zip(prev.iter().rev()) - .map(|(x, y)| x * y) - .sum::(); + for (bl, f) in block.iter().zip(fut.iter_mut()) { + let diff = bl + .iter() + .zip(prev.iter()) + .map(|(x, y)| x * y) + .sum::(); + *f = diff * idx; + } - *f = idx * if symmetric { diff } else { -diff }; + // The window needs to be aligned to the diagonal elements, + // based on the block size + let window_elems_to_skip = block.len() - ((diag.len() - 1) / 2); + + for (window, f) in prev + .windows(diag.len()) + .skip(window_elems_to_skip) + .zip(fut.iter_mut().skip(block.len())) + .take(nx - 2 * block.len()) + { + let diff = diag.iter().zip(window).map(|(&x, &y)| x * y).sum::(); + *f = diff * idx; + } + + for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { + let diff = bl + .iter() + .zip(prev.iter().rev()) + .map(|(x, y)| x * y) + .sum::(); + + *f = idx * if symmetric { diff } else { -diff }; + } } } } diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index cb3c2b1..e0ff86d 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -40,7 +40,7 @@ impl SbpOperator2d for (&SBP, &SBP4) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(SBP4::BLOCK, SBP4::DIAG, false, false, prev, fut); + diff_op_row(SBP4::BLOCK, SBP4::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) => { diff_op_col(SBP4::BLOCK, SBP4::DIAG, false, false, prev, fut); diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index b738ece..d288255 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -44,7 +44,7 @@ impl SbpOperator2d for (&SBP, &SBP8) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(SBP8::BLOCK, SBP8::DIAG, false, false, prev, fut); + diff_op_row(SBP8::BLOCK, SBP8::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) => { diff_op_col(SBP8::BLOCK, SBP8::DIAG, false, false, prev, fut); diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index 20be562..d507799 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -58,7 +58,7 @@ impl SbpOperator2d for (&SBP, &Upwind4h2) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true, prev, fut); + diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true)(prev, fut); } ([1, _], [1, _]) => { diff_op_col(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true, prev, fut); @@ -81,14 +81,7 @@ impl UpwindOperator2d for (&UO, &Upwind4h2) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row( - Upwind4h2::DISS_BLOCK, - Upwind4h2::DISS_DIAG, - true, - true, - prev, - fut, - ); + diff_op_row(Upwind4h2::DISS_BLOCK, Upwind4h2::DISS_DIAG, true, true)(prev, fut); } ([1, _], [1, _]) => { diff_op_col( diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index 2a52632..2aab4e6 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -63,7 +63,7 @@ impl SbpOperator2d for (&SBP, &Upwind9) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, false, false, prev, fut); + diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) => { diff_op_col(Upwind9::BLOCK, Upwind9::DIAG, false, false, prev, fut); @@ -96,14 +96,7 @@ impl UpwindOperator2d for (&UO, &Upwind9) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row( - Upwind9::DISS_BLOCK, - Upwind9::DISS_DIAG, - true, - false, - prev, - fut, - ); + diff_op_row(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true, false)(prev, fut); } ([1, _], [1, _]) => { diff_op_col( diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index 06a0259..1b0e2e4 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -66,7 +66,7 @@ impl SbpOperator2d for (&SBP, &Upwind9h2) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true, prev, fut); + diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true)(prev, fut); } ([1, _], [1, _]) => { diff_op_col(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true, prev, fut); @@ -123,14 +123,7 @@ impl UpwindOperator2d for (&UO, &Upwind9h2) { match (prev.strides(), fut.strides()) { ([_, 1], [_, 1]) => { - diff_op_row( - Upwind9h2::DISS_BLOCK, - Upwind9h2::DISS_DIAG, - true, - true, - prev, - fut, - ); + diff_op_row(Upwind9h2::DISS_BLOCK, Upwind9h2::DISS_DIAG, true, true)(prev, fut); } ([1, _], [1, _]) => { diff_op_col(