use diff_op_col as fallback for Upwind4
This commit is contained in:
		@@ -1,4 +1,6 @@
 | 
			
		||||
use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d};
 | 
			
		||||
use super::{
 | 
			
		||||
    diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d,
 | 
			
		||||
};
 | 
			
		||||
use crate::Float;
 | 
			
		||||
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis};
 | 
			
		||||
 | 
			
		||||
@@ -229,17 +231,21 @@ impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4) {
 | 
			
		||||
        assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
 | 
			
		||||
 | 
			
		||||
        match (prev.strides(), fut.strides()) {
 | 
			
		||||
            ([_, 1], [_, 1]) => {
 | 
			
		||||
                diff_op_row(
 | 
			
		||||
                    Upwind4::BLOCK,
 | 
			
		||||
                    Upwind4::DIAG,
 | 
			
		||||
                    super::Symmetry::AntiSymmetric,
 | 
			
		||||
                    super::OperatorType::Normal,
 | 
			
		||||
                )(prev, fut);
 | 
			
		||||
            }
 | 
			
		||||
            ([_, 1], [_, 1]) => diff_op_row(
 | 
			
		||||
                Upwind4::BLOCK,
 | 
			
		||||
                Upwind4::DIAG,
 | 
			
		||||
                super::Symmetry::AntiSymmetric,
 | 
			
		||||
                super::OperatorType::Normal,
 | 
			
		||||
            )(prev, fut),
 | 
			
		||||
            ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
 | 
			
		||||
                diff_simd_col(prev, fut);
 | 
			
		||||
                diff_simd_col(prev, fut)
 | 
			
		||||
            }
 | 
			
		||||
            ([1, _], [1, _]) => diff_op_col(
 | 
			
		||||
                Upwind4::BLOCK,
 | 
			
		||||
                Upwind4::DIAG,
 | 
			
		||||
                super::Symmetry::AntiSymmetric,
 | 
			
		||||
                super::OperatorType::Normal,
 | 
			
		||||
            )(prev, fut),
 | 
			
		||||
            ([_, _], [_, _]) => {
 | 
			
		||||
                // Fallback, work row by row
 | 
			
		||||
                for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
@@ -373,17 +379,21 @@ impl<UO: UpwindOperator1d> UpwindOperator2d for (&UO, &Upwind4) {
 | 
			
		||||
        assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len());
 | 
			
		||||
 | 
			
		||||
        match (prev.strides(), fut.strides()) {
 | 
			
		||||
            ([_, 1], [_, 1]) => {
 | 
			
		||||
                diff_op_row(
 | 
			
		||||
                    Upwind4::DISS_BLOCK,
 | 
			
		||||
                    Upwind4::DISS_DIAG,
 | 
			
		||||
                    super::Symmetry::Symmetric,
 | 
			
		||||
                    super::OperatorType::Normal,
 | 
			
		||||
                )(prev, fut);
 | 
			
		||||
            }
 | 
			
		||||
            ([_, 1], [_, 1]) => diff_op_row(
 | 
			
		||||
                Upwind4::DISS_BLOCK,
 | 
			
		||||
                Upwind4::DISS_DIAG,
 | 
			
		||||
                super::Symmetry::Symmetric,
 | 
			
		||||
                super::OperatorType::Normal,
 | 
			
		||||
            )(prev, fut),
 | 
			
		||||
            ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => {
 | 
			
		||||
                diss_simd_col(prev, fut);
 | 
			
		||||
            }
 | 
			
		||||
            ([1, _], [1, _]) => diff_op_row(
 | 
			
		||||
                Upwind4::DISS_BLOCK,
 | 
			
		||||
                Upwind4::DISS_DIAG,
 | 
			
		||||
                super::Symmetry::Symmetric,
 | 
			
		||||
                super::OperatorType::Normal,
 | 
			
		||||
            )(prev, fut),
 | 
			
		||||
            ([_, _], [_, _]) => {
 | 
			
		||||
                // Fallback, work row by row
 | 
			
		||||
                for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user