diff --git a/multigrid/Cargo.toml b/multigrid/Cargo.toml index 7d783b1..7657769 100644 --- a/multigrid/Cargo.toml +++ b/multigrid/Cargo.toml @@ -9,7 +9,7 @@ edition = "2018" sbp = { path = "../sbp", features = ["serde1", "fast-float"] } euler = { path = "../euler", features = ["serde1"] } hdf5 = "0.7.0" -integrate = { path = "../utils/integrate", features = ["rayon"] } +integrate = { path = "../utils/integrate" } rayon = "1.3.0" indicatif = "0.15.0" structopt = "0.3.14" diff --git a/multigrid/src/main.rs b/multigrid/src/main.rs index 489f0bb..f364dff 100644 --- a/multigrid/src/main.rs +++ b/multigrid/src/main.rs @@ -1,3 +1,4 @@ +use rayon::prelude::*; use structopt::StructOpt; use sbp::operators::SbpOperator2d; @@ -20,6 +21,21 @@ struct System { operators: Vec>, } +impl integrate::Integrable for System { + type State = Vec; + type Diff = Vec; + fn assign(s: &mut Self::State, o: &Self::State) { + s.par_iter_mut() + .zip(o.par_iter()) + .for_each(|(s, o)| euler::Field::assign(s, o)) + } + fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) { + s.par_iter_mut() + .zip(o.par_iter()) + .for_each(|(s, o)| euler::Field::scaled_add(s, o, scale)) + } +} + impl System { fn new( grids: Vec, @@ -68,7 +84,7 @@ impl System { } } - fn advance(&mut self, dt: Float, pool: &rayon::ThreadPool) { + fn advance(&mut self, dt: Float) { let metrics = &self.metrics; let grids = &self.grids; let bt = &self.bt; @@ -76,9 +92,9 @@ impl System { let eb = &mut self.eb; let operators = &self.operators; - let rhs = move |fut: &mut [euler::Field], prev: &[euler::Field], time: Float| { + let rhs = move |fut: &mut Vec, prev: &Vec, time: Float| { let prev_all = &prev; - pool.scope(|s| { + rayon::scope(|s| { for (((((((fut, prev), wb), grid), metrics), op), bt), eb) in fut .iter_mut() .zip(prev.iter()) @@ -101,19 +117,13 @@ impl System { }); }; - let mut k = self - .k - .iter_mut() - .map(|k| k.as_mut_slice()) - .collect::>(); - integrate::integrate_multigrid::( + integrate::integrate::( rhs, &self.fnow, &mut self.fnext, &mut self.time, dt, - &mut k, - pool, + &mut self.k, ); std::mem::swap(&mut self.fnow, &mut self.fnext); @@ -178,7 +188,7 @@ struct Options { no_progressbar: bool, /// Number of simultaneous threads #[structopt(short, long)] - jobs: Option>, + jobs: Option, /// Name of output file #[structopt(default_value = "output.hdf", long, short)] output: std::path::PathBuf, @@ -241,20 +251,13 @@ fn main() { let ntime = (integration_time / dt).round() as u64; - let pool = { - let builder = rayon::ThreadPoolBuilder::new(); - if let Some(j) = opt.jobs { - if let Some(j) = j { - builder.num_threads(j) - } else { - builder - } - } else { - builder.num_threads(1) - } - .build() - .unwrap() - }; + { + let nthreads = opt.jobs.unwrap_or(1); + rayon::ThreadPoolBuilder::new() + .num_threads(nthreads) + .build_global() + .unwrap(); + } let should_output = |itime| { opt.number_of_outputs.map_or(false, |num_out| { @@ -282,7 +285,7 @@ fn main() { output.add_timestep(itime, &sys.fnow); } progressbar.inc(1); - sys.advance(dt, &pool); + sys.advance(dt); } progressbar.finish_and_clear(); diff --git a/utils/integrate/Cargo.toml b/utils/integrate/Cargo.toml index 1317c29..ca63675 100644 --- a/utils/integrate/Cargo.toml +++ b/utils/integrate/Cargo.toml @@ -6,4 +6,3 @@ edition = "2018" [dependencies] float = { path = "../float/" } -rayon = { version = "1.5.0", optional = true } diff --git a/utils/integrate/src/lib.rs b/utils/integrate/src/lib.rs index 1c26324..1e49e72 100644 --- a/utils/integrate/src/lib.rs +++ b/utils/integrate/src/lib.rs @@ -243,90 +243,6 @@ pub fn integrate_embedded_rk( - mut rhs: RHS, - prev: &[F::State], - fut: &mut [F::State], - time: &mut Float, - dt: Float, - k: &mut [&mut [F::Diff]], - - pool: &rayon::ThreadPool, -) where - RHS: FnMut(&mut [F::Diff], &[F::State], Float), - F::State: Send + Sync, - F::Diff: Send + Sync, -{ - for i in 0.. { - let simtime; - match i { - 0 => { - pool.scope(|s| { - assert!(k.len() >= BTableau::S); - for (prev, fut) in prev.iter().zip(fut.iter_mut()) { - s.spawn(move |_| { - F::assign(fut, prev); - }); - } - }); - simtime = *time; - } - i if i < BTableau::S => { - pool.scope(|s| { - for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { - let k = &k; - s.spawn(move |_| { - F::assign(fut, prev); - for (ik, &a) in BTableau::A[i - 1].iter().enumerate() { - if a == 0.0 { - continue; - } - F::scaled_add(fut, &k[ik][ig], a * dt); - } - }); - } - }); - simtime = *time + dt * BTableau::C[i - 1]; - } - _ if i == BTableau::S => { - pool.scope(|s| { - for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { - let k = &k; - s.spawn(move |_| { - F::assign(fut, prev); - for (ik, &b) in BTableau::B.iter().enumerate() { - if b == 0.0 { - continue; - } - F::scaled_add(fut, &k[ik][ig], b * dt); - } - }); - } - }); - *time += dt; - return; - } - _ => { - unreachable!(); - } - }; - - rhs(&mut k[i], &fut, simtime); - } -} - #[test] /// Solving a second order PDE fn ballistic() {