simplify integrate function
This commit is contained in:
parent
41935728e1
commit
cfff49107c
|
@ -84,11 +84,7 @@ impl System {
|
|||
let mut eb = &mut self.eb;
|
||||
let operators = &self.operators;
|
||||
|
||||
let rhs = move |fut: &mut [euler::Field],
|
||||
prev: &[euler::Field],
|
||||
time: Float,
|
||||
_c: (),
|
||||
_mt: &mut ()| {
|
||||
let rhs = move |fut: &mut [euler::Field], prev: &[euler::Field], time: Float| {
|
||||
let bc = euler::extract_boundaries(prev, &bt, &mut eb, &grids, time);
|
||||
pool.scope(|s| {
|
||||
for (((((fut, prev), bc), wb), metrics), op) in fut
|
||||
|
@ -116,15 +112,13 @@ impl System {
|
|||
.iter_mut()
|
||||
.map(|k| k.as_mut_slice())
|
||||
.collect::<Vec<_>>();
|
||||
sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _, _, _>(
|
||||
sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _>(
|
||||
rhs,
|
||||
&self.fnow,
|
||||
&mut self.fnext,
|
||||
&mut self.time,
|
||||
dt,
|
||||
&mut k,
|
||||
(),
|
||||
&mut (),
|
||||
pool,
|
||||
);
|
||||
|
||||
|
|
|
@ -49,20 +49,20 @@ impl<SBP: SbpOperator2d> System<SBP> {
|
|||
west: BoundaryCharacteristic::This,
|
||||
};
|
||||
let op = &self.op;
|
||||
let rhs_trad = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| {
|
||||
let (grid, metrics) = gm;
|
||||
let wb = &mut self.wb.0;
|
||||
let grid = &self.grid.0;
|
||||
let metrics = &self.grid.1;
|
||||
let rhs_trad = |k: &mut Field, y: &Field, _time: Float| {
|
||||
let boundaries = boundary_extractor(y, grid, &bc);
|
||||
RHS_trad(op, k, y, metrics, &boundaries, wb)
|
||||
};
|
||||
integrate::integrate::<integrate::Rk4, _, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, _, _>(
|
||||
rhs_trad,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
&mut 0.0,
|
||||
dt,
|
||||
&mut self.k,
|
||||
&self.grid,
|
||||
&mut self.wb.0,
|
||||
);
|
||||
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
|
||||
}
|
||||
|
@ -110,20 +110,20 @@ impl<UO: UpwindOperator2d> System<UO> {
|
|||
west: BoundaryCharacteristic::This,
|
||||
};
|
||||
let op = &self.op;
|
||||
let rhs_upwind = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| {
|
||||
let (grid, metrics) = gm;
|
||||
let grid = &self.grid;
|
||||
let wb = &mut self.wb.0;
|
||||
let rhs_upwind = |k: &mut Field, y: &Field, _time: Float| {
|
||||
let (grid, metrics) = grid;
|
||||
let boundaries = boundary_extractor(y, grid, &bc);
|
||||
RHS_upwind(op, k, y, metrics, &boundaries, wb)
|
||||
};
|
||||
integrate::integrate::<integrate::Rk4, _, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, _, _>(
|
||||
rhs_upwind,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
&mut 0.0,
|
||||
dt,
|
||||
&mut self.k,
|
||||
&self.grid,
|
||||
&mut self.wb.0,
|
||||
);
|
||||
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
|
||||
}
|
||||
|
@ -146,6 +146,17 @@ impl std::ops::DerefMut for Field {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
|
||||
fn from(f: &'a Field) -> Self {
|
||||
f.0.view()
|
||||
}
|
||||
}
|
||||
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
|
||||
fn from(f: &'a mut Field) -> Self {
|
||||
f.0.view_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl Field {
|
||||
pub fn new(ny: usize, nx: usize) -> Self {
|
||||
let field = Array3::zeros((4, ny, nx));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::Float;
|
||||
use ndarray::Array3;
|
||||
use ndarray::{ArrayView3, ArrayViewMut3};
|
||||
|
||||
pub trait ButcherTableau {
|
||||
const S: usize = Self::B.len();
|
||||
|
@ -88,49 +88,45 @@ impl ButcherTableau for Rk6 {
|
|||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>(
|
||||
pub fn integrate<BTableau: ButcherTableau, F, RHS>(
|
||||
mut rhs: RHS,
|
||||
prev: &F,
|
||||
fut: &mut F,
|
||||
time: &mut Float,
|
||||
dt: Float,
|
||||
k: &mut [F],
|
||||
|
||||
constants: C,
|
||||
mut mutables: &mut MT,
|
||||
) where
|
||||
C: Copy,
|
||||
F: std::ops::Deref<Target = Array3<Float>> + std::ops::DerefMut<Target = Array3<Float>>,
|
||||
RHS: FnMut(&mut F, &F, Float, C, &mut MT),
|
||||
BTableau: ButcherTableau,
|
||||
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>,
|
||||
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>,
|
||||
RHS: FnMut(&mut F, &F, Float),
|
||||
{
|
||||
assert_eq!(prev.shape(), fut.shape());
|
||||
assert_eq!(prev.into().shape(), fut.into().shape());
|
||||
assert!(k.len() >= BTableau::S);
|
||||
|
||||
for i in 0.. {
|
||||
let simtime;
|
||||
match i {
|
||||
0 => {
|
||||
fut.assign(prev);
|
||||
fut.into().assign(&prev.into());
|
||||
simtime = *time;
|
||||
}
|
||||
i if i < BTableau::S => {
|
||||
fut.assign(prev);
|
||||
fut.into().assign(&prev.into());
|
||||
for (&a, k) in BTableau::A[i - 1].iter().zip(k.iter()) {
|
||||
if a == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.scaled_add(a * dt, &k);
|
||||
fut.into().scaled_add(a * dt, &k.into());
|
||||
}
|
||||
simtime = *time + dt * BTableau::C[i - 1];
|
||||
}
|
||||
_ if i == BTableau::S => {
|
||||
fut.assign(prev);
|
||||
fut.into().assign(&prev.into());
|
||||
for (&b, k) in BTableau::B.iter().zip(k.iter()) {
|
||||
if b == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.scaled_add(b * dt, &k);
|
||||
fut.into().scaled_add(b * dt, &k.into());
|
||||
}
|
||||
*time += dt;
|
||||
return;
|
||||
|
@ -140,13 +136,13 @@ pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>(
|
|||
}
|
||||
};
|
||||
|
||||
rhs(&mut k[i], &fut, simtime, constants, &mut mutables);
|
||||
rhs(&mut k[i], &fut, simtime);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
|
||||
pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS>(
|
||||
mut rhs: RHS,
|
||||
prev: &[F],
|
||||
fut: &mut [F],
|
||||
|
@ -154,17 +150,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
|
|||
dt: Float,
|
||||
k: &mut [&mut [F]],
|
||||
|
||||
constants: C,
|
||||
mut mutables: &mut MT,
|
||||
pool: &rayon::ThreadPool,
|
||||
) where
|
||||
C: Copy,
|
||||
F: std::ops::Deref<Target = Array3<Float>>
|
||||
+ std::ops::DerefMut<Target = Array3<Float>>
|
||||
+ Send
|
||||
+ Sync,
|
||||
RHS: FnMut(&mut [F], &[F], Float, C, &mut MT),
|
||||
BTableau: ButcherTableau,
|
||||
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>,
|
||||
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>,
|
||||
RHS: FnMut(&mut [F], &[F], Float),
|
||||
F: Send + Sync,
|
||||
{
|
||||
for i in 0.. {
|
||||
let simtime;
|
||||
|
@ -174,8 +165,8 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
|
|||
assert!(k.len() >= BTableau::S);
|
||||
for (prev, fut) in prev.iter().zip(fut.iter_mut()) {
|
||||
s.spawn(move |_| {
|
||||
assert_eq!(prev.shape(), fut.shape());
|
||||
fut.assign(prev);
|
||||
assert_eq!(prev.into().shape(), fut.into().shape());
|
||||
fut.into().assign(&prev.into());
|
||||
});
|
||||
}
|
||||
});
|
||||
|
@ -186,12 +177,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
|
|||
for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
|
||||
let k = &k;
|
||||
s.spawn(move |_| {
|
||||
fut.assign(prev);
|
||||
fut.into().assign(&prev.into());
|
||||
for (ik, &a) in BTableau::A[i - 1].iter().enumerate() {
|
||||
if a == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.scaled_add(a * dt, &k[ik][ig]);
|
||||
fut.into().scaled_add(a * dt, &(&k[ik][ig]).into());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -203,12 +194,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
|
|||
for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
|
||||
let k = &k;
|
||||
s.spawn(move |_| {
|
||||
fut.assign(prev);
|
||||
fut.into().assign(&prev.into());
|
||||
for (ik, &b) in BTableau::B.iter().enumerate() {
|
||||
if b == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.scaled_add(b * dt, &k[ik][ig]);
|
||||
fut.into().scaled_add(b * dt, &(&k[ik][ig]).into());
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -221,6 +212,6 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
|
|||
}
|
||||
};
|
||||
|
||||
rhs(&mut k[i], &fut, simtime, constants, &mut mutables);
|
||||
rhs(&mut k[i], &fut, simtime);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,17 @@ impl std::ops::DerefMut for Field {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
|
||||
fn from(f: &'a Field) -> Self {
|
||||
f.0.view()
|
||||
}
|
||||
}
|
||||
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
|
||||
fn from(f: &'a mut Field) -> Self {
|
||||
f.0.view_mut()
|
||||
}
|
||||
}
|
||||
|
||||
impl Field {
|
||||
pub fn new(height: usize, width: usize) -> Self {
|
||||
let field = Array3::zeros((3, height, width));
|
||||
|
@ -118,29 +129,20 @@ impl<SBP: SbpOperator2d> System<SBP> {
|
|||
|
||||
pub fn advance(&mut self, dt: Float) {
|
||||
let op = &self.op;
|
||||
let rhs_adaptor = move |fut: &mut Field,
|
||||
prev: &Field,
|
||||
_time: Float,
|
||||
c: &(&Grid, &Metrics),
|
||||
m: &mut (
|
||||
Array2<Float>,
|
||||
Array2<Float>,
|
||||
Array2<Float>,
|
||||
Array2<Float>,
|
||||
)| {
|
||||
let (grid, metrics) = c;
|
||||
RHS(op, fut, prev, grid, metrics, m);
|
||||
let grid = &self.grid;
|
||||
let metrics = &self.metrics;
|
||||
let wb = &mut self.wb.tmp;
|
||||
let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| {
|
||||
RHS(op, fut, prev, grid, metrics, wb);
|
||||
};
|
||||
let mut _time = 0.0;
|
||||
integrate::integrate::<integrate::Rk4, _, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, _, _>(
|
||||
rhs_adaptor,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
&mut _time,
|
||||
dt,
|
||||
&mut self.wb.k,
|
||||
&(&self.grid, &self.metrics),
|
||||
&mut self.wb.tmp,
|
||||
);
|
||||
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
|
||||
}
|
||||
|
@ -150,29 +152,20 @@ impl<UO: UpwindOperator2d> System<UO> {
|
|||
/// Using artificial dissipation with the upwind operator
|
||||
pub fn advance_upwind(&mut self, dt: Float) {
|
||||
let op = &self.op;
|
||||
let rhs_adaptor = move |fut: &mut Field,
|
||||
prev: &Field,
|
||||
_time: Float,
|
||||
c: &(&Grid, &Metrics),
|
||||
m: &mut (
|
||||
Array2<Float>,
|
||||
Array2<Float>,
|
||||
Array2<Float>,
|
||||
Array2<Float>,
|
||||
)| {
|
||||
let (grid, metrics) = c;
|
||||
RHS_upwind(op, fut, prev, grid, metrics, m);
|
||||
let grid = &self.grid;
|
||||
let metrics = &self.metrics;
|
||||
let wb = &mut self.wb.tmp;
|
||||
let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| {
|
||||
RHS_upwind(op, fut, prev, grid, metrics, wb);
|
||||
};
|
||||
let mut _time = 0.0;
|
||||
integrate::integrate::<integrate::Rk4, _, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, _, _>(
|
||||
rhs_adaptor,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
&mut _time,
|
||||
dt,
|
||||
&mut self.wb.k,
|
||||
&(&self.grid, &self.metrics),
|
||||
&mut self.wb.tmp,
|
||||
);
|
||||
std::mem::swap(&mut self.sys.0, &mut self.sys.1);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue