simplify integrate function

This commit is contained in:
Magnus Ulimoen 2020-04-16 20:40:22 +02:00
parent 41935728e1
commit cfff49107c
4 changed files with 70 additions and 81 deletions

View File

@ -84,11 +84,7 @@ impl System {
let mut eb = &mut self.eb;
let operators = &self.operators;
let rhs = move |fut: &mut [euler::Field],
prev: &[euler::Field],
time: Float,
_c: (),
_mt: &mut ()| {
let rhs = move |fut: &mut [euler::Field], prev: &[euler::Field], time: Float| {
let bc = euler::extract_boundaries(prev, &bt, &mut eb, &grids, time);
pool.scope(|s| {
for (((((fut, prev), bc), wb), metrics), op) in fut
@ -116,15 +112,13 @@ impl System {
.iter_mut()
.map(|k| k.as_mut_slice())
.collect::<Vec<_>>();
sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _, _, _>(
sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _>(
rhs,
&self.fnow,
&mut self.fnext,
&mut self.time,
dt,
&mut k,
(),
&mut (),
pool,
);

View File

@ -49,20 +49,20 @@ impl<SBP: SbpOperator2d> System<SBP> {
west: BoundaryCharacteristic::This,
};
let op = &self.op;
let rhs_trad = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| {
let (grid, metrics) = gm;
let wb = &mut self.wb.0;
let grid = &self.grid.0;
let metrics = &self.grid.1;
let rhs_trad = |k: &mut Field, y: &Field, _time: Float| {
let boundaries = boundary_extractor(y, grid, &bc);
RHS_trad(op, k, y, metrics, &boundaries, wb)
};
integrate::integrate::<integrate::Rk4, _, _, _, _>(
integrate::integrate::<integrate::Rk4, _, _>(
rhs_trad,
&self.sys.0,
&mut self.sys.1,
&mut 0.0,
dt,
&mut self.k,
&self.grid,
&mut self.wb.0,
);
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
}
@ -110,20 +110,20 @@ impl<UO: UpwindOperator2d> System<UO> {
west: BoundaryCharacteristic::This,
};
let op = &self.op;
let rhs_upwind = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| {
let (grid, metrics) = gm;
let grid = &self.grid;
let wb = &mut self.wb.0;
let rhs_upwind = |k: &mut Field, y: &Field, _time: Float| {
let (grid, metrics) = grid;
let boundaries = boundary_extractor(y, grid, &bc);
RHS_upwind(op, k, y, metrics, &boundaries, wb)
};
integrate::integrate::<integrate::Rk4, _, _, _, _>(
integrate::integrate::<integrate::Rk4, _, _>(
rhs_upwind,
&self.sys.0,
&mut self.sys.1,
&mut 0.0,
dt,
&mut self.k,
&self.grid,
&mut self.wb.0,
);
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
}
@ -146,6 +146,17 @@ impl std::ops::DerefMut for Field {
}
}
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
fn from(f: &'a Field) -> Self {
f.0.view()
}
}
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
fn from(f: &'a mut Field) -> Self {
f.0.view_mut()
}
}
impl Field {
pub fn new(ny: usize, nx: usize) -> Self {
let field = Array3::zeros((4, ny, nx));

View File

@ -1,5 +1,5 @@
use super::Float;
use ndarray::Array3;
use ndarray::{ArrayView3, ArrayViewMut3};
pub trait ButcherTableau {
const S: usize = Self::B.len();
@ -88,49 +88,45 @@ impl ButcherTableau for Rk6 {
}
#[allow(clippy::too_many_arguments)]
pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>(
pub fn integrate<BTableau: ButcherTableau, F, RHS>(
mut rhs: RHS,
prev: &F,
fut: &mut F,
time: &mut Float,
dt: Float,
k: &mut [F],
constants: C,
mut mutables: &mut MT,
) where
C: Copy,
F: std::ops::Deref<Target = Array3<Float>> + std::ops::DerefMut<Target = Array3<Float>>,
RHS: FnMut(&mut F, &F, Float, C, &mut MT),
BTableau: ButcherTableau,
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>,
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>,
RHS: FnMut(&mut F, &F, Float),
{
assert_eq!(prev.shape(), fut.shape());
assert_eq!(prev.into().shape(), fut.into().shape());
assert!(k.len() >= BTableau::S);
for i in 0.. {
let simtime;
match i {
0 => {
fut.assign(prev);
fut.into().assign(&prev.into());
simtime = *time;
}
i if i < BTableau::S => {
fut.assign(prev);
fut.into().assign(&prev.into());
for (&a, k) in BTableau::A[i - 1].iter().zip(k.iter()) {
if a == 0.0 {
continue;
}
fut.scaled_add(a * dt, &k);
fut.into().scaled_add(a * dt, &k.into());
}
simtime = *time + dt * BTableau::C[i - 1];
}
_ if i == BTableau::S => {
fut.assign(prev);
fut.into().assign(&prev.into());
for (&b, k) in BTableau::B.iter().zip(k.iter()) {
if b == 0.0 {
continue;
}
fut.scaled_add(b * dt, &k);
fut.into().scaled_add(b * dt, &k.into());
}
*time += dt;
return;
@ -140,13 +136,13 @@ pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>(
}
};
rhs(&mut k[i], &fut, simtime, constants, &mut mutables);
rhs(&mut k[i], &fut, simtime);
}
}
#[cfg(feature = "rayon")]
#[allow(clippy::too_many_arguments)]
pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS>(
mut rhs: RHS,
prev: &[F],
fut: &mut [F],
@ -154,17 +150,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
dt: Float,
k: &mut [&mut [F]],
constants: C,
mut mutables: &mut MT,
pool: &rayon::ThreadPool,
) where
C: Copy,
F: std::ops::Deref<Target = Array3<Float>>
+ std::ops::DerefMut<Target = Array3<Float>>
+ Send
+ Sync,
RHS: FnMut(&mut [F], &[F], Float, C, &mut MT),
BTableau: ButcherTableau,
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>,
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>,
RHS: FnMut(&mut [F], &[F], Float),
F: Send + Sync,
{
for i in 0.. {
let simtime;
@ -174,8 +165,8 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
assert!(k.len() >= BTableau::S);
for (prev, fut) in prev.iter().zip(fut.iter_mut()) {
s.spawn(move |_| {
assert_eq!(prev.shape(), fut.shape());
fut.assign(prev);
assert_eq!(prev.into().shape(), fut.into().shape());
fut.into().assign(&prev.into());
});
}
});
@ -186,12 +177,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
let k = &k;
s.spawn(move |_| {
fut.assign(prev);
fut.into().assign(&prev.into());
for (ik, &a) in BTableau::A[i - 1].iter().enumerate() {
if a == 0.0 {
continue;
}
fut.scaled_add(a * dt, &k[ik][ig]);
fut.into().scaled_add(a * dt, &(&k[ik][ig]).into());
}
});
}
@ -203,12 +194,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
let k = &k;
s.spawn(move |_| {
fut.assign(prev);
fut.into().assign(&prev.into());
for (ik, &b) in BTableau::B.iter().enumerate() {
if b == 0.0 {
continue;
}
fut.scaled_add(b * dt, &k[ik][ig]);
fut.into().scaled_add(b * dt, &(&k[ik][ig]).into());
}
});
}
@ -221,6 +212,6 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
}
};
rhs(&mut k[i], &fut, simtime, constants, &mut mutables);
rhs(&mut k[i], &fut, simtime);
}
}

View File

@ -21,6 +21,17 @@ impl std::ops::DerefMut for Field {
}
}
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
fn from(f: &'a Field) -> Self {
f.0.view()
}
}
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
fn from(f: &'a mut Field) -> Self {
f.0.view_mut()
}
}
impl Field {
pub fn new(height: usize, width: usize) -> Self {
let field = Array3::zeros((3, height, width));
@ -118,29 +129,20 @@ impl<SBP: SbpOperator2d> System<SBP> {
pub fn advance(&mut self, dt: Float) {
let op = &self.op;
let rhs_adaptor = move |fut: &mut Field,
prev: &Field,
_time: Float,
c: &(&Grid, &Metrics),
m: &mut (
Array2<Float>,
Array2<Float>,
Array2<Float>,
Array2<Float>,
)| {
let (grid, metrics) = c;
RHS(op, fut, prev, grid, metrics, m);
let grid = &self.grid;
let metrics = &self.metrics;
let wb = &mut self.wb.tmp;
let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| {
RHS(op, fut, prev, grid, metrics, wb);
};
let mut _time = 0.0;
integrate::integrate::<integrate::Rk4, _, _, _, _>(
integrate::integrate::<integrate::Rk4, _, _>(
rhs_adaptor,
&self.sys.0,
&mut self.sys.1,
&mut _time,
dt,
&mut self.wb.k,
&(&self.grid, &self.metrics),
&mut self.wb.tmp,
);
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
}
@ -150,29 +152,20 @@ impl<UO: UpwindOperator2d> System<UO> {
/// Using artificial dissipation with the upwind operator
pub fn advance_upwind(&mut self, dt: Float) {
let op = &self.op;
let rhs_adaptor = move |fut: &mut Field,
prev: &Field,
_time: Float,
c: &(&Grid, &Metrics),
m: &mut (
Array2<Float>,
Array2<Float>,
Array2<Float>,
Array2<Float>,
)| {
let (grid, metrics) = c;
RHS_upwind(op, fut, prev, grid, metrics, m);
let grid = &self.grid;
let metrics = &self.metrics;
let wb = &mut self.wb.tmp;
let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| {
RHS_upwind(op, fut, prev, grid, metrics, wb);
};
let mut _time = 0.0;
integrate::integrate::<integrate::Rk4, _, _, _, _>(
integrate::integrate::<integrate::Rk4, _, _>(
rhs_adaptor,
&self.sys.0,
&mut self.sys.1,
&mut _time,
dt,
&mut self.wb.k,
&(&self.grid, &self.metrics),
&mut self.wb.tmp,
);
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
}