From 1f15bcc056d12f7db9d027e9b1e432e092d945b2 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Sun, 17 Jan 2021 15:37:45 +0100 Subject: [PATCH] revisit SBP Traits --- euler/src/lib.rs | 15 ++- maxwell/src/lib.rs | 4 +- multigrid/Cargo.toml | 1 - multigrid/src/main.rs | 38 ++---- multigrid/src/parsing.rs | 42 +++---- sbp/src/operators.rs | 195 ++++++++++++------------------ sbp/src/operators/traditional4.rs | 8 +- sbp/src/operators/traditional8.rs | 8 +- sbp/src/operators/upwind4.rs | 20 ++- sbp/src/operators/upwind4h2.rs | 19 ++- sbp/src/operators/upwind9.rs | 19 ++- sbp/src/operators/upwind9h2.rs | 19 ++- shallow_water/src/lib.rs | 60 ++++----- 13 files changed, 229 insertions(+), 219 deletions(-) diff --git a/euler/src/lib.rs b/euler/src/lib.rs index ee44f82..d95b2bf 100644 --- a/euler/src/lib.rs +++ b/euler/src/lib.rs @@ -115,7 +115,7 @@ impl System { } } -impl System { +impl System { pub fn advance_upwind(&mut self, dt: Float) { let bc = BoundaryCharacteristics { north: BoundaryCharacteristic::This, @@ -558,7 +558,7 @@ pub fn RHS_trad( #[allow(non_snake_case)] pub fn RHS_upwind( - op: &dyn UpwindOperator2d, + op: &dyn SbpOperator2d, k: &mut Field, y: &Field, metrics: &Metrics, @@ -583,7 +583,14 @@ pub fn RHS_upwind( let ad_xi = &mut tmp.4; let ad_eta = &mut tmp.5; - upwind_dissipation(op, (ad_xi, ad_eta), y, metrics, (&mut tmp.0, &mut tmp.1)); + let diss_op = op.upwind().expect("This is not an upwind operator"); + upwind_dissipation( + &*diss_op, + (ad_xi, ad_eta), + y, + metrics, + (&mut tmp.0, &mut tmp.1), + ); azip!((out in &mut k.0, eflux in &dE.0, @@ -594,7 +601,7 @@ pub fn RHS_upwind( *out = (-eflux - fflux + ad_xi + ad_eta)/detj }); - SAT_characteristics(op.as_sbp(), k, y, metrics, boundaries); + SAT_characteristics(op, k, y, metrics, boundaries); } #[allow(clippy::many_single_char_names)] diff --git a/maxwell/src/lib.rs b/maxwell/src/lib.rs index 3d43c5f..0fa6dc3 100644 --- a/maxwell/src/lib.rs +++ b/maxwell/src/lib.rs @@ -196,7 +196,7 @@ impl System { } } -impl System { +impl System { /// Using artificial dissipation with the upwind operator pub fn advance_upwind(&mut self, dt: Float) { let op = &self.op; @@ -271,7 +271,7 @@ fn RHS( } #[allow(non_snake_case)] -fn RHS_upwind( +fn RHS_upwind( op: &UO, k: &mut Field, y: &Field, diff --git a/multigrid/Cargo.toml b/multigrid/Cargo.toml index e3067a1..2c014c4 100644 --- a/multigrid/Cargo.toml +++ b/multigrid/Cargo.toml @@ -13,7 +13,6 @@ rayon = "1.3.0" indicatif = "0.15.0" structopt = "0.3.14" ndarray = { version = "0.13.1", features = ["serde"] } -either = "1.5.3" serde = { version = "1.0.115", features = ["derive"] } json5 = "0.2.8" indexmap = { version = "1.5.2", features = ["serde-1"] } diff --git a/multigrid/src/main.rs b/multigrid/src/main.rs index f24ddca..e885bfb 100644 --- a/multigrid/src/main.rs +++ b/multigrid/src/main.rs @@ -1,15 +1,12 @@ -use either::*; use structopt::StructOpt; -use sbp::operators::{SbpOperator2d, UpwindOperator2d}; +use sbp::operators::SbpOperator2d; use sbp::*; mod file; mod parsing; use file::*; -pub(crate) type DiffOp = Either, Box>; - struct System { fnow: Vec, fnext: Vec, @@ -20,14 +17,14 @@ struct System { bt: Vec, eb: Vec, time: Float, - operators: Vec, + operators: Vec>, } impl System { fn new( grids: Vec, bt: Vec, - operators: Vec, + operators: Vec>, ) -> Self { let fnow = grids .iter() @@ -42,10 +39,7 @@ impl System { let metrics = grids .iter() .zip(&operators) - .map(|(g, op)| { - let sbpop: &dyn SbpOperator2d = op.as_ref().either(|op| &**op, |uo| uo.as_sbp()); - g.metrics(sbpop).unwrap() - }) + .map(|(g, op)| g.metrics(&**op).unwrap()) .collect::>(); let eb = bt @@ -97,13 +91,10 @@ impl System { { s.spawn(move |_| { let bc = euler::boundary_extracts(prev_all, bt, prev, grid, eb, time); - match op.as_ref() { - Left(sbp) => { - euler::RHS_trad(&**sbp, fut, prev, metrics, &bc, &mut wb.0); - } - Right(uo) => { - euler::RHS_upwind(&**uo, fut, prev, metrics, &bc, &mut wb.0); - } + if op.upwind().is_some() { + euler::RHS_upwind(&**op, fut, prev, metrics, &bc, &mut wb.0); + } else { + euler::RHS_trad(&**op, fut, prev, metrics, &bc, &mut wb.0); } }) } @@ -130,12 +121,10 @@ impl System { /// Suggested maximum dt for this problem fn max_dt(&self) -> Float { - let is_h2 = self.operators.iter().any(|op| { - op.as_ref().either( - |op| op.is_h2xi() || op.is_h2eta(), - |op| op.is_h2xi() || op.is_h2eta(), - ) - }); + let is_h2 = self + .operators + .iter() + .any(|op| op.is_h2xi() || op.is_h2eta()); let c_max = if is_h2 { 0.5 } else { 1.0 }; let mut max_dt: Float = Float::INFINITY; @@ -283,8 +272,7 @@ fn main() { for ((fmod, grid), op) in sys.fnow.iter().zip(&sys.grids).zip(&sys.operators) { let mut fvort = fmod.clone(); fvort.vortex(grid.x(), grid.y(), time, &vortexparams); - let sbpop: &dyn SbpOperator2d = op.as_ref().either(|op| &**op, |uo| uo.as_sbp()); - e += fmod.h2_err(&fvort, sbpop); + e += fmod.h2_err(&fvort, &**op); } println!("Total error: {:e}", e); } diff --git a/multigrid/src/parsing.rs b/multigrid/src/parsing.rs index 506bdf3..241592b 100644 --- a/multigrid/src/parsing.rs +++ b/multigrid/src/parsing.rs @@ -1,5 +1,4 @@ -use super::DiffOp; -use either::*; +use sbp::operators::SbpOperator2d; use sbp::utils::h2linspace; use sbp::Float; @@ -148,7 +147,7 @@ pub struct RuntimeConfiguration { pub names: Vec, pub grids: Vec, pub bc: Vec, - pub op: Vec, + pub op: Vec>, pub integration_time: Float, pub vortex: euler::VortexParameters, } @@ -223,32 +222,19 @@ impl Configuration { use sbp::operators::*; use Operator as op; - match (eta, xi) { - (op::Upwind4, op::Upwind4) => { - Right(Box::new(Upwind4) as Box) + + let matcher = |op| -> Box { + match op { + op::Upwind4 => Box::new(Upwind4), + op::Upwind4h2 => Box::new(Upwind4h2), + op::Upwind9 => Box::new(Upwind9), + op::Upwind9h2 => Box::new(Upwind9h2), + op::Sbp4 => Box::new(SBP4), + op::Sbp8 => Box::new(SBP8), } - (op::Upwind4h2, op::Upwind4h2) => { - Right(Box::new(Upwind4h2) as Box) - } - (op::Upwind9, op::Upwind9) => { - Right(Box::new(Upwind9) as Box) - } - (op::Upwind9h2, op::Upwind9h2) => { - Right(Box::new(Upwind9h2) as Box) - } - (op::Upwind4, op::Upwind4h2) => { - Right(Box::new((&Upwind4, &Upwind4h2)) as Box) - } - (op::Upwind9, op::Upwind9h2) => { - Right(Box::new((&Upwind9, &Upwind9h2)) as Box) - } - (op::Upwind9h2, op::Upwind9) => { - Right(Box::new((&Upwind9h2, &Upwind9)) as Box) - } - (op::Sbp4, op::Sbp4) => Left(Box::new(SBP4) as Box), - (op::Sbp8, op::Sbp8) => Left(Box::new(SBP8) as Box), - _ => todo!("Combination {:?}, {:?} not implemented", eta, xi), - } + }; + + Box::new((matcher(eta), matcher(xi))) as Box }) .collect(); let bc = self diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 41a8864..0c96223 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -40,80 +40,6 @@ pub trait SbpOperator1d2: SbpOperator1d { fn d1_vec(&self, n: usize, front: bool) -> sprs::CsMat; } -pub trait SbpOperator2d: Send + Sync { - fn diffxi(&self, prev: ArrayView2, fut: ArrayViewMut2); - fn diffeta(&self, prev: ArrayView2, fut: ArrayViewMut2); - - fn hxi(&self) -> &'static [Float]; - fn heta(&self) -> &'static [Float]; - - fn is_h2xi(&self) -> bool; - fn is_h2eta(&self) -> bool; - - fn op_xi(&self) -> &dyn SbpOperator1d; - fn op_eta(&self) -> &dyn SbpOperator1d; -} - -impl SbpOperator2d for (&SBPeta, &SBPxi) { - default fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - self.1.diff(r0, r1) - } - } - fn diffeta(&self, prev: ArrayView2, fut: ArrayViewMut2) { - let ba = (self.1, self.0); - ba.diffxi(prev.reversed_axes(), fut.reversed_axes()) - } - fn hxi(&self) -> &'static [Float] { - self.1.h() - } - fn heta(&self) -> &'static [Float] { - self.0.h() - } - fn is_h2xi(&self) -> bool { - self.1.is_h2() - } - fn is_h2eta(&self) -> bool { - self.0.is_h2() - } - - fn op_xi(&self) -> &dyn SbpOperator1d { - self.1 - } - fn op_eta(&self) -> &dyn SbpOperator1d { - self.0 - } -} - -impl SbpOperator2d for SBP { - fn diffxi(&self, prev: ArrayView2, fut: ArrayViewMut2) { - <(&SBP, &SBP) as SbpOperator2d>::diffxi(&(self, self), prev, fut) - } - fn diffeta(&self, prev: ArrayView2, fut: ArrayViewMut2) { - <(&SBP, &SBP) as SbpOperator2d>::diffeta(&(self, self), prev, fut) - } - fn hxi(&self) -> &'static [Float] { - <(&SBP, &SBP) as SbpOperator2d>::hxi(&(self, self)) - } - fn heta(&self) -> &'static [Float] { - <(&SBP, &SBP) as SbpOperator2d>::heta(&(self, self)) - } - fn is_h2xi(&self) -> bool { - <(&SBP, &SBP) as SbpOperator2d>::is_h2xi(&(self, self)) - } - fn is_h2eta(&self) -> bool { - <(&SBP, &SBP) as SbpOperator2d>::is_h2eta(&(self, self)) - } - - fn op_xi(&self) -> &dyn SbpOperator1d { - self - } - fn op_eta(&self) -> &dyn SbpOperator1d { - self - } -} - pub trait UpwindOperator1d: SbpOperator1d + Send + Sync { /// Dissipation operator fn diss(&self, prev: ArrayView1, fut: ArrayViewMut1); @@ -123,57 +49,55 @@ pub trait UpwindOperator1d: SbpOperator1d + Send + Sync { fn diss_matrix(&self, n: usize) -> sprs::CsMat; } -pub trait UpwindOperator2d: SbpOperator2d + Send + Sync { - fn dissxi(&self, prev: ArrayView2, fut: ArrayViewMut2); - fn disseta(&self, prev: ArrayView2, fut: ArrayViewMut2); - fn as_sbp(&self) -> &dyn SbpOperator2d; +pub trait SbpOperator2d: Send + Sync { + fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + for (p, f) in prev.outer_iter().zip(fut.outer_iter_mut()) { + self.op_xi().diff(p, f) + } + } + fn diffeta(&self, prev: ArrayView2, fut: ArrayViewMut2) { + self.diffxi(prev.reversed_axes(), fut.reversed_axes()) + } + + fn hxi(&self) -> &'static [Float] { + self.op_xi().h() + } + fn heta(&self) -> &'static [Float] { + self.op_eta().h() + } + + fn is_h2xi(&self) -> bool { + self.op_xi().is_h2() + } + fn is_h2eta(&self) -> bool { + self.op_eta().is_h2() + } + + fn op_xi(&self) -> &dyn SbpOperator1d; + fn op_eta(&self) -> &dyn SbpOperator1d; + + fn upwind(&self) -> Option> { + None + } +} + +pub trait UpwindOperator2d: Send + Sync { + fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { + assert_eq!(prev.shape(), fut.shape()); + for (p, f) in prev.outer_iter().zip(fut.outer_iter_mut()) { + UpwindOperator2d::op_xi(self).diss(p, f) + } + } + // Assuming operator is symmetrical x/y + fn disseta(&self, prev: ArrayView2, fut: ArrayViewMut2) { + self.dissxi(prev.reversed_axes(), fut.reversed_axes()) + } fn op_xi(&self) -> &dyn UpwindOperator1d; fn op_eta(&self) -> &dyn UpwindOperator1d; } -impl UpwindOperator2d for (&UOeta, &UOxi) { - default fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { - assert_eq!(prev.shape(), fut.shape()); - for (r0, r1) in prev.outer_iter().zip(fut.outer_iter_mut()) { - self.1.diss(r0, r1); - } - } - fn disseta(&self, prev: ArrayView2, fut: ArrayViewMut2) { - let ba = (self.1, self.0); - ba.dissxi(prev.reversed_axes(), fut.reversed_axes()) - } - fn as_sbp(&self) -> &dyn SbpOperator2d { - self - } - - fn op_xi(&self) -> &dyn UpwindOperator1d { - self.1 - } - fn op_eta(&self) -> &dyn UpwindOperator1d { - self.0 - } -} - -impl UpwindOperator2d for UO { - fn dissxi(&self, prev: ArrayView2, fut: ArrayViewMut2) { - <(&UO, &UO) as UpwindOperator2d>::dissxi(&(self, self), prev, fut) - } - fn disseta(&self, prev: ArrayView2, fut: ArrayViewMut2) { - <(&UO, &UO) as UpwindOperator2d>::disseta(&(self, self), prev, fut) - } - fn as_sbp(&self) -> &dyn SbpOperator2d { - self - } - - fn op_xi(&self) -> &dyn UpwindOperator1d { - self - } - fn op_eta(&self) -> &dyn UpwindOperator1d { - self - } -} - pub trait InterpolationOperator: Send + Sync { /// Interpolation from a grid with twice resolution fn fine2coarse(&self, fine: ArrayView1, coarse: ArrayViewMut1); @@ -181,6 +105,37 @@ pub trait InterpolationOperator: Send + Sync { fn coarse2fine(&self, coarse: ArrayView1, fine: ArrayViewMut1); } +impl SbpOperator2d for (Box, Box) { + fn diffxi(&self, prev: ArrayView2, fut: ArrayViewMut2) { + self.1.diffxi(prev, fut) + } + fn diffeta(&self, prev: ArrayView2, fut: ArrayViewMut2) { + self.0.diffeta(prev, fut) + } + + fn op_xi(&self) -> &dyn SbpOperator1d { + self.1.op_xi() + } + fn op_eta(&self) -> &dyn SbpOperator1d { + self.0.op_eta() + } + fn upwind(&self) -> Option> { + match (self.0.upwind(), self.1.upwind()) { + (Some(u), Some(v)) => Some(Box::new((u, v))), + _ => None, + } + } +} + +impl UpwindOperator2d for (Box, Box) { + fn op_xi(&self) -> &dyn UpwindOperator1d { + self.1.op_xi() + } + fn op_eta(&self) -> &dyn UpwindOperator1d { + self.0.op_eta() + } +} + #[inline(always)] fn diff_op_1d( block: &[&[Float]], diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index 0d2af67..649c43c 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -71,7 +71,7 @@ impl SbpOperator1d for SBP4 { } } -impl SbpOperator2d for (&SBP, &SBP4) { +impl SbpOperator2d for SBP4 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * SBP4::BLOCK.len()); @@ -95,6 +95,12 @@ impl SbpOperator2d for (&SBP, &SBP4) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn SbpOperator1d { + &Self + } + fn op_eta(&self) -> &dyn SbpOperator1d { + &Self + } } impl super::SbpOperator1d2 for SBP4 { diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index 2db2e58..6947022 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -59,7 +59,7 @@ impl SbpOperator1d for SBP8 { } } -impl SbpOperator2d for (&SBP, &SBP8) { +impl SbpOperator2d for SBP8 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * SBP8::BLOCK.len()); @@ -83,6 +83,12 @@ impl SbpOperator2d for (&SBP, &SBP8) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn SbpOperator1d { + &Self + } + fn op_eta(&self) -> &dyn SbpOperator1d { + &Self + } } #[test] diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index 3f1aca9..8e577b6 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -229,7 +229,7 @@ impl SbpOperator1d for Upwind4 { } } -impl SbpOperator2d for (&SBP, &Upwind4) { +impl SbpOperator2d for Upwind4 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); @@ -259,6 +259,15 @@ impl SbpOperator2d for (&SBP, &Upwind4) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn SbpOperator1d { + &Self + } + fn op_eta(&self) -> &dyn SbpOperator1d { + &Self + } + fn upwind(&self) -> Option> { + Some(Box::new(Self)) + } } #[test] @@ -377,7 +386,7 @@ impl UpwindOperator1d for Upwind4 { } } -impl UpwindOperator2d for (&UO, &Upwind4) { +impl UpwindOperator2d for Upwind4 { fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4::BLOCK.len()); @@ -407,6 +416,13 @@ impl UpwindOperator2d for (&UO, &Upwind4) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + + fn op_xi(&self) -> &dyn UpwindOperator1d { + &Self + } + fn op_eta(&self) -> &dyn UpwindOperator1d { + &Self + } } #[test] diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index d1cdf75..0ef5f9b 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -77,7 +77,7 @@ impl SbpOperator1d for Upwind4h2 { } } -impl SbpOperator2d for (&SBP, &Upwind4h2) { +impl SbpOperator2d for Upwind4h2 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4h2::BLOCK.len()); @@ -101,9 +101,18 @@ impl SbpOperator2d for (&SBP, &Upwind4h2) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn SbpOperator1d { + &Self + } + fn op_eta(&self) -> &dyn SbpOperator1d { + &Self + } + fn upwind(&self) -> Option> { + Some(Box::new(Self)) + } } -impl UpwindOperator2d for (&UO, &Upwind4h2) { +impl UpwindOperator2d for Upwind4h2 { fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind4h2::BLOCK.len()); @@ -137,6 +146,12 @@ impl UpwindOperator2d for (&UO, &Upwind4h2) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn UpwindOperator1d { + &Self + } + fn op_eta(&self) -> &dyn UpwindOperator1d { + &Self + } } #[test] diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index 1464b4a..596a354 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -82,7 +82,7 @@ impl SbpOperator1d for Upwind9 { } } -impl SbpOperator2d for (&SBP, &Upwind9) { +impl SbpOperator2d for Upwind9 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9::BLOCK.len()); @@ -106,6 +106,15 @@ impl SbpOperator2d for (&SBP, &Upwind9) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn SbpOperator1d { + &Self + } + fn op_eta(&self) -> &dyn SbpOperator1d { + &Self + } + fn upwind(&self) -> Option> { + Some(Box::new(Self)) + } } impl UpwindOperator1d for Upwind9 { @@ -136,7 +145,7 @@ impl UpwindOperator1d for Upwind9 { } } -impl UpwindOperator2d for (&UO, &Upwind9) { +impl UpwindOperator2d for Upwind9 { fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9::BLOCK.len()); @@ -160,6 +169,12 @@ impl UpwindOperator2d for (&UO, &Upwind9) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn UpwindOperator1d { + &Self + } + fn op_eta(&self) -> &dyn UpwindOperator1d { + &Self + } } #[test] diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index 957c549..df2e284 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -85,7 +85,7 @@ impl SbpOperator1d for Upwind9h2 { } } -impl SbpOperator2d for (&SBP, &Upwind9h2) { +impl SbpOperator2d for Upwind9h2 { fn diffxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9h2::BLOCK.len()); @@ -109,6 +109,15 @@ impl SbpOperator2d for (&SBP, &Upwind9h2) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn SbpOperator1d { + &Self + } + fn op_eta(&self) -> &dyn SbpOperator1d { + &Self + } + fn upwind(&self) -> Option> { + Some(Box::new(Self)) + } } #[test] @@ -163,7 +172,7 @@ impl UpwindOperator1d for Upwind9h2 { } } -impl UpwindOperator2d for (&UO, &Upwind9h2) { +impl UpwindOperator2d for Upwind9h2 { fn dissxi(&self, prev: ArrayView2, mut fut: ArrayViewMut2) { assert_eq!(prev.shape(), fut.shape()); assert!(prev.shape()[1] >= 2 * Upwind9h2::BLOCK.len()); @@ -197,4 +206,10 @@ impl UpwindOperator2d for (&UO, &Upwind9h2) { _ => unreachable!("Should only be two elements in the strides vectors"), } } + fn op_xi(&self) -> &dyn UpwindOperator1d { + &Self + } + fn op_eta(&self) -> &dyn UpwindOperator1d { + &Self + } } diff --git a/shallow_water/src/lib.rs b/shallow_water/src/lib.rs index af7f5ad..4e143c1 100644 --- a/shallow_water/src/lib.rs +++ b/shallow_water/src/lib.rs @@ -62,7 +62,7 @@ pub struct System { fnext: Field, x: (Float, Float, usize), y: (Float, Float, usize), - op: Box, + op: Box, k: [Field; 4], } @@ -175,35 +175,37 @@ impl System { // Upwind dissipation if false { - let mut temp_dx = temp_dy; - azip!((dest in &mut temp, eta in now.eta(), etau in now.etau()) { - *dest = -(eta.powf(3.0/2.0)*G.sqrt() + etau.abs())/eta - }); - op.dissxi(temp.view(), temp_dx.view_mut()); - azip!((dest in &mut next_eta, eta in now.eta(), diss in &temp_dx) { - *dest -= eta*diss; - }); - azip!((dest in &mut next_etau, etau in now.etau(), diss in &temp_dx) { - *dest -= etau*diss; - }); - azip!((dest in &mut next_etav, etav in now.etav(), diss in &temp_dx) { - *dest -= etav*diss; - }); + if let Some(op) = op.upwind() { + let mut temp_dx = temp_dy; + azip!((dest in &mut temp, eta in now.eta(), etau in now.etau()) { + *dest = -(eta.powf(3.0/2.0)*G.sqrt() + etau.abs())/eta + }); + op.dissxi(temp.view(), temp_dx.view_mut()); + azip!((dest in &mut next_eta, eta in now.eta(), diss in &temp_dx) { + *dest -= eta*diss; + }); + azip!((dest in &mut next_etau, etau in now.etau(), diss in &temp_dx) { + *dest -= etau*diss; + }); + azip!((dest in &mut next_etav, etav in now.etav(), diss in &temp_dx) { + *dest -= etav*diss; + }); - let mut temp_dy = temp_dx; - azip!((dest in &mut temp, eta in now.eta(), etav in now.etav()) { - *dest = -(eta.powf(3.0/2.0)*G.sqrt() + etav.abs())/eta - }); - op.disseta(temp.view(), temp_dy.view_mut()); - azip!((dest in &mut next_eta, eta in now.eta(), diss in &temp_dy) { - *dest -= eta*diss; - }); - azip!((dest in &mut next_etau, etau in now.etau(), diss in &temp_dy) { - *dest -= etau*diss; - }); - azip!((dest in &mut next_etav, etav in now.etav(), diss in &temp_dy) { - *dest -= etav*diss; - }); + let mut temp_dy = temp_dx; + azip!((dest in &mut temp, eta in now.eta(), etav in now.etav()) { + *dest = -(eta.powf(3.0/2.0)*G.sqrt() + etav.abs())/eta + }); + op.disseta(temp.view(), temp_dy.view_mut()); + azip!((dest in &mut next_eta, eta in now.eta(), diss in &temp_dy) { + *dest -= eta*diss; + }); + azip!((dest in &mut next_etau, etau in now.etau(), diss in &temp_dy) { + *dest -= etau*diss; + }); + azip!((dest in &mut next_etav, etav in now.etav(), diss in &temp_dy) { + *dest -= etav*diss; + }); + } } // SAT boundaries