Add boolean for switching serial/parallel execution

This commit is contained in:
Magnus Ulimoen 2021-03-26 00:00:42 +01:00
parent a33e1d37ba
commit 52c21dbbe9
1 changed files with 58 additions and 20 deletions

View File

@ -21,18 +21,33 @@ struct System {
operators: Vec<Box<dyn SbpOperator2d>>, operators: Vec<Box<dyn SbpOperator2d>>,
} }
use std::sync::atomic::{AtomicBool, Ordering};
pub(crate) static MULTITHREAD: AtomicBool = AtomicBool::new(false);
impl integrate::Integrable for System { impl integrate::Integrable for System {
type State = Vec<euler::Field>; type State = Vec<euler::Field>;
type Diff = Vec<euler::Field>; type Diff = Vec<euler::Field>;
fn assign(s: &mut Self::State, o: &Self::State) { fn assign(s: &mut Self::State, o: &Self::State) {
if MULTITHREAD.load(Ordering::Acquire) {
s.par_iter_mut() s.par_iter_mut()
.zip(o.par_iter()) .zip(o.par_iter())
.for_each(|(s, o)| euler::Field::assign(s, o)) .for_each(|(s, o)| euler::Field::assign(s, o))
} else {
s.iter_mut()
.zip(o.iter())
.for_each(|(s, o)| euler::Field::assign(s, o))
}
} }
fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) { fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) {
if MULTITHREAD.load(Ordering::Acquire) {
s.par_iter_mut() s.par_iter_mut()
.zip(o.par_iter()) .zip(o.par_iter())
.for_each(|(s, o)| euler::Field::scaled_add(s, o, scale)) .for_each(|(s, o)| euler::Field::scaled_add(s, o, scale))
} else {
s.iter_mut()
.zip(o.iter())
.for_each(|(s, o)| euler::Field::scaled_add(s, o, scale))
}
} }
} }
@ -94,6 +109,7 @@ impl System {
let rhs = move |fut: &mut Vec<euler::Field>, prev: &Vec<euler::Field>, time: Float| { let rhs = move |fut: &mut Vec<euler::Field>, prev: &Vec<euler::Field>, time: Float| {
let prev_all = &prev; let prev_all = &prev;
if MULTITHREAD.load(Ordering::Acquire) {
rayon::scope(|s| { rayon::scope(|s| {
for (((((((fut, prev), wb), grid), metrics), op), bt), eb) in fut for (((((((fut, prev), wb), grid), metrics), op), bt), eb) in fut
.iter_mut() .iter_mut()
@ -115,6 +131,25 @@ impl System {
}) })
} }
}); });
} else {
for (((((((fut, prev), wb), grid), metrics), op), bt), eb) in fut
.iter_mut()
.zip(prev.iter())
.zip(wb.iter_mut())
.zip(grids)
.zip(metrics.iter())
.zip(operators.iter())
.zip(bt.iter())
.zip(eb.iter_mut())
{
let bc = euler::boundary_extracts(prev_all, bt, prev, grid, eb, time);
if op.upwind().is_some() {
euler::RHS_upwind(&**op, fut, prev, metrics, &bc, &mut wb.0);
} else {
euler::RHS_trad(&**op, fut, prev, metrics, &bc, &mut wb.0);
}
}
}
}; };
integrate::integrate::<integrate::Rk4, System, _>( integrate::integrate::<integrate::Rk4, System, _>(
@ -259,11 +294,14 @@ fn main() {
{ {
let nthreads = opt.jobs.unwrap_or(1); let nthreads = opt.jobs.unwrap_or(1);
if nthreads > 1 {
MULTITHREAD.store(true, Ordering::Release);
rayon::ThreadPoolBuilder::new() rayon::ThreadPoolBuilder::new()
.num_threads(nthreads) .num_threads(nthreads)
.build_global() .build_global()
.unwrap(); .unwrap();
} }
}
let should_output = |itime| { let should_output = |itime| {
opt.number_of_outputs.map_or(false, |num_out| { opt.number_of_outputs.map_or(false, |num_out| {