diff --git a/sbp/Cargo.toml b/sbp/Cargo.toml index 7b3738d..69259f1 100644 --- a/sbp/Cargo.toml +++ b/sbp/Cargo.toml @@ -21,6 +21,7 @@ f32 = [] criterion = "0.3.0" structopt = "0.3.12" indicatif = "0.14.0" +rayon = "1.3.0" [[bench]] name = "maxwell" diff --git a/sbp/examples/multigrid.rs b/sbp/examples/multigrid.rs index 5fe4d11..aebd03d 100644 --- a/sbp/examples/multigrid.rs +++ b/sbp/examples/multigrid.rs @@ -153,6 +153,122 @@ impl System { } } } + + fn advance_parallel(&mut self, dt: Float, s: &rayon::ThreadPool) { + for i in 0.. { + match i { + 0 => { + s.scope(|s| { + for (prev, fut) in self.fnow.iter().zip(self.fnext.iter_mut()) { + s.spawn(move |_| { + fut.assign(prev); + }); + } + }); + } + 1 | 2 => { + s.scope(|s| { + for ((prev, fut), k) in self + .fnow + .iter() + .zip(self.fnext.iter_mut()) + .zip(&self.k[i - 1]) + { + s.spawn(move |_| { + fut.assign(prev); + fut.scaled_add(1.0 / 2.0 * dt, k); + }); + } + }); + } + 3 => { + s.scope(|s| { + for ((prev, fut), k) in self + .fnow + .iter() + .zip(self.fnext.iter_mut()) + .zip(&self.k[i - 1]) + { + s.spawn(move |_| { + fut.assign(prev); + fut.scaled_add(dt, k); + }); + } + }); + } + 4 => { + s.scope(|s| { + for (((((prev, fut), k0), k1), k2), k3) in self + .fnow + .iter() + .zip(self.fnext.iter_mut()) + .zip(&self.k[0]) + .zip(&self.k[1]) + .zip(&self.k[2]) + .zip(&self.k[3]) + { + s.spawn(move |_| { + ndarray::Zip::from(&mut **fut) + .and(&**prev) + .and(&**k0) + .and(&**k1) + .and(&**k2) + .and(&**k3) + .apply(|y1, &y0, &k1, &k2, &k3, &k4| { + *y1 = y0 + dt / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4) + }); + }); + } + }); + std::mem::swap(&mut self.fnext, &mut self.fnow); + return; + } + _ => { + unreachable!(); + } + } + + s.scope(|s| { + let fields = &self.fnext; + let bt = self + .bt + .iter() + .enumerate() + .map(|(i, bt)| euler::BoundaryTerms { + north: match bt.north { + euler::BoundaryCharacteristic::This => fields[i].south(), + euler::BoundaryCharacteristic::Grid(g) => fields[g].south(), + euler::BoundaryCharacteristic::Vortex(_) => todo!(), + }, + south: match bt.south { + euler::BoundaryCharacteristic::This => fields[i].north(), + euler::BoundaryCharacteristic::Grid(g) => fields[g].north(), + euler::BoundaryCharacteristic::Vortex(_) => todo!(), + }, + west: match bt.west { + euler::BoundaryCharacteristic::This => fields[i].east(), + euler::BoundaryCharacteristic::Grid(g) => fields[g].east(), + euler::BoundaryCharacteristic::Vortex(_) => todo!(), + }, + east: match bt.east { + euler::BoundaryCharacteristic::This => fields[i].west(), + euler::BoundaryCharacteristic::Grid(g) => fields[g].west(), + euler::BoundaryCharacteristic::Vortex(_) => todo!(), + }, + }) + .collect::>(); + for ((((prev, fut), grid), wb), bt) in fields + .iter() + .zip(&mut self.k[i]) + .zip(&self.grids) + .zip(&mut self.wb) + .zip(bt) + { + s.spawn(move |_| euler::RHS_upwind(fut, prev, grid, &bt, wb)); + } + }); + } + } } #[derive(Debug, StructOpt)] @@ -160,6 +276,8 @@ struct Options { json: std::path::PathBuf, #[structopt(long, help = "Disable the progressbar")] no_progressbar: bool, + #[structopt(short, long, help = "Number of simultaneous threads")] + jobs: Option, } fn main() { @@ -207,6 +325,17 @@ fn main() { let ntime = (integration_time / dt).round() as usize; + let pool = if let Some(j) = opt.jobs { + Some( + rayon::ThreadPoolBuilder::new() + .num_threads(j) + .build() + .unwrap(), + ) + } else { + None + }; + let bar = if opt.no_progressbar { indicatif::ProgressBar::hidden() } else { @@ -218,7 +347,11 @@ fn main() { }; for _ in 0..ntime { bar.inc(1); - sys.advance(dt); + if let Some(pool) = pool.as_ref() { + sys.advance_parallel(dt, &pool); + } else { + sys.advance(dt); + } } bar.finish(); diff --git a/sbp/src/operators.rs b/sbp/src/operators.rs index 83090ab..a624f2b 100644 --- a/sbp/src/operators.rs +++ b/sbp/src/operators.rs @@ -5,7 +5,7 @@ use crate::Float; use ndarray::{ArrayView2, ArrayViewMut2}; -pub trait SbpOperator { +pub trait SbpOperator: Send + Sync { fn diffxi(prev: ArrayView2, fut: ArrayViewMut2); fn diffeta(prev: ArrayView2, fut: ArrayViewMut2); fn h() -> &'static [Float];