From 9cfd54253f46ec639330d81ccec8ca16d0799760 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Mon, 13 Apr 2020 20:56:29 +0200 Subject: [PATCH] change trait signature for interpolation --- multigrid/src/main.rs | 19 +++++-- sbp/src/euler.rs | 54 +++++++++++++------ sbp/src/operators.rs | 4 +- .../operators/interpolation/interpolation4.rs | 6 +-- .../operators/interpolation/interpolation9.rs | 6 +-- 5 files changed, 61 insertions(+), 28 deletions(-) diff --git a/multigrid/src/main.rs b/multigrid/src/main.rs index bfda683..7da8da1 100644 --- a/multigrid/src/main.rs +++ b/multigrid/src/main.rs @@ -16,6 +16,7 @@ struct System { bt: Vec, eb: Vec, time: Float, + interpolation_operators: Vec, } enum Metrics { @@ -33,6 +34,7 @@ impl System { fn new( grids: Vec, bt: Vec, + interpolation_operators: Vec, operatorx: &str, operatory: &str, ) -> Self { @@ -95,6 +97,7 @@ impl System { bt, eb, time: 0.0, + interpolation_operators, } } @@ -110,15 +113,14 @@ impl System { let bt = &self.bt; let wb = &mut self.wb; let mut eb = &mut self.eb; + let intops = &self.interpolation_operators; let rhs = move |fut: &mut [euler::Field], prev: &[euler::Field], time: Float, _c: (), _mt: &mut ()| { - let bc = euler::extract_boundaries::( - prev, &bt, &mut eb, &grids, time, - ); + let bc = euler::extract_boundaries(prev, &bt, &mut eb, &grids, time, Some(intops)); pool.scope(|s| { for ((((fut, prev), bc), wb), metrics) in fut .iter_mut() @@ -238,6 +240,15 @@ fn main() { west: determine_bc(grid.dirw.as_ref()), }); } + let interpolation_operators = jgrids + .iter() + .map(|_g| euler::InterpolationOperators { + north: Some(Box::new(operators::Interpolation4)), + south: Some(Box::new(operators::Interpolation4)), + east: Some(Box::new(operators::Interpolation4)), + west: Some(Box::new(operators::Interpolation4)), + }) + .collect::>(); let grids = jgrids.into_iter().map(|egrid| egrid.grid).collect(); let integration_time: Float = json["integration_time"].as_number().unwrap().into(); @@ -254,7 +265,7 @@ fn main() { } }; - let mut sys = System::new(grids, bt, operatorx, operatory); + let mut sys = System::new(grids, bt, interpolation_operators, operatorx, operatory); sys.vortex(0.0, vortexparams); let max_n = { diff --git a/sbp/src/euler.rs b/sbp/src/euler.rs index e619f21..08b824a 100644 --- a/sbp/src/euler.rs +++ b/sbp/src/euler.rs @@ -608,14 +608,15 @@ pub enum BoundaryCharacteristic { Interpolate(usize), } -#[derive(Clone, Debug)] -pub struct BoundaryCharacteristics { - pub north: BoundaryCharacteristic, - pub south: BoundaryCharacteristic, - pub east: BoundaryCharacteristic, - pub west: BoundaryCharacteristic, +pub struct Direction { + pub north: T, + pub south: T, + pub west: T, + pub east: T, } +pub type BoundaryCharacteristics = Direction; + fn boundary_extractor<'a>( field: &'a Field, _grid: &Grid, @@ -653,18 +654,22 @@ fn boundary_extractor<'a>( } } -pub fn extract_boundaries<'a, IO: InterpolationOperator>( +pub type InterpolationOperators = Direction>>; + +pub fn extract_boundaries<'a>( fields: &'a [Field], bt: &[BoundaryCharacteristics], eb: &'a mut [BoundaryStorage], grids: &[Grid], time: Float, + interpolation_operators: Option<&[InterpolationOperators]>, ) -> Vec> { bt.iter() .zip(eb) .zip(grids) .zip(fields) - .map(|(((bt, eb), grid), field)| BoundaryTerms { + .enumerate() + .map(|(ig, (((bt, eb), grid), field))| BoundaryTerms { north: match bt.north { BoundaryCharacteristic::This => field.south(), BoundaryCharacteristic::Grid(g) => fields[g].south(), @@ -677,11 +682,16 @@ pub fn extract_boundaries<'a, IO: InterpolationOperator>( let to = eb.n.as_mut().unwrap(); let fine2coarse = field.nx() < fields[g].nx(); + let operator = interpolation_operators.as_ref().unwrap()[ig] + .north + .as_ref() + .unwrap(); + for (mut to, from) in to.outer_iter_mut().zip(fields[g].south().outer_iter()) { if fine2coarse { - IO::fine2coarse(from.view(), to.view_mut()); + operator.fine2coarse(from.view(), to.view_mut()); } else { - IO::coarse2fine(from.view(), to.view_mut()); + operator.coarse2fine(from.view(), to.view_mut()); } } to.view() @@ -698,12 +708,16 @@ pub fn extract_boundaries<'a, IO: InterpolationOperator>( BoundaryCharacteristic::Interpolate(g) => { let to = eb.s.as_mut().unwrap(); let fine2coarse = field.nx() < fields[g].nx(); + let operator = interpolation_operators.as_ref().unwrap()[ig] + .south + .as_ref() + .unwrap(); for (mut to, from) in to.outer_iter_mut().zip(fields[g].north().outer_iter()) { if fine2coarse { - IO::fine2coarse(from.view(), to.view_mut()); + operator.fine2coarse(from.view(), to.view_mut()); } else { - IO::coarse2fine(from.view(), to.view_mut()); + operator.coarse2fine(from.view(), to.view_mut()); } } to.view() @@ -720,12 +734,16 @@ pub fn extract_boundaries<'a, IO: InterpolationOperator>( BoundaryCharacteristic::Interpolate(g) => { let to = eb.w.as_mut().unwrap(); let fine2coarse = field.ny() < fields[g].ny(); + let operator = interpolation_operators.as_ref().unwrap()[ig] + .west + .as_ref() + .unwrap(); for (mut to, from) in to.outer_iter_mut().zip(fields[g].east().outer_iter()) { if fine2coarse { - IO::fine2coarse(from.view(), to.view_mut()); + operator.fine2coarse(from.view(), to.view_mut()); } else { - IO::coarse2fine(from.view(), to.view_mut()); + operator.coarse2fine(from.view(), to.view_mut()); } } to.view() @@ -742,12 +760,16 @@ pub fn extract_boundaries<'a, IO: InterpolationOperator>( BoundaryCharacteristic::Interpolate(g) => { let to = eb.e.as_mut().unwrap(); let fine2coarse = field.ny() < fields[g].ny(); + let operator = interpolation_operators.as_ref().unwrap()[ig] + .east + .as_ref() + .unwrap(); for (mut to, from) in to.outer_iter_mut().zip(fields[g].west().outer_iter()) { if fine2coarse { - IO::fine2coarse(from.view(), to.view_mut()); + operator.fine2coarse(from.view(), to.view_mut()); } else { - IO::coarse2fine(from.view(), to.view_mut()); + operator.coarse2fine(from.view(), to.view_mut()); } } to.view() diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 8264a42..9cdf243 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -36,8 +36,8 @@ pub trait UpwindOperator: SbpOperator { } pub trait InterpolationOperator: Send + Sync { - fn fine2coarse(fine: ArrayView1, coarse: ArrayViewMut1); - fn coarse2fine(coarse: ArrayView1, fine: ArrayViewMut1); + fn fine2coarse(&self, fine: ArrayView1, coarse: ArrayViewMut1); + fn coarse2fine(&self, coarse: ArrayView1, fine: ArrayViewMut1); } #[inline(always)] diff --git a/sbp/src/operators/interpolation/interpolation4.rs b/sbp/src/operators/interpolation/interpolation4.rs index e01b67e..ff03956 100644 --- a/sbp/src/operators/interpolation/interpolation4.rs +++ b/sbp/src/operators/interpolation/interpolation4.rs @@ -1,6 +1,6 @@ use super::*; -pub struct Interpolation4 {} +pub struct Interpolation4; impl Interpolation4 { const F2C_DIAG: &'static [[Float; 7]] = &[[ @@ -71,7 +71,7 @@ impl Interpolation4 { } impl InterpolationOperator for Interpolation4 { - fn fine2coarse(fine: ArrayView1, coarse: ArrayViewMut1) { + fn fine2coarse(&self, fine: ArrayView1, coarse: ArrayViewMut1) { assert_eq!(fine.len(), 2 * coarse.len() - 1); super::interpolate( fine, @@ -81,7 +81,7 @@ impl InterpolationOperator for Interpolation4 { (3, 2), ) } - fn coarse2fine(coarse: ArrayView1, fine: ArrayViewMut1) { + fn coarse2fine(&self, coarse: ArrayView1, fine: ArrayViewMut1) { assert_eq!(fine.len(), 2 * coarse.len() - 1); super::interpolate( coarse, diff --git a/sbp/src/operators/interpolation/interpolation9.rs b/sbp/src/operators/interpolation/interpolation9.rs index 9f99291..8847a26 100644 --- a/sbp/src/operators/interpolation/interpolation9.rs +++ b/sbp/src/operators/interpolation/interpolation9.rs @@ -1,6 +1,6 @@ use super::*; -pub struct Interpolation9 {} +pub struct Interpolation9; impl Interpolation9 { #[rustfmt::skip] @@ -50,7 +50,7 @@ impl Interpolation9 { } impl InterpolationOperator for Interpolation9 { - fn fine2coarse(fine: ArrayView1, coarse: ArrayViewMut1) { + fn fine2coarse(&self, fine: ArrayView1, coarse: ArrayViewMut1) { assert_eq!(fine.len(), 2 * coarse.len() - 1); use ndarray::prelude::*; use std::iter::FromIterator; @@ -62,7 +62,7 @@ impl InterpolationOperator for Interpolation9 { .unwrap(); super::interpolate(fine, coarse, block.view(), diag.view(), (5, 2)) } - fn coarse2fine(coarse: ArrayView1, fine: ArrayViewMut1) { + fn coarse2fine(&self, coarse: ArrayView1, fine: ArrayViewMut1) { assert_eq!(fine.len(), 2 * coarse.len() - 1); use ndarray::prelude::*; use std::iter::FromIterator;