diff --git a/multigrid/Cargo.toml b/multigrid/Cargo.toml index ceda6bf..de81dab 100644 --- a/multigrid/Cargo.toml +++ b/multigrid/Cargo.toml @@ -13,3 +13,4 @@ indicatif = "0.14.0" structopt = "0.3.13" ndarray = "0.13.0" json = "0.12.4" +either = "1.5.3" diff --git a/multigrid/src/main.rs b/multigrid/src/main.rs index 7da8da1..2b8dcf2 100644 --- a/multigrid/src/main.rs +++ b/multigrid/src/main.rs @@ -1,4 +1,6 @@ #![feature(str_strip)] +use either::*; +use sbp::operators::{SbpOperator2d, UpwindOperator2d}; use sbp::utils::json_to_grids; use sbp::*; use structopt::StructOpt; @@ -12,31 +14,20 @@ struct System { wb: Vec, k: [Vec; 4], grids: Vec, - metrics: Vec, + metrics: Vec, bt: Vec, eb: Vec, time: Float, + operators: Vec, Box>>, interpolation_operators: Vec, } -enum Metrics { - Upwind4(grid::Metrics), - Upwind9(grid::Metrics), - Upwind4h2(grid::Metrics), - Trad4(grid::Metrics), - Trad8(grid::Metrics), - - Upwind4Upwind4h2(grid::Metrics), - Upwind4h2Upwind4(grid::Metrics), -} - impl System { fn new( grids: Vec, bt: Vec, interpolation_operators: Vec, - operatorx: &str, - operatory: &str, + operators: Vec, Box>>, ) -> Self { let fnow = grids .iter() @@ -50,34 +41,10 @@ impl System { let k = [fnow.clone(), fnow.clone(), fnow.clone(), fnow.clone()]; let metrics = grids .iter() - .map(|g| match (operatorx, operatory) { - ("upwind4", "upwind4") => Metrics::Upwind4( - g.metrics::() - .unwrap(), - ), - ("upwind9", "upwind9") => Metrics::Upwind9( - g.metrics::() - .unwrap(), - ), - ("upwind4h2", "upwind4h2") => Metrics::Upwind4h2( - g.metrics::() - .unwrap(), - ), - ("trad4", "trad4") => { - Metrics::Trad4(g.metrics::().unwrap()) - } - ("trad8", "trad8") => { - Metrics::Trad8(g.metrics::().unwrap()) - } - ("upwind4", "upwind4h2") => Metrics::Upwind4Upwind4h2( - g.metrics::() - .unwrap(), - ), - ("upwind4h2", "upwind4") => Metrics::Upwind4h2Upwind4( - g.metrics::() - .unwrap(), - ), - (opx, opy) => panic!("operator combination {}x{} not known", opx, opy), + .zip(&operators) + .map(|(g, op)| { + let sbpop: &dyn SbpOperator2d = op.as_ref().either(|op| &**op, |uo| uo.as_sbp()); + g.metrics(sbpop).unwrap() }) .collect::>(); @@ -98,6 +65,7 @@ impl System { eb, time: 0.0, interpolation_operators, + operators, } } @@ -114,6 +82,7 @@ impl System { let wb = &mut self.wb; let mut eb = &mut self.eb; let intops = &self.interpolation_operators; + let operators = &self.operators; let rhs = move |fut: &mut [euler::Field], prev: &[euler::Field], @@ -122,36 +91,22 @@ impl System { _mt: &mut ()| { let bc = euler::extract_boundaries(prev, &bt, &mut eb, &grids, time, Some(intops)); pool.scope(|s| { - for ((((fut, prev), bc), wb), metrics) in fut + for (((((fut, prev), bc), wb), metrics), op) in fut .iter_mut() .zip(prev.iter()) .zip(bc) .zip(wb.iter_mut()) .zip(metrics.iter()) + .zip(operators.iter()) { - s.spawn(move |_| match metrics { - Metrics::Upwind4(metrics) => { - euler::RHS_upwind(fut, prev, metrics, &bc, &mut wb.0) + s.spawn(move |_| match op.as_ref() { + Left(sbp) => { + euler::RHS_trad(&**sbp, fut, prev, metrics, &bc, &mut wb.0); } - Metrics::Upwind9(metrics) => { - euler::RHS_upwind(fut, prev, metrics, &bc, &mut wb.0) + Right(uo) => { + euler::RHS_upwind(&**uo, fut, prev, metrics, &bc, &mut wb.0); } - Metrics::Upwind4h2(metrics) => { - euler::RHS_upwind(fut, prev, metrics, &bc, &mut wb.0) - } - Metrics::Trad4(metrics) => { - euler::RHS_trad(fut, prev, metrics, &bc, &mut wb.0) - } - Metrics::Trad8(metrics) => { - euler::RHS_trad(fut, prev, metrics, &bc, &mut wb.0) - } - Metrics::Upwind4Upwind4h2(metrics) => { - euler::RHS_trad(fut, prev, metrics, &bc, &mut wb.0) - } - Metrics::Upwind4h2Upwind4(metrics) => { - euler::RHS_trad(fut, prev, metrics, &bc, &mut wb.0) - } - }); + }) } }); }; @@ -201,7 +156,6 @@ struct Options { } fn main() { - type SBP = operators::Upwind4; let opt = Options::from_args(); let filecontents = std::fs::read_to_string(&opt.json).unwrap(); @@ -249,23 +203,20 @@ fn main() { west: Some(Box::new(operators::Interpolation4)), }) .collect::>(); - let grids = jgrids.into_iter().map(|egrid| egrid.grid).collect(); + + let grids = jgrids + .into_iter() + .map(|egrid| egrid.grid) + .collect::>(); let integration_time: Float = json["integration_time"].as_number().unwrap().into(); - let (operatorx, operatory) = { - if json["operator"].is_object() { - ( - json["operator"]["x"].as_str().unwrap(), - json["operator"]["y"].as_str().unwrap(), - ) - } else { - let op = json["operator"].as_str().unwrap_or("upwind4"); - (op, op) - } - }; + let operators = grids + .iter() + .map(|_| Right(Box::new(operators::Upwind4) as Box)) + .collect::>(); - let mut sys = System::new(grids, bt, interpolation_operators, operatorx, operatory); + let mut sys = System::new(grids, bt, interpolation_operators, operators); sys.vortex(0.0, vortexparams); let max_n = { @@ -331,10 +282,11 @@ fn main() { if opt.error { let time = ntime as Float * dt; let mut e = 0.0; - for (fmod, grid) in sys.fnow.iter().zip(&sys.grids) { + for ((fmod, grid), op) in sys.fnow.iter().zip(&sys.grids).zip(&sys.operators) { let mut fvort = fmod.clone(); fvort.vortex(grid.x(), grid.y(), time, vortexparams); - e += fmod.h2_err::(&fvort); + let sbpop: &dyn SbpOperator2d = op.as_ref().either(|op| &**op, |uo| uo.as_sbp()); + e += fmod.h2_err(&fvort, sbpop); } println!("Total error: {:e}", e); }