use diff_op_col as fallback for Upwind4
This commit is contained in:
parent
8d90d8106d
commit
f90618be42
|
@ -1,4 +1,6 @@
|
||||||
use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d};
|
use super::{
|
||||||
|
diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d,
|
||||||
|
};
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
|
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
|
||||||
|
|
||||||
|
@ -229,17 +231,21 @@ impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4) {
|
||||||
assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
|
assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
|
||||||
|
|
||||||
match (prev.strides(), fut.strides()) {
|
match (prev.strides(), fut.strides()) {
|
||||||
([_, 1], [_, 1]) => {
|
([_, 1], [_, 1]) => diff_op_row(
|
||||||
diff_op_row(
|
Upwind4::BLOCK,
|
||||||
Upwind4::BLOCK,
|
Upwind4::DIAG,
|
||||||
Upwind4::DIAG,
|
super::Symmetry::AntiSymmetric,
|
||||||
super::Symmetry::AntiSymmetric,
|
super::OperatorType::Normal,
|
||||||
super::OperatorType::Normal,
|
)(prev, fut),
|
||||||
)(prev, fut);
|
|
||||||
}
|
|
||||||
([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
|
([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
|
||||||
diff_simd_col(prev, fut);
|
diff_simd_col(prev, fut)
|
||||||
}
|
}
|
||||||
|
([1, _], [1, _]) => diff_op_col(
|
||||||
|
Upwind4::BLOCK,
|
||||||
|
Upwind4::DIAG,
|
||||||
|
super::Symmetry::AntiSymmetric,
|
||||||
|
super::OperatorType::Normal,
|
||||||
|
)(prev, fut),
|
||||||
([_, _], [_, _]) => {
|
([_, _], [_, _]) => {
|
||||||
// Fallback, work row by row
|
// Fallback, work row by row
|
||||||
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
|
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
|
||||||
|
@ -373,17 +379,21 @@ impl<UO: UpwindOperator1d> UpwindOperator2d for (&UO, &Upwind4) {
|
||||||
assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
|
assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
|
||||||
|
|
||||||
match (prev.strides(), fut.strides()) {
|
match (prev.strides(), fut.strides()) {
|
||||||
([_, 1], [_, 1]) => {
|
([_, 1], [_, 1]) => diff_op_row(
|
||||||
diff_op_row(
|
Upwind4::DISS_BLOCK,
|
||||||
Upwind4::DISS_BLOCK,
|
Upwind4::DISS_DIAG,
|
||||||
Upwind4::DISS_DIAG,
|
super::Symmetry::Symmetric,
|
||||||
super::Symmetry::Symmetric,
|
super::OperatorType::Normal,
|
||||||
super::OperatorType::Normal,
|
)(prev, fut),
|
||||||
)(prev, fut);
|
|
||||||
}
|
|
||||||
([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
|
([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
|
||||||
diss_simd_col(prev, fut);
|
diss_simd_col(prev, fut);
|
||||||
}
|
}
|
||||||
|
([1, _], [1, _]) => diff_op_row(
|
||||||
|
Upwind4::DISS_BLOCK,
|
||||||
|
Upwind4::DISS_DIAG,
|
||||||
|
super::Symmetry::Symmetric,
|
||||||
|
super::OperatorType::Normal,
|
||||||
|
)(prev, fut),
|
||||||
([_, _], [_, _]) => {
|
([_, _], [_, _]) => {
|
||||||
// Fallback, work row by row
|
// Fallback, work row by row
|
||||||
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
|
for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
|
||||||
|
|
Loading…
Reference in New Issue