Add boolean for switching serial/parallel execution
This commit is contained in:
		@@ -21,18 +21,33 @@ struct System {
 | 
			
		||||
    operators: Vec<Box<dyn SbpOperator2d>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
use std::sync::atomic::{AtomicBool, Ordering};
 | 
			
		||||
pub(crate) static MULTITHREAD: AtomicBool = AtomicBool::new(false);
 | 
			
		||||
 | 
			
		||||
impl integrate::Integrable for System {
 | 
			
		||||
    type State = Vec<euler::Field>;
 | 
			
		||||
    type Diff = Vec<euler::Field>;
 | 
			
		||||
    fn assign(s: &mut Self::State, o: &Self::State) {
 | 
			
		||||
        s.par_iter_mut()
 | 
			
		||||
            .zip(o.par_iter())
 | 
			
		||||
            .for_each(|(s, o)| euler::Field::assign(s, o))
 | 
			
		||||
        if MULTITHREAD.load(Ordering::Acquire) {
 | 
			
		||||
            s.par_iter_mut()
 | 
			
		||||
                .zip(o.par_iter())
 | 
			
		||||
                .for_each(|(s, o)| euler::Field::assign(s, o))
 | 
			
		||||
        } else {
 | 
			
		||||
            s.iter_mut()
 | 
			
		||||
                .zip(o.iter())
 | 
			
		||||
                .for_each(|(s, o)| euler::Field::assign(s, o))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) {
 | 
			
		||||
        s.par_iter_mut()
 | 
			
		||||
            .zip(o.par_iter())
 | 
			
		||||
            .for_each(|(s, o)| euler::Field::scaled_add(s, o, scale))
 | 
			
		||||
        if MULTITHREAD.load(Ordering::Acquire) {
 | 
			
		||||
            s.par_iter_mut()
 | 
			
		||||
                .zip(o.par_iter())
 | 
			
		||||
                .for_each(|(s, o)| euler::Field::scaled_add(s, o, scale))
 | 
			
		||||
        } else {
 | 
			
		||||
            s.iter_mut()
 | 
			
		||||
                .zip(o.iter())
 | 
			
		||||
                .for_each(|(s, o)| euler::Field::scaled_add(s, o, scale))
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -94,7 +109,29 @@ impl System {
 | 
			
		||||
 | 
			
		||||
        let rhs = move |fut: &mut Vec<euler::Field>, prev: &Vec<euler::Field>, time: Float| {
 | 
			
		||||
            let prev_all = &prev;
 | 
			
		||||
            rayon::scope(|s| {
 | 
			
		||||
            if MULTITHREAD.load(Ordering::Acquire) {
 | 
			
		||||
                rayon::scope(|s| {
 | 
			
		||||
                    for (((((((fut, prev), wb), grid), metrics), op), bt), eb) in fut
 | 
			
		||||
                        .iter_mut()
 | 
			
		||||
                        .zip(prev.iter())
 | 
			
		||||
                        .zip(wb.iter_mut())
 | 
			
		||||
                        .zip(grids)
 | 
			
		||||
                        .zip(metrics.iter())
 | 
			
		||||
                        .zip(operators.iter())
 | 
			
		||||
                        .zip(bt.iter())
 | 
			
		||||
                        .zip(eb.iter_mut())
 | 
			
		||||
                    {
 | 
			
		||||
                        s.spawn(move |_| {
 | 
			
		||||
                            let bc = euler::boundary_extracts(prev_all, bt, prev, grid, eb, time);
 | 
			
		||||
                            if op.upwind().is_some() {
 | 
			
		||||
                                euler::RHS_upwind(&**op, fut, prev, metrics, &bc, &mut wb.0);
 | 
			
		||||
                            } else {
 | 
			
		||||
                                euler::RHS_trad(&**op, fut, prev, metrics, &bc, &mut wb.0);
 | 
			
		||||
                            }
 | 
			
		||||
                        })
 | 
			
		||||
                    }
 | 
			
		||||
                });
 | 
			
		||||
            } else {
 | 
			
		||||
                for (((((((fut, prev), wb), grid), metrics), op), bt), eb) in fut
 | 
			
		||||
                    .iter_mut()
 | 
			
		||||
                    .zip(prev.iter())
 | 
			
		||||
@@ -105,16 +142,14 @@ impl System {
 | 
			
		||||
                    .zip(bt.iter())
 | 
			
		||||
                    .zip(eb.iter_mut())
 | 
			
		||||
                {
 | 
			
		||||
                    s.spawn(move |_| {
 | 
			
		||||
                        let bc = euler::boundary_extracts(prev_all, bt, prev, grid, eb, time);
 | 
			
		||||
                        if op.upwind().is_some() {
 | 
			
		||||
                            euler::RHS_upwind(&**op, fut, prev, metrics, &bc, &mut wb.0);
 | 
			
		||||
                        } else {
 | 
			
		||||
                            euler::RHS_trad(&**op, fut, prev, metrics, &bc, &mut wb.0);
 | 
			
		||||
                        }
 | 
			
		||||
                    })
 | 
			
		||||
                    let bc = euler::boundary_extracts(prev_all, bt, prev, grid, eb, time);
 | 
			
		||||
                    if op.upwind().is_some() {
 | 
			
		||||
                        euler::RHS_upwind(&**op, fut, prev, metrics, &bc, &mut wb.0);
 | 
			
		||||
                    } else {
 | 
			
		||||
                        euler::RHS_trad(&**op, fut, prev, metrics, &bc, &mut wb.0);
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        integrate::integrate::<integrate::Rk4, System, _>(
 | 
			
		||||
@@ -259,10 +294,13 @@ fn main() {
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        let nthreads = opt.jobs.unwrap_or(1);
 | 
			
		||||
        rayon::ThreadPoolBuilder::new()
 | 
			
		||||
            .num_threads(nthreads)
 | 
			
		||||
            .build_global()
 | 
			
		||||
            .unwrap();
 | 
			
		||||
        if nthreads > 1 {
 | 
			
		||||
            MULTITHREAD.store(true, Ordering::Release);
 | 
			
		||||
            rayon::ThreadPoolBuilder::new()
 | 
			
		||||
                .num_threads(nthreads)
 | 
			
		||||
                .build_global()
 | 
			
		||||
                .unwrap();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    let should_output = |itime| {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user