Make integrate generic over D

This commit is contained in:
Magnus Ulimoen 2020-09-17 20:30:58 +02:00
parent 5bbc13158a
commit 0fc9ec64ec
5 changed files with 22 additions and 19 deletions

View File

@ -57,7 +57,7 @@ impl<SBP: SbpOperator2d> System<SBP> {
let boundaries = boundary_extractor(y, grid, &bc); let boundaries = boundary_extractor(y, grid, &bc);
RHS_trad(op, k, y, metrics, &boundaries, wb) RHS_trad(op, k, y, metrics, &boundaries, wb)
}; };
integrate::integrate::<integrate::Rk4, _, _>( integrate::integrate::<integrate::Rk4, _, _, _>(
rhs_trad, rhs_trad,
&self.sys.0, &self.sys.0,
&mut self.sys.1, &mut self.sys.1,
@ -131,7 +131,7 @@ impl<UO: UpwindOperator2d> System<UO> {
let boundaries = boundary_extractor(y, grid, &bc); let boundaries = boundary_extractor(y, grid, &bc);
RHS_upwind(op, k, y, metrics, &boundaries, wb) RHS_upwind(op, k, y, metrics, &boundaries, wb)
}; };
integrate::integrate::<integrate::Rk4, _, _>( integrate::integrate::<integrate::Rk4, _, _, _>(
rhs_upwind, rhs_upwind,
&self.sys.0, &self.sys.0,
&mut self.sys.1, &mut self.sys.1,
@ -159,7 +159,7 @@ impl<UO: UpwindOperator2d> System<UO> {
let mut time = 0.0; let mut time = 0.0;
let mut sys2 = self.sys.0.clone(); let mut sys2 = self.sys.0.clone();
while time < dt { while time < dt {
integrate::integrate_embedded_rk::<integrate::BogackiShampine, _, _>( integrate::integrate_embedded_rk::<integrate::BogackiShampine, _, _, _>(
&mut rhs_upwind, &mut rhs_upwind,
&self.sys.0, &self.sys.0,
&mut self.sys.1, &mut self.sys.1,

View File

@ -147,7 +147,7 @@ impl<SBP: SbpOperator2d> System<SBP> {
RHS(op, fut, prev, grid, metrics, wb); RHS(op, fut, prev, grid, metrics, wb);
}; };
let mut _time = 0.0; let mut _time = 0.0;
integrate::integrate::<integrate::Rk4, _, _>( integrate::integrate::<integrate::Rk4, _, _, _>(
rhs_adaptor, rhs_adaptor,
&self.sys.0, &self.sys.0,
&mut self.sys.1, &mut self.sys.1,
@ -170,7 +170,7 @@ impl<SBP: SbpOperator2d> System<SBP> {
); );
// sprs::lingalg::dsolve(..) // sprs::lingalg::dsolve(..)
}; };
sbp::integrate::integrate::<sbp::integrate::Rk4, _, _>( sbp::integrate::integrate::<sbp::integrate::Rk4, _, _, _>(
rhs_f, rhs_f,
&self.sys.0, &self.sys.0,
&mut self.sys.1, &mut self.sys.1,
@ -207,7 +207,7 @@ impl<UO: UpwindOperator2d> System<UO> {
RHS_upwind(op, fut, prev, grid, metrics, wb); RHS_upwind(op, fut, prev, grid, metrics, wb);
}; };
let mut _time = 0.0; let mut _time = 0.0;
integrate::integrate::<integrate::Rk4, _, _>( integrate::integrate::<integrate::Rk4, _, _, _>(
rhs_adaptor, rhs_adaptor,
&self.sys.0, &self.sys.0,
&mut self.sys.1, &mut self.sys.1,

View File

@ -115,7 +115,7 @@ impl System {
.iter_mut() .iter_mut()
.map(|k| k.as_mut_slice()) .map(|k| k.as_mut_slice())
.collect::<Vec<_>>(); .collect::<Vec<_>>();
sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _>( sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _, _>(
rhs, rhs,
&self.fnow, &self.fnow,
&mut self.fnext, &mut self.fnext,

View File

@ -1,5 +1,5 @@
use super::Float; use super::Float;
use ndarray::{ArrayView3, ArrayViewMut3}; use ndarray::{ArrayView, ArrayViewMut};
pub trait ButcherTableau { pub trait ButcherTableau {
const S: usize = Self::B.len(); const S: usize = Self::B.len();
@ -137,7 +137,7 @@ impl EmbeddedButcherTableau for BogackiShampine {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn integrate<BTableau: ButcherTableau, F, RHS>( pub fn integrate<BTableau: ButcherTableau, F, RHS, D>(
mut rhs: RHS, mut rhs: RHS,
prev: &F, prev: &F,
fut: &mut F, fut: &mut F,
@ -145,8 +145,9 @@ pub fn integrate<BTableau: ButcherTableau, F, RHS>(
dt: Float, dt: Float,
k: &mut [F], k: &mut [F],
) where ) where
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>, for<'r> &'r F: std::convert::Into<ArrayView<'r, Float, D>>,
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>, for<'r> &'r mut F: std::convert::Into<ArrayViewMut<'r, Float, D>>,
D: ndarray::Dimension,
RHS: FnMut(&mut F, &F, Float), RHS: FnMut(&mut F, &F, Float),
{ {
assert_eq!(prev.into().shape(), fut.into().shape()); assert_eq!(prev.into().shape(), fut.into().shape());
@ -190,7 +191,7 @@ pub fn integrate<BTableau: ButcherTableau, F, RHS>(
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F, RHS>( pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F, RHS, D>(
rhs: RHS, rhs: RHS,
prev: &F, prev: &F,
fut: &mut F, fut: &mut F,
@ -199,11 +200,12 @@ pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F, RHS>(
dt: Float, dt: Float,
k: &mut [F], k: &mut [F],
) where ) where
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>, for<'r> &'r F: std::convert::Into<ArrayView<'r, Float, D>>,
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>, for<'r> &'r mut F: std::convert::Into<ArrayViewMut<'r, Float, D>>,
RHS: FnMut(&mut F, &F, Float), RHS: FnMut(&mut F, &F, Float),
D: ndarray::Dimension,
{ {
integrate::<BTableau, F, RHS>(rhs, prev, fut, time, dt, k); integrate::<BTableau, F, RHS, D>(rhs, prev, fut, time, dt, k);
fut2.into().assign(&prev.into()); fut2.into().assign(&prev.into());
for (&b, k) in BTableau::BSTAR.iter().zip(k.iter()) { for (&b, k) in BTableau::BSTAR.iter().zip(k.iter()) {
if b == 0.0 { if b == 0.0 {
@ -215,7 +217,7 @@ pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F, RHS>(
#[cfg(feature = "rayon")] #[cfg(feature = "rayon")]
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS>( pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS, D>(
mut rhs: RHS, mut rhs: RHS,
prev: &[F], prev: &[F],
fut: &mut [F], fut: &mut [F],
@ -225,10 +227,11 @@ pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS>(
pool: &rayon::ThreadPool, pool: &rayon::ThreadPool,
) where ) where
for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>, for<'r> &'r F: std::convert::Into<ArrayView<'r, Float, D>>,
for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>, for<'r> &'r mut F: std::convert::Into<ArrayViewMut<'r, Float, D>>,
RHS: FnMut(&mut [F], &[F], Float), RHS: FnMut(&mut [F], &[F], Float),
F: Send + Sync, F: Send + Sync,
D: ndarray::Dimension,
{ {
for i in 0.. { for i in 0.. {
let simtime; let simtime;

View File

@ -298,7 +298,7 @@ impl System {
} }
log::trace!("Iteration complete"); log::trace!("Iteration complete");
}; };
integrate::integrate::<integrate::Rk4, _, _>( integrate::integrate::<integrate::Rk4, _, _, _>(
rhs, rhs,
&self.fnow, &self.fnow,
&mut self.fnext, &mut self.fnext,