From c6e467bc2d7732318e47cc9074c2c516884f53db Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Wed, 22 Apr 2020 00:32:49 +0200 Subject: [PATCH] add col-wise specialisation --- sbp/src/operators.rs | 61 +++++++++++++++++++++++++++++++ sbp/src/operators/traditional4.rs | 5 ++- sbp/src/operators/traditional8.rs | 5 ++- sbp/src/operators/upwind4h2.rs | 17 ++++++++- sbp/src/operators/upwind9.rs | 17 ++++++++- sbp/src/operators/upwind9h2.rs | 17 ++++++++- 6 files changed, 117 insertions(+), 5 deletions(-) diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 4aa79ef..423c1f5 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -169,6 +169,67 @@ pub(crate) fn diff_op_1d( } } +#[inline(always)] +pub(crate) fn diff_op_col( + block: &[&[Float]], + diag: &[Float], + symmetric: bool, + is_h2: bool, + prev: ArrayView2, + mut fut: ArrayViewMut2, +) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[1]; + assert!(nx >= 2 * block.len()); + + assert_eq!(prev.strides()[0], 1); + assert_eq!(fut.strides()[0], 1); + + let dx = if is_h2 { + 1.0 / (nx - 2) as Float + } else { + 1.0 / (nx - 1) as Float + }; + let idx = 1.0 / dx; + + fut.fill(0.0); + + // First block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) { + debug_assert_eq!(fut.len(), prev.shape()[0]); + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) { + debug_assert_eq!(prev.len(), fut.len()); + fut.scaled_add(idx * bl, &prev); + } + } + + let half_diag_width = (diag.len() - 1) / 2; + + // Diagonal entries + for (ifut, mut fut) in fut + .axis_iter_mut(ndarray::Axis(1)) + .enumerate() + .skip(block.len()) + .take(nx - 2 * block.len()) + { + for (id, d) in diag.iter().enumerate() { + let offset = ifut - half_diag_width + id; + fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset])) + } + } + + // End block + for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) { + for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) { + if symmetric { + fut.scaled_add(idx * bl, &prev); + } else { + fut.scaled_add(-idx * bl, &prev); + } + } + } +} + #[inline(always)] pub(crate) fn diff_op_row( block: &[&[Float]], diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index 2221966..a83bfb8 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -1,4 +1,4 @@ -use super::{diff_op_row, SbpOperator1d, SbpOperator2d}; +use super::{diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -42,6 +42,9 @@ impl SbpOperator2d for (&SBP, &SBP4) { ([_, 1], [_, 1]) => { diff_op_row(SBP4::BLOCK, SBP4::DIAG, false, false, prev, fut); } + ([1, _], [1, _]) => { + diff_op_col(SBP4::BLOCK, SBP4::DIAG, false, false, prev, fut); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index 062d134..a109579 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -1,4 +1,4 @@ -use super::{diff_op_row, SbpOperator1d, SbpOperator2d}; +use super::{diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -46,6 +46,9 @@ impl SbpOperator2d for (&SBP, &SBP8) { ([_, 1], [_, 1]) => { diff_op_row(SBP8::BLOCK, SBP8::DIAG, false, false, prev, fut); } + ([1, _], [1, _]) => { + diff_op_col(SBP8::BLOCK, SBP8::DIAG, false, false, prev, fut); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index a8d3dd3..113a3e0 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -1,4 +1,6 @@ -use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; +use super::{ + diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d, +}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -58,6 +60,9 @@ impl SbpOperator2d for (&SBP, &Upwind4h2) { ([_, 1], [_, 1]) => { diff_op_row(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true, prev, fut); } + ([1, _], [1, _]) => { + diff_op_col(Upwind4h2::BLOCK, Upwind4h2::DIAG, false, true, prev, fut); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { @@ -85,6 +90,16 @@ impl UpwindOperator2d for (&UO, &Upwind4h2) { fut, ); } + ([1, _], [1, _]) => { + diff_op_col( + Upwind4h2::DISS_BLOCK, + Upwind4h2::DISS_DIAG, + true, + true, + prev, + fut, + ); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index cb9abaf..b42327e 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -1,4 +1,6 @@ -use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; +use super::{ + diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d, +}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -63,6 +65,9 @@ impl SbpOperator2d for (&SBP, &Upwind9) { ([_, 1], [_, 1]) => { diff_op_row(Upwind9::BLOCK, Upwind9::DIAG, false, false, prev, fut); } + ([1, _], [1, _]) => { + diff_op_col(Upwind9::BLOCK, Upwind9::DIAG, false, false, prev, fut); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { @@ -100,6 +105,16 @@ impl UpwindOperator2d for (&UO, &Upwind9) { fut, ); } + ([1, _], [1, _]) => { + diff_op_col( + Upwind9::DISS_BLOCK, + Upwind9::DISS_DIAG, + true, + false, + prev, + fut, + ); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index a070893..a32b0dd 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -1,4 +1,6 @@ -use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; +use super::{ + diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d, +}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; @@ -66,6 +68,9 @@ impl SbpOperator2d for (&SBP, &Upwind9h2) { ([_, 1], [_, 1]) => { diff_op_row(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true, prev, fut); } + ([1, _], [1, _]) => { + diff_op_col(Upwind9h2::BLOCK, Upwind9h2::DIAG, false, true, prev, fut); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { @@ -127,6 +132,16 @@ impl UpwindOperator2d for (&UO, &Upwind9h2) { fut, ); } + ([1, _], [1, _]) => { + diff_op_col( + Upwind9h2::DISS_BLOCK, + Upwind9h2::DISS_DIAG, + true, + true, + prev, + fut, + ); + } ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {