diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 19c3199..64127cf 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -171,60 +171,61 @@ pub(crate) fn diff_op_1d( #[inline(always)] pub(crate) fn diff_op_col( - block: &[&[Float]], - diag: &[Float], + block: &'static [&'static [Float]], + diag: &'static [Float], symmetric: bool, is_h2: bool, - prev: ArrayView2, - mut fut: ArrayViewMut2, -) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[1]; - assert!(nx >= 2 * block.len()); +) -> impl Fn(ArrayView2, ArrayViewMut2) { + #[inline(always)] + move |prev: ArrayView2, mut fut: ArrayViewMut2| { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); - assert_eq!(prev.strides()[0], 1); - assert_eq!(fut.strides()[0], 1); + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); - let dx = if is_h2 { - 1.0 / (nx - 2) as Float - } else { - 1.0 / (nx - 1) as Float - }; - let idx = 1.0 / dx; + let dx = if is_h2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; - fut.fill(0.0); + fut.fill(0.0); - // First block - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { - debug_assert_eq!(fut.len(), prev.shape()[0]); - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { - debug_assert_eq!(prev.len(), fut.len()); - fut.scaled_add(idx * bl, &prev); - } - } - - let half_diag_width = (diag.len() - 1) / 2; - - // Diagonal entries - for (ifut, mut fut) in fut - .axis_iter_mut(ndarray::Axis(1)) - .enumerate() - .skip(block.len()) - .take(nx - 2 * block.len()) - { - for (id, d) in diag.iter().enumerate() { - let offset = ifut - half_diag_width + id; - fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) - } - } - - // End block - for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { - for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { - if symmetric { + // First block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { + debug_assert_eq!(fut.len(), prev.shape()[0]); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + debug_assert_eq!(prev.len(), fut.len()); fut.scaled_add(idx * bl, &prev); - } else { - fut.scaled_add(-idx * bl, &prev); + } + } + + let half_diag_width = (diag.len() - 1) / 2; + + // Diagonal entries + for (ifut, mut fut) in fut + .axis_iter_mut(ndarray::Axis(1)) + .enumerate() + .skip(block.len()) + .take(nx - 2 * block.len()) + { + for (id, d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) + } + } + + // End block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { + if symmetric { + fut.scaled_add(idx * bl, &prev); + } else { + fut.scaled_add(-idx * bl, &prev); + } } } } diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index e0ff86d..1a26ae5 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -43,7 +43,7 @@ impl SbpOperator2d for (&SBP, &SBP4) { diff_op_row(SBP4::BLOCK, SBP4::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(SBP4::BLOCK, SBP4::DIAG, false, false, prev, fut); + diff_op_col(SBP4::BLOCK, SBP4::DIAG, false, false)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index d288255..7d34931 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -47,7 +47,7 @@ impl SbpOperator2d for (&SBP, &SBP8) { diff_op_row(SBP8::BLOCK, SBP8::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(SBP8::BLOCK, SBP8::DIAG, false, false, prev, fut); + diff_op_col(SBP8::BLOCK, SBP8::DIAG, false, false)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index d507799..06aa12f 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -61,7 +61,7 @@ impl SbpOperator2d for (&SBP, &Upwind4h2) { diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true, prev, fut); + diff_op_col(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -84,14 +84,7 @@ impl UpwindOperator2d for (&UO, &Upwind4h2) { diff_op_row(Upwind4h2::DISS_BLOCK, Upwind4h2::DISS_DIAG, true, true)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col( - Upwind4h2::DISS_BLOCK, - Upwind4h2::DISS_DIAG, - true, - true, - prev, - fut, - ); + diff_op_col(Upwind4h2::DISS_BLOCK, Upwind4h2::DISS_DIAG, true, true)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index 2aab4e6..484b316 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -66,7 +66,7 @@ impl SbpOperator2d for (&SBP, &Upwind9) { diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, false, false)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind9::BLOCK, Upwind9::DIAG, false, false, prev, fut); + diff_op_col(Upwind9::BLOCK, Upwind9::DIAG, false, false)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -99,14 +99,7 @@ impl UpwindOperator2d for (&UO, &Upwind9) { diff_op_row(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true, false)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col( - Upwind9::DISS_BLOCK, - Upwind9::DISS_DIAG, - true, - false, - prev, - fut, - ); + diff_op_col(Upwind9::DISS_BLOCK, Upwind9::DISS_DIAG, true, false)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index 1b0e2e4..c4ebd42 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -69,7 +69,7 @@ impl SbpOperator2d for (&SBP, &Upwind9h2) { diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true, prev, fut); + diff_op_col(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row @@ -126,14 +126,7 @@ impl UpwindOperator2d for (&UO, &Upwind9h2) { diff_op_row(Upwind9h2::DISS_BLOCK, Upwind9h2::DISS_DIAG, true, true)(prev, fut); } ([1, _], [1, _]) => { - diff_op_col( - Upwind9h2::DISS_BLOCK, - Upwind9h2::DISS_DIAG, - true, - true, - prev, - fut, - ); + diff_op_col(Upwind9h2::DISS_BLOCK, Upwind9h2::DISS_DIAG, true, true)(prev, fut); } ([_, _], [_, _]) => { // Fallback, work row by row