use diff_op_col as fallback for Upwind4

This commit is contained in:
Magnus Ulimoen 2020-09-01 17:26:27 +02:00
parent 8d90d8106d
commit f90618be42
1 changed files with 28 additions and 18 deletions

View File

@ -1,4 +1,6 @@
use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; use super::{
diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d,
};
use crate::Float; use crate::Float;
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis}; use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
@ -229,17 +231,21 @@ impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4) {
assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
match (prev.strides(), fut.strides()) { match (prev.strides(), fut.strides()) {
([_, 1], [_, 1]) => { ([_, 1], [_, 1]) => diff_op_row(
diff_op_row(
Upwind4::BLOCK, Upwind4::BLOCK,
Upwind4::DIAG, Upwind4::DIAG,
super::Symmetry::AntiSymmetric, super::Symmetry::AntiSymmetric,
super::OperatorType::Normal, super::OperatorType::Normal,
)(prev, fut); )(prev, fut),
}
([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
diff_simd_col(prev, fut); diff_simd_col(prev, fut)
} }
([1, _], [1, _]) => diff_op_col(
Upwind4::BLOCK,
Upwind4::DIAG,
super::Symmetry::AntiSymmetric,
super::OperatorType::Normal,
)(prev, fut),
([_, _], [_, _]) => { ([_, _], [_, _]) => {
// Fallback, work row by row // Fallback, work row by row
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
@ -373,17 +379,21 @@ impl<UO: UpwindOperator1d> UpwindOperator2d for (&UO, &Upwind4) {
assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
match (prev.strides(), fut.strides()) { match (prev.strides(), fut.strides()) {
([_, 1], [_, 1]) => { ([_, 1], [_, 1]) => diff_op_row(
diff_op_row(
Upwind4::DISS_BLOCK, Upwind4::DISS_BLOCK,
Upwind4::DISS_DIAG, Upwind4::DISS_DIAG,
super::Symmetry::Symmetric, super::Symmetry::Symmetric,
super::OperatorType::Normal, super::OperatorType::Normal,
)(prev, fut); )(prev, fut),
}
([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
diss_simd_col(prev, fut); diss_simd_col(prev, fut);
} }
([1, _], [1, _]) => diff_op_row(
Upwind4::DISS_BLOCK,
Upwind4::DISS_DIAG,
super::Symmetry::Symmetric,
super::OperatorType::Normal,
)(prev, fut),
([_, _], [_, _]) => { ([_, _], [_, _]) => {
// Fallback, work row by row // Fallback, work row by row
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {