simple rayon parallelisation

This commit is contained in:
Magnus Ulimoen 2020-04-02 23:36:20 +02:00
parent a31ca3ff3e
commit 97bdf7b0a5
3 changed files with 136 additions and 2 deletions

View File

@ -21,6 +21,7 @@ f32 = []
criterion = "0.3.0" criterion = "0.3.0"
structopt = "0.3.12" structopt = "0.3.12"
indicatif = "0.14.0" indicatif = "0.14.0"
rayon = "1.3.0"
[[bench]] [[bench]]
name = "maxwell" name = "maxwell"

View File

@ -153,6 +153,122 @@ impl<T: operators::UpwindOperator> System<T> {
} }
} }
} }
fn advance_parallel(&mut self, dt: Float, s: &rayon::ThreadPool) {
for i in 0.. {
match i {
0 => {
s.scope(|s| {
for (prev, fut) in self.fnow.iter().zip(self.fnext.iter_mut()) {
s.spawn(move |_| {
fut.assign(prev);
});
}
});
}
1 | 2 => {
s.scope(|s| {
for ((prev, fut), k) in self
.fnow
.iter()
.zip(self.fnext.iter_mut())
.zip(&self.k[i - 1])
{
s.spawn(move |_| {
fut.assign(prev);
fut.scaled_add(1.0 / 2.0 * dt, k);
});
}
});
}
3 => {
s.scope(|s| {
for ((prev, fut), k) in self
.fnow
.iter()
.zip(self.fnext.iter_mut())
.zip(&self.k[i - 1])
{
s.spawn(move |_| {
fut.assign(prev);
fut.scaled_add(dt, k);
});
}
});
}
4 => {
s.scope(|s| {
for (((((prev, fut), k0), k1), k2), k3) in self
.fnow
.iter()
.zip(self.fnext.iter_mut())
.zip(&self.k[0])
.zip(&self.k[1])
.zip(&self.k[2])
.zip(&self.k[3])
{
s.spawn(move |_| {
ndarray::Zip::from(&mut **fut)
.and(&**prev)
.and(&**k0)
.and(&**k1)
.and(&**k2)
.and(&**k3)
.apply(|y1, &y0, &k1, &k2, &k3, &k4| {
*y1 = y0 + dt / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
});
});
}
});
std::mem::swap(&mut self.fnext, &mut self.fnow);
return;
}
_ => {
unreachable!();
}
}
s.scope(|s| {
let fields = &self.fnext;
let bt = self
.bt
.iter()
.enumerate()
.map(|(i, bt)| euler::BoundaryTerms {
north: match bt.north {
euler::BoundaryCharacteristic::This => fields[i].south(),
euler::BoundaryCharacteristic::Grid(g) => fields[g].south(),
euler::BoundaryCharacteristic::Vortex(_) => todo!(),
},
south: match bt.south {
euler::BoundaryCharacteristic::This => fields[i].north(),
euler::BoundaryCharacteristic::Grid(g) => fields[g].north(),
euler::BoundaryCharacteristic::Vortex(_) => todo!(),
},
west: match bt.west {
euler::BoundaryCharacteristic::This => fields[i].east(),
euler::BoundaryCharacteristic::Grid(g) => fields[g].east(),
euler::BoundaryCharacteristic::Vortex(_) => todo!(),
},
east: match bt.east {
euler::BoundaryCharacteristic::This => fields[i].west(),
euler::BoundaryCharacteristic::Grid(g) => fields[g].west(),
euler::BoundaryCharacteristic::Vortex(_) => todo!(),
},
})
.collect::<Vec<_>>();
for ((((prev, fut), grid), wb), bt) in fields
.iter()
.zip(&mut self.k[i])
.zip(&self.grids)
.zip(&mut self.wb)
.zip(bt)
{
s.spawn(move |_| euler::RHS_upwind(fut, prev, grid, &bt, wb));
}
});
}
}
} }
#[derive(Debug, StructOpt)] #[derive(Debug, StructOpt)]
@ -160,6 +276,8 @@ struct Options {
json: std::path::PathBuf, json: std::path::PathBuf,
#[structopt(long, help = "Disable the progressbar")] #[structopt(long, help = "Disable the progressbar")]
no_progressbar: bool, no_progressbar: bool,
#[structopt(short, long, help = "Number of simultaneous threads")]
jobs: Option<usize>,
} }
fn main() { fn main() {
@ -207,6 +325,17 @@ fn main() {
let ntime = (integration_time / dt).round() as usize; let ntime = (integration_time / dt).round() as usize;
let pool = if let Some(j) = opt.jobs {
Some(
rayon::ThreadPoolBuilder::new()
.num_threads(j)
.build()
.unwrap(),
)
} else {
None
};
let bar = if opt.no_progressbar { let bar = if opt.no_progressbar {
indicatif::ProgressBar::hidden() indicatif::ProgressBar::hidden()
} else { } else {
@ -218,8 +347,12 @@ fn main() {
}; };
for _ in 0..ntime { for _ in 0..ntime {
bar.inc(1); bar.inc(1);
if let Some(pool) = pool.as_ref() {
sys.advance_parallel(dt, &pool);
} else {
sys.advance(dt); sys.advance(dt);
} }
}
bar.finish(); bar.finish();
dump_to_file(&sys); dump_to_file(&sys);

View File

@ -5,7 +5,7 @@ use crate::Float;
use ndarray::{ArrayView2, ArrayViewMut2}; use ndarray::{ArrayView2, ArrayViewMut2};
pub trait SbpOperator { pub trait SbpOperator: Send + Sync {
fn diffxi(prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>); fn diffxi(prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>);
fn diffeta(prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>); fn diffeta(prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>);
fn h() -> &'static [Float]; fn h() -> &'static [Float];