simplify integrate function
This commit is contained in:
		@@ -49,20 +49,20 @@ impl<SBP: SbpOperator2d> System<SBP> {
 | 
			
		||||
            west: BoundaryCharacteristic::This,
 | 
			
		||||
        };
 | 
			
		||||
        let op = &self.op;
 | 
			
		||||
        let rhs_trad = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| {
 | 
			
		||||
            let (grid, metrics) = gm;
 | 
			
		||||
        let wb = &mut self.wb.0;
 | 
			
		||||
        let grid = &self.grid.0;
 | 
			
		||||
        let metrics = &self.grid.1;
 | 
			
		||||
        let rhs_trad = |k: &mut Field, y: &Field, _time: Float| {
 | 
			
		||||
            let boundaries = boundary_extractor(y, grid, &bc);
 | 
			
		||||
            RHS_trad(op, k, y, metrics, &boundaries, wb)
 | 
			
		||||
        };
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _, _, _>(
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _>(
 | 
			
		||||
            rhs_trad,
 | 
			
		||||
            &self.sys.0,
 | 
			
		||||
            &mut self.sys.1,
 | 
			
		||||
            &mut 0.0,
 | 
			
		||||
            dt,
 | 
			
		||||
            &mut self.k,
 | 
			
		||||
            &self.grid,
 | 
			
		||||
            &mut self.wb.0,
 | 
			
		||||
        );
 | 
			
		||||
        std::mem::swap(&mut self.sys.0, &mut self.sys.1);
 | 
			
		||||
    }
 | 
			
		||||
@@ -110,20 +110,20 @@ impl<UO: UpwindOperator2d> System<UO> {
 | 
			
		||||
            west: BoundaryCharacteristic::This,
 | 
			
		||||
        };
 | 
			
		||||
        let op = &self.op;
 | 
			
		||||
        let rhs_upwind = |k: &mut Field, y: &Field, _time: Float, gm: &(_, _), wb: &mut _| {
 | 
			
		||||
            let (grid, metrics) = gm;
 | 
			
		||||
        let grid = &self.grid;
 | 
			
		||||
        let wb = &mut self.wb.0;
 | 
			
		||||
        let rhs_upwind = |k: &mut Field, y: &Field, _time: Float| {
 | 
			
		||||
            let (grid, metrics) = grid;
 | 
			
		||||
            let boundaries = boundary_extractor(y, grid, &bc);
 | 
			
		||||
            RHS_upwind(op, k, y, metrics, &boundaries, wb)
 | 
			
		||||
        };
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _, _, _>(
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _>(
 | 
			
		||||
            rhs_upwind,
 | 
			
		||||
            &self.sys.0,
 | 
			
		||||
            &mut self.sys.1,
 | 
			
		||||
            &mut 0.0,
 | 
			
		||||
            dt,
 | 
			
		||||
            &mut self.k,
 | 
			
		||||
            &self.grid,
 | 
			
		||||
            &mut self.wb.0,
 | 
			
		||||
        );
 | 
			
		||||
        std::mem::swap(&mut self.sys.0, &mut self.sys.1);
 | 
			
		||||
    }
 | 
			
		||||
@@ -146,6 +146,17 @@ impl std::ops::DerefMut for Field {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
 | 
			
		||||
    fn from(f: &'a Field) -> Self {
 | 
			
		||||
        f.0.view()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
 | 
			
		||||
    fn from(f: &'a mut Field) -> Self {
 | 
			
		||||
        f.0.view_mut()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Field {
 | 
			
		||||
    pub fn new(ny: usize, nx: usize) -> Self {
 | 
			
		||||
        let field = Array3::zeros((4, ny, nx));
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,5 @@
 | 
			
		||||
use super::Float;
 | 
			
		||||
use ndarray::Array3;
 | 
			
		||||
use ndarray::{ArrayView3, ArrayViewMut3};
 | 
			
		||||
 | 
			
		||||
pub trait ButcherTableau {
 | 
			
		||||
    const S: usize = Self::B.len();
 | 
			
		||||
@@ -88,49 +88,45 @@ impl ButcherTableau for Rk6 {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[allow(clippy::too_many_arguments)]
 | 
			
		||||
pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
pub fn integrate<BTableau: ButcherTableau, F, RHS>(
 | 
			
		||||
    mut rhs: RHS,
 | 
			
		||||
    prev: &F,
 | 
			
		||||
    fut: &mut F,
 | 
			
		||||
    time: &mut Float,
 | 
			
		||||
    dt: Float,
 | 
			
		||||
    k: &mut [F],
 | 
			
		||||
 | 
			
		||||
    constants: C,
 | 
			
		||||
    mut mutables: &mut MT,
 | 
			
		||||
) where
 | 
			
		||||
    C: Copy,
 | 
			
		||||
    F: std::ops::Deref<Target = Array3<Float>> + std::ops::DerefMut<Target = Array3<Float>>,
 | 
			
		||||
    RHS: FnMut(&mut F, &F, Float, C, &mut MT),
 | 
			
		||||
    BTableau: ButcherTableau,
 | 
			
		||||
    for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>,
 | 
			
		||||
    for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>,
 | 
			
		||||
    RHS: FnMut(&mut F, &F, Float),
 | 
			
		||||
{
 | 
			
		||||
    assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
    assert_eq!(prev.into().shape(), fut.into().shape());
 | 
			
		||||
    assert!(k.len() >= BTableau::S);
 | 
			
		||||
 | 
			
		||||
    for i in 0.. {
 | 
			
		||||
        let simtime;
 | 
			
		||||
        match i {
 | 
			
		||||
            0 => {
 | 
			
		||||
                fut.assign(prev);
 | 
			
		||||
                fut.into().assign(&prev.into());
 | 
			
		||||
                simtime = *time;
 | 
			
		||||
            }
 | 
			
		||||
            i if i < BTableau::S => {
 | 
			
		||||
                fut.assign(prev);
 | 
			
		||||
                fut.into().assign(&prev.into());
 | 
			
		||||
                for (&a, k) in BTableau::A[i - 1].iter().zip(k.iter()) {
 | 
			
		||||
                    if a == 0.0 {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
                    fut.scaled_add(a * dt, &k);
 | 
			
		||||
                    fut.into().scaled_add(a * dt, &k.into());
 | 
			
		||||
                }
 | 
			
		||||
                simtime = *time + dt * BTableau::C[i - 1];
 | 
			
		||||
            }
 | 
			
		||||
            _ if i == BTableau::S => {
 | 
			
		||||
                fut.assign(prev);
 | 
			
		||||
                fut.into().assign(&prev.into());
 | 
			
		||||
                for (&b, k) in BTableau::B.iter().zip(k.iter()) {
 | 
			
		||||
                    if b == 0.0 {
 | 
			
		||||
                        continue;
 | 
			
		||||
                    }
 | 
			
		||||
                    fut.scaled_add(b * dt, &k);
 | 
			
		||||
                    fut.into().scaled_add(b * dt, &k.into());
 | 
			
		||||
                }
 | 
			
		||||
                *time += dt;
 | 
			
		||||
                return;
 | 
			
		||||
@@ -140,13 +136,13 @@ pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        rhs(&mut k[i], &fut, simtime, constants, &mut mutables);
 | 
			
		||||
        rhs(&mut k[i], &fut, simtime);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(feature = "rayon")]
 | 
			
		||||
#[allow(clippy::too_many_arguments)]
 | 
			
		||||
pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS>(
 | 
			
		||||
    mut rhs: RHS,
 | 
			
		||||
    prev: &[F],
 | 
			
		||||
    fut: &mut [F],
 | 
			
		||||
@@ -154,17 +150,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
    dt: Float,
 | 
			
		||||
    k: &mut [&mut [F]],
 | 
			
		||||
 | 
			
		||||
    constants: C,
 | 
			
		||||
    mut mutables: &mut MT,
 | 
			
		||||
    pool: &rayon::ThreadPool,
 | 
			
		||||
) where
 | 
			
		||||
    C: Copy,
 | 
			
		||||
    F: std::ops::Deref<Target = Array3<Float>>
 | 
			
		||||
        + std::ops::DerefMut<Target = Array3<Float>>
 | 
			
		||||
        + Send
 | 
			
		||||
        + Sync,
 | 
			
		||||
    RHS: FnMut(&mut [F], &[F], Float, C, &mut MT),
 | 
			
		||||
    BTableau: ButcherTableau,
 | 
			
		||||
    for<'r> &'r F: std::convert::Into<ArrayView3<'r, Float>>,
 | 
			
		||||
    for<'r> &'r mut F: std::convert::Into<ArrayViewMut3<'r, Float>>,
 | 
			
		||||
    RHS: FnMut(&mut [F], &[F], Float),
 | 
			
		||||
    F: Send + Sync,
 | 
			
		||||
{
 | 
			
		||||
    for i in 0.. {
 | 
			
		||||
        let simtime;
 | 
			
		||||
@@ -174,8 +165,8 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
                    assert!(k.len() >= BTableau::S);
 | 
			
		||||
                    for (prev, fut) in prev.iter().zip(fut.iter_mut()) {
 | 
			
		||||
                        s.spawn(move |_| {
 | 
			
		||||
                            assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
                            fut.assign(prev);
 | 
			
		||||
                            assert_eq!(prev.into().shape(), fut.into().shape());
 | 
			
		||||
                            fut.into().assign(&prev.into());
 | 
			
		||||
                        });
 | 
			
		||||
                    }
 | 
			
		||||
                });
 | 
			
		||||
@@ -186,12 +177,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
                    for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
 | 
			
		||||
                        let k = &k;
 | 
			
		||||
                        s.spawn(move |_| {
 | 
			
		||||
                            fut.assign(prev);
 | 
			
		||||
                            fut.into().assign(&prev.into());
 | 
			
		||||
                            for (ik, &a) in BTableau::A[i - 1].iter().enumerate() {
 | 
			
		||||
                                if a == 0.0 {
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                }
 | 
			
		||||
                                fut.scaled_add(a * dt, &k[ik][ig]);
 | 
			
		||||
                                fut.into().scaled_add(a * dt, &(&k[ik][ig]).into());
 | 
			
		||||
                            }
 | 
			
		||||
                        });
 | 
			
		||||
                    }
 | 
			
		||||
@@ -203,12 +194,12 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
                    for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
 | 
			
		||||
                        let k = &k;
 | 
			
		||||
                        s.spawn(move |_| {
 | 
			
		||||
                            fut.assign(prev);
 | 
			
		||||
                            fut.into().assign(&prev.into());
 | 
			
		||||
                            for (ik, &b) in BTableau::B.iter().enumerate() {
 | 
			
		||||
                                if b == 0.0 {
 | 
			
		||||
                                    continue;
 | 
			
		||||
                                }
 | 
			
		||||
                                fut.scaled_add(b * dt, &k[ik][ig]);
 | 
			
		||||
                                fut.into().scaled_add(b * dt, &(&k[ik][ig]).into());
 | 
			
		||||
                            }
 | 
			
		||||
                        });
 | 
			
		||||
                    }
 | 
			
		||||
@@ -221,6 +212,6 @@ pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>(
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        rhs(&mut k[i], &fut, simtime, constants, &mut mutables);
 | 
			
		||||
        rhs(&mut k[i], &fut, simtime);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,17 @@ impl std::ops::DerefMut for Field {
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
 | 
			
		||||
    fn from(f: &'a Field) -> Self {
 | 
			
		||||
        f.0.view()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
 | 
			
		||||
    fn from(f: &'a mut Field) -> Self {
 | 
			
		||||
        f.0.view_mut()
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl Field {
 | 
			
		||||
    pub fn new(height: usize, width: usize) -> Self {
 | 
			
		||||
        let field = Array3::zeros((3, height, width));
 | 
			
		||||
@@ -118,29 +129,20 @@ impl<SBP: SbpOperator2d> System<SBP> {
 | 
			
		||||
 | 
			
		||||
    pub fn advance(&mut self, dt: Float) {
 | 
			
		||||
        let op = &self.op;
 | 
			
		||||
        let rhs_adaptor = move |fut: &mut Field,
 | 
			
		||||
                                prev: &Field,
 | 
			
		||||
                                _time: Float,
 | 
			
		||||
                                c: &(&Grid, &Metrics),
 | 
			
		||||
                                m: &mut (
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
        )| {
 | 
			
		||||
            let (grid, metrics) = c;
 | 
			
		||||
            RHS(op, fut, prev, grid, metrics, m);
 | 
			
		||||
        let grid = &self.grid;
 | 
			
		||||
        let metrics = &self.metrics;
 | 
			
		||||
        let wb = &mut self.wb.tmp;
 | 
			
		||||
        let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| {
 | 
			
		||||
            RHS(op, fut, prev, grid, metrics, wb);
 | 
			
		||||
        };
 | 
			
		||||
        let mut _time = 0.0;
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _, _, _>(
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _>(
 | 
			
		||||
            rhs_adaptor,
 | 
			
		||||
            &self.sys.0,
 | 
			
		||||
            &mut self.sys.1,
 | 
			
		||||
            &mut _time,
 | 
			
		||||
            dt,
 | 
			
		||||
            &mut self.wb.k,
 | 
			
		||||
            &(&self.grid, &self.metrics),
 | 
			
		||||
            &mut self.wb.tmp,
 | 
			
		||||
        );
 | 
			
		||||
        std::mem::swap(&mut self.sys.0, &mut self.sys.1);
 | 
			
		||||
    }
 | 
			
		||||
@@ -150,29 +152,20 @@ impl<UO: UpwindOperator2d> System<UO> {
 | 
			
		||||
    /// Using artificial dissipation with the upwind operator
 | 
			
		||||
    pub fn advance_upwind(&mut self, dt: Float) {
 | 
			
		||||
        let op = &self.op;
 | 
			
		||||
        let rhs_adaptor = move |fut: &mut Field,
 | 
			
		||||
                                prev: &Field,
 | 
			
		||||
                                _time: Float,
 | 
			
		||||
                                c: &(&Grid, &Metrics),
 | 
			
		||||
                                m: &mut (
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
            Array2<Float>,
 | 
			
		||||
        )| {
 | 
			
		||||
            let (grid, metrics) = c;
 | 
			
		||||
            RHS_upwind(op, fut, prev, grid, metrics, m);
 | 
			
		||||
        let grid = &self.grid;
 | 
			
		||||
        let metrics = &self.metrics;
 | 
			
		||||
        let wb = &mut self.wb.tmp;
 | 
			
		||||
        let rhs_adaptor = move |fut: &mut Field, prev: &Field, _time: Float| {
 | 
			
		||||
            RHS_upwind(op, fut, prev, grid, metrics, wb);
 | 
			
		||||
        };
 | 
			
		||||
        let mut _time = 0.0;
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _, _, _>(
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, _, _>(
 | 
			
		||||
            rhs_adaptor,
 | 
			
		||||
            &self.sys.0,
 | 
			
		||||
            &mut self.sys.1,
 | 
			
		||||
            &mut _time,
 | 
			
		||||
            dt,
 | 
			
		||||
            &mut self.wb.k,
 | 
			
		||||
            &(&self.grid, &self.metrics),
 | 
			
		||||
            &mut self.wb.tmp,
 | 
			
		||||
        );
 | 
			
		||||
        std::mem::swap(&mut self.sys.0, &mut self.sys.1);
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user