diff --git a/multigrid/src/main.rs b/multigrid/src/main.rs index 65ed1cc..e831ad7 100644 --- a/multigrid/src/main.rs +++ b/multigrid/src/main.rs @@ -84,11 +84,7 @@ impl System { let mut eb = &mut self.eb; let operators = &self.operators; - let rhs = move |fut: &mut [euler::Field], - prev: &[euler::Field], - time: Float, - _c: (), - _mt: &mut ()| { + let rhs = move |fut: &mut [euler::Field], prev: &[euler::Field], time: Float| { let bc = euler::extract_boundaries(prev, &bt, &mut eb, &grids, time); pool.scope(|s| { for (((((fut, prev), bc), wb), metrics), op) in fut @@ -116,15 +112,13 @@ impl System { .iter_mut() .map(|k| k.as_mut_slice()) .collect::>(); - sbp::integrate::integrate_multigrid::( + sbp::integrate::integrate_multigrid::( rhs, &self.fnow, &mut self.fnext, &mut self.time, dt, &mut k, - (), - &mut (), pool, ); diff --git a/sbp/src/euler.rs b/sbp/src/euler.rs index bab75d9..309afc6 100644 --- a/sbp/src/euler.rs +++ b/sbp/src/euler.rs @@ -49,20 +49,20 @@ impl System { west: BoundaryCharacteristic::This, }; let op = &self.op; - let rhs_trad = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| { - let (grid, metrics) = gm; + let wb = &mut self.wb.0; + let grid = &self.grid.0; + let metrics = &self.grid.1; + let rhs_trad = |k: &mut Field, y: &Field, _time: Float| { let boundaries = boundary_extractor(y, grid, &bc); RHS_trad(op, k, y, metrics, &boundaries, wb) }; - integrate::integrate::( + integrate::integrate::( rhs_trad, &self.sys.0, &mut self.sys.1, &mut 0.0, dt, &mut self.k, - &self.grid, - &mut self.wb.0, ); std::mem::swap(&mut self.sys.0, &mut self.sys.1); } @@ -110,20 +110,20 @@ impl System { west: BoundaryCharacteristic::This, }; let op = &self.op; - let rhs_upwind = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| { - let (grid, metrics) = gm; + let grid = &self.grid; + let wb = &mut self.wb.0; + let rhs_upwind = |k: &mut Field, y: &Field, _time: Float| { + let (grid, metrics) = grid; let boundaries = boundary_extractor(y, grid, &bc); RHS_upwind(op, k, y, metrics, &boundaries, wb) }; - integrate::integrate::( + integrate::integrate::( rhs_upwind, &self.sys.0, &mut self.sys.1, &mut 0.0, dt, &mut self.k, - &self.grid, - &mut self.wb.0, ); std::mem::swap(&mut self.sys.0, &mut self.sys.1); } @@ -146,6 +146,17 @@ impl std::ops::DerefMut for Field { } } +impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> { + fn from(f: &'a Field) -> Self { + f.0.view() + } +} +impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> { + fn from(f: &'a mut Field) -> Self { + f.0.view_mut() + } +} + impl Field { pub fn new(ny: usize, nx: usize) -> Self { let field = Array3::zeros((4, ny, nx)); diff --git a/sbp/src/integrate.rs b/sbp/src/integrate.rs index c11e8b3..9886e5e 100644 --- a/sbp/src/integrate.rs +++ b/sbp/src/integrate.rs @@ -1,5 +1,5 @@ use super::Float; -use ndarray::Array3; +use ndarray::{ArrayView3, ArrayViewMut3}; pub trait ButcherTableau { const S: usize = Self::B.len(); @@ -88,49 +88,45 @@ impl ButcherTableau for Rk6 { } #[allow(clippy::too_many_arguments)] -pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>( +pub fn integrate( mut rhs: RHS, prev: &F, fut: &mut F, time: &mut Float, dt: Float, k: &mut [F], - - constants: C, - mut mutables: &mut MT, ) where - C: Copy, - F: std::ops::Deref> + std::ops::DerefMut>, - RHS: FnMut(&mut F, &F, Float, C, &mut MT), - BTableau: ButcherTableau, + for<'r> &'r F: std::convert::Into>, + for<'r> &'r mut F: std::convert::Into>, + RHS: FnMut(&mut F, &F, Float), { - assert_eq!(prev.shape(), fut.shape()); + assert_eq!(prev.into().shape(), fut.into().shape()); assert!(k.len() >= BTableau::S); for i in 0.. { let simtime; match i { 0 => { - fut.assign(prev); + fut.into().assign(&prev.into()); simtime = *time; } i if i < BTableau::S => { - fut.assign(prev); + fut.into().assign(&prev.into()); for (&a, k) in BTableau::A[i - 1].iter().zip(k.iter()) { if a == 0.0 { continue; } - fut.scaled_add(a * dt, &k); + fut.into().scaled_add(a * dt, &k.into()); } simtime = *time + dt * BTableau::C[i - 1]; } _ if i == BTableau::S => { - fut.assign(prev); + fut.into().assign(&prev.into()); for (&b, k) in BTableau::B.iter().zip(k.iter()) { if b == 0.0 { continue; } - fut.scaled_add(b * dt, &k); + fut.into().scaled_add(b * dt, &k.into()); } *time += dt; return; @@ -140,13 +136,13 @@ pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>( } }; - rhs(&mut k[i], &fut, simtime, constants, &mut mutables); + rhs(&mut k[i], &fut, simtime); } } #[cfg(feature = "rayon")] #[allow(clippy::too_many_arguments)] -pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( +pub fn integrate_multigrid( mut rhs: RHS, prev: &[F], fut: &mut [F], @@ -154,17 +150,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( dt: Float, k: &mut [&mut [F]], - constants: C, - mut mutables: &mut MT, pool: &rayon::ThreadPool, ) where - C: Copy, - F: std::ops::Deref> - + std::ops::DerefMut> - + Send - + Sync, - RHS: FnMut(&mut [F], &[F], Float, C, &mut MT), - BTableau: ButcherTableau, + for<'r> &'r F: std::convert::Into>, + for<'r> &'r mut F: std::convert::Into>, + RHS: FnMut(&mut [F], &[F], Float), + F: Send + Sync, { for i in 0.. { let simtime; @@ -174,8 +165,8 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( assert!(k.len() >= BTableau::S); for (prev, fut) in prev.iter().zip(fut.iter_mut()) { s.spawn(move |_| { - assert_eq!(prev.shape(), fut.shape()); - fut.assign(prev); + assert_eq!(prev.into().shape(), fut.into().shape()); + fut.into().assign(&prev.into()); }); } }); @@ -186,12 +177,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { let k = &k; s.spawn(move |_| { - fut.assign(prev); + fut.into().assign(&prev.into()); for (ik, &a) in BTableau::A[i - 1].iter().enumerate() { if a == 0.0 { continue; } - fut.scaled_add(a * dt, &k[ik][ig]); + fut.into().scaled_add(a * dt, &(&k[ik][ig]).into()); } }); } @@ -203,12 +194,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { let k = &k; s.spawn(move |_| { - fut.assign(prev); + fut.into().assign(&prev.into()); for (ik, &b) in BTableau::B.iter().enumerate() { if b == 0.0 { continue; } - fut.scaled_add(b * dt, &k[ik][ig]); + fut.into().scaled_add(b * dt, &(&k[ik][ig]).into()); } }); } @@ -221,6 +212,6 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( } }; - rhs(&mut k[i], &fut, simtime, constants, &mut mutables); + rhs(&mut k[i], &fut, simtime); } } diff --git a/sbp/src/maxwell.rs b/sbp/src/maxwell.rs index 9c4c3dd..67cdf6a 100644 --- a/sbp/src/maxwell.rs +++ b/sbp/src/maxwell.rs @@ -21,6 +21,17 @@ impl std::ops::DerefMut for Field { } } +impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> { + fn from(f: &'a Field) -> Self { + f.0.view() + } +} +impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> { + fn from(f: &'a mut Field) -> Self { + f.0.view_mut() + } +} + impl Field { pub fn new(height: usize, width: usize) -> Self { let field = Array3::zeros((3, height, width)); @@ -118,29 +129,20 @@ impl System { pub fn advance(&mut self, dt: Float) { let op = &self.op; - let rhs_adaptor = move |fut: &mut Field, - prev: &Field, - _time: Float, - c: &(&Grid, &Metrics), - m: &mut ( - Array2, - Array2, - Array2, - Array2, - )| { - let (grid, metrics) = c; - RHS(op, fut, prev, grid, metrics, m); + let grid = &self.grid; + let metrics = &self.metrics; + let wb = &mut self.wb.tmp; + let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| { + RHS(op, fut, prev, grid, metrics, wb); }; let mut _time = 0.0; - integrate::integrate::( + integrate::integrate::( rhs_adaptor, &self.sys.0, &mut self.sys.1, &mut _time, dt, &mut self.wb.k, - &(&self.grid, &self.metrics), - &mut self.wb.tmp, ); std::mem::swap(&mut self.sys.0, &mut self.sys.1); } @@ -150,29 +152,20 @@ impl System { /// Using artificial dissipation with the upwind operator pub fn advance_upwind(&mut self, dt: Float) { let op = &self.op; - let rhs_adaptor = move |fut: &mut Field, - prev: &Field, - _time: Float, - c: &(&Grid, &Metrics), - m: &mut ( - Array2, - Array2, - Array2, - Array2, - )| { - let (grid, metrics) = c; - RHS_upwind(op, fut, prev, grid, metrics, m); + let grid = &self.grid; + let metrics = &self.metrics; + let wb = &mut self.wb.tmp; + let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| { + RHS_upwind(op, fut, prev, grid, metrics, wb); }; let mut _time = 0.0; - integrate::integrate::( + integrate::integrate::( rhs_adaptor, &self.sys.0, &mut self.sys.1, &mut _time, dt, &mut self.wb.k, - &(&self.grid, &self.metrics), - &mut self.wb.tmp, ); std::mem::swap(&mut self.sys.0, &mut self.sys.1); }