use diff_op_col as fallback for Upwind4
This commit is contained in:
		
							parent
							
								
									8d90d8106d
								
							
						
					
					
						commit
						f90618be42
					
				| @ -1,4 +1,6 @@ | ||||
| use super::{diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d}; | ||||
| use super::{ | ||||
|     diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d, UpwindOperator1d, UpwindOperator2d, | ||||
| }; | ||||
| use crate::Float; | ||||
| use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2, Axis}; | ||||
| 
 | ||||
| @ -229,17 +231,21 @@ impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4) { | ||||
|         assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); | ||||
| 
 | ||||
|         match (prev.strides(), fut.strides()) { | ||||
|             ([_, 1], [_, 1]) => { | ||||
|                 diff_op_row( | ||||
|                     Upwind4::BLOCK, | ||||
|                     Upwind4::DIAG, | ||||
|                     super::Symmetry::AntiSymmetric, | ||||
|                     super::OperatorType::Normal, | ||||
|                 )(prev, fut); | ||||
|             } | ||||
|             ([_, 1], [_, 1]) => diff_op_row( | ||||
|                 Upwind4::BLOCK, | ||||
|                 Upwind4::DIAG, | ||||
|                 super::Symmetry::AntiSymmetric, | ||||
|                 super::OperatorType::Normal, | ||||
|             )(prev, fut), | ||||
|             ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { | ||||
|                 diff_simd_col(prev, fut); | ||||
|                 diff_simd_col(prev, fut) | ||||
|             } | ||||
|             ([1, _], [1, _]) => diff_op_col( | ||||
|                 Upwind4::BLOCK, | ||||
|                 Upwind4::DIAG, | ||||
|                 super::Symmetry::AntiSymmetric, | ||||
|                 super::OperatorType::Normal, | ||||
|             )(prev, fut), | ||||
|             ([_, _], [_, _]) => { | ||||
|                 // Fallback, work row by row
 | ||||
|                 for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { | ||||
| @ -373,17 +379,21 @@ impl<UO: UpwindOperator1d> UpwindOperator2d for (&UO, &Upwind4) { | ||||
|         assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); | ||||
| 
 | ||||
|         match (prev.strides(), fut.strides()) { | ||||
|             ([_, 1], [_, 1]) => { | ||||
|                 diff_op_row( | ||||
|                     Upwind4::DISS_BLOCK, | ||||
|                     Upwind4::DISS_DIAG, | ||||
|                     super::Symmetry::Symmetric, | ||||
|                     super::OperatorType::Normal, | ||||
|                 )(prev, fut); | ||||
|             } | ||||
|             ([_, 1], [_, 1]) => diff_op_row( | ||||
|                 Upwind4::DISS_BLOCK, | ||||
|                 Upwind4::DISS_DIAG, | ||||
|                 super::Symmetry::Symmetric, | ||||
|                 super::OperatorType::Normal, | ||||
|             )(prev, fut), | ||||
|             ([1, _], [1, _]) if prev.len_of(Axis(0)) % SimdT::lanes() == 0 => { | ||||
|                 diss_simd_col(prev, fut); | ||||
|             } | ||||
|             ([1, _], [1, _]) => diff_op_row( | ||||
|                 Upwind4::DISS_BLOCK, | ||||
|                 Upwind4::DISS_DIAG, | ||||
|                 super::Symmetry::Symmetric, | ||||
|                 super::OperatorType::Normal, | ||||
|             )(prev, fut), | ||||
|             ([_, _], [_, _]) => { | ||||
|                 // Fallback, work row by row
 | ||||
|                 for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user