From f90618be429a6832de3f3d6cd072991e30a4ed39 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Tue, 1 Sep 2020 17:26:27 +0200 Subject: [PATCH] use diff_op_col as fallback for Upwind4 --- sbp/src/operators/upwind4.rs | 46 ++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 69fb97f..5c17ffc 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -1,4 +1,6 @@ -use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; +use super::{ + diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d, +}; use crate::Float; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis}; @@ -229,17 +231,21 @@ impl SbpOperator2d for (&SBP, &Upwind4) { assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); match (prev.strides(), fut.strides()) { - ([_, 1], [_, 1]) => { - diff_op_row( - Upwind4::BLOCK, - Upwind4::DIAG, - super::Symmetry::AntiSymmetric, - super::OperatorType::Normal, - )(prev, fut); - } + ([_, 1], [_, 1]) => diff_op_row( + Upwind4::BLOCK, + Upwind4::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + )(prev, fut), ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { - diff_simd_col(prev, fut); + diff_simd_col(prev, fut) } + ([1, _], [1, _]) => diff_op_col( + Upwind4::BLOCK, + Upwind4::DIAG, + super::Symmetry::AntiSymmetric, + super::OperatorType::Normal, + )(prev, fut), ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { @@ -373,17 +379,21 @@ impl UpwindOperator2d for (&UO, &Upwind4) { assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); match (prev.strides(), fut.strides()) { - ([_, 1], [_, 1]) => { - diff_op_row( - Upwind4::DISS_BLOCK, - Upwind4::DISS_DIAG, - super::Symmetry::Symmetric, - super::OperatorType::Normal, - )(prev, fut); - } + ([_, 1], [_, 1]) => diff_op_row( + Upwind4::DISS_BLOCK, + Upwind4::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::Normal, + )(prev, fut), ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { diss_simd_col(prev, fut); } + ([1, _], [1, _]) => diff_op_row( + Upwind4::DISS_BLOCK, + Upwind4::DISS_DIAG, + super::Symmetry::Symmetric, + super::OperatorType::Normal, + )(prev, fut), ([_, _], [_, _]) => { // Fallback, work row by row for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {