simple rayon parallelisation
This commit is contained in:
		@@ -21,6 +21,7 @@ f32 = []
 | 
			
		||||
criterion = "0.3.0"
 | 
			
		||||
structopt = "0.3.12"
 | 
			
		||||
indicatif = "0.14.0"
 | 
			
		||||
rayon = "1.3.0"
 | 
			
		||||
 | 
			
		||||
[[bench]]
 | 
			
		||||
name = "maxwell"
 | 
			
		||||
 
 | 
			
		||||
@@ -153,6 +153,122 @@ impl<T: operators::UpwindOperator> System<T> {
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn advance_parallel(&mut self, dt: Float, s: &rayon::ThreadPool) {
 | 
			
		||||
        for i in 0.. {
 | 
			
		||||
            match i {
 | 
			
		||||
                0 => {
 | 
			
		||||
                    s.scope(|s| {
 | 
			
		||||
                        for (prev, fut) in self.fnow.iter().zip(self.fnext.iter_mut()) {
 | 
			
		||||
                            s.spawn(move |_| {
 | 
			
		||||
                                fut.assign(prev);
 | 
			
		||||
                            });
 | 
			
		||||
                        }
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
                1 | 2 => {
 | 
			
		||||
                    s.scope(|s| {
 | 
			
		||||
                        for ((prev, fut), k) in self
 | 
			
		||||
                            .fnow
 | 
			
		||||
                            .iter()
 | 
			
		||||
                            .zip(self.fnext.iter_mut())
 | 
			
		||||
                            .zip(&self.k[i - 1])
 | 
			
		||||
                        {
 | 
			
		||||
                            s.spawn(move |_| {
 | 
			
		||||
                                fut.assign(prev);
 | 
			
		||||
                                fut.scaled_add(1.0 / 2.0 * dt, k);
 | 
			
		||||
                            });
 | 
			
		||||
                        }
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
                3 => {
 | 
			
		||||
                    s.scope(|s| {
 | 
			
		||||
                        for ((prev, fut), k) in self
 | 
			
		||||
                            .fnow
 | 
			
		||||
                            .iter()
 | 
			
		||||
                            .zip(self.fnext.iter_mut())
 | 
			
		||||
                            .zip(&self.k[i - 1])
 | 
			
		||||
                        {
 | 
			
		||||
                            s.spawn(move |_| {
 | 
			
		||||
                                fut.assign(prev);
 | 
			
		||||
                                fut.scaled_add(dt, k);
 | 
			
		||||
                            });
 | 
			
		||||
                        }
 | 
			
		||||
                    });
 | 
			
		||||
                }
 | 
			
		||||
                4 => {
 | 
			
		||||
                    s.scope(|s| {
 | 
			
		||||
                        for (((((prev, fut), k0), k1), k2), k3) in self
 | 
			
		||||
                            .fnow
 | 
			
		||||
                            .iter()
 | 
			
		||||
                            .zip(self.fnext.iter_mut())
 | 
			
		||||
                            .zip(&self.k[0])
 | 
			
		||||
                            .zip(&self.k[1])
 | 
			
		||||
                            .zip(&self.k[2])
 | 
			
		||||
                            .zip(&self.k[3])
 | 
			
		||||
                        {
 | 
			
		||||
                            s.spawn(move |_| {
 | 
			
		||||
                                ndarray::Zip::from(&mut **fut)
 | 
			
		||||
                                    .and(&**prev)
 | 
			
		||||
                                    .and(&**k0)
 | 
			
		||||
                                    .and(&**k1)
 | 
			
		||||
                                    .and(&**k2)
 | 
			
		||||
                                    .and(&**k3)
 | 
			
		||||
                                    .apply(|y1, &y0, &k1, &k2, &k3, &k4| {
 | 
			
		||||
                                        *y1 = y0 + dt / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4)
 | 
			
		||||
                                    });
 | 
			
		||||
                            });
 | 
			
		||||
                        }
 | 
			
		||||
                    });
 | 
			
		||||
                    std::mem::swap(&mut self.fnext, &mut self.fnow);
 | 
			
		||||
                    return;
 | 
			
		||||
                }
 | 
			
		||||
                _ => {
 | 
			
		||||
                    unreachable!();
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            s.scope(|s| {
 | 
			
		||||
                let fields = &self.fnext;
 | 
			
		||||
                let bt = self
 | 
			
		||||
                    .bt
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .enumerate()
 | 
			
		||||
                    .map(|(i, bt)| euler::BoundaryTerms {
 | 
			
		||||
                        north: match bt.north {
 | 
			
		||||
                            euler::BoundaryCharacteristic::This => fields[i].south(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Grid(g) => fields[g].south(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Vortex(_) => todo!(),
 | 
			
		||||
                        },
 | 
			
		||||
                        south: match bt.south {
 | 
			
		||||
                            euler::BoundaryCharacteristic::This => fields[i].north(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Grid(g) => fields[g].north(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Vortex(_) => todo!(),
 | 
			
		||||
                        },
 | 
			
		||||
                        west: match bt.west {
 | 
			
		||||
                            euler::BoundaryCharacteristic::This => fields[i].east(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Grid(g) => fields[g].east(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Vortex(_) => todo!(),
 | 
			
		||||
                        },
 | 
			
		||||
                        east: match bt.east {
 | 
			
		||||
                            euler::BoundaryCharacteristic::This => fields[i].west(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Grid(g) => fields[g].west(),
 | 
			
		||||
                            euler::BoundaryCharacteristic::Vortex(_) => todo!(),
 | 
			
		||||
                        },
 | 
			
		||||
                    })
 | 
			
		||||
                    .collect::<Vec<_>>();
 | 
			
		||||
                for ((((prev, fut), grid), wb), bt) in fields
 | 
			
		||||
                    .iter()
 | 
			
		||||
                    .zip(&mut self.k[i])
 | 
			
		||||
                    .zip(&self.grids)
 | 
			
		||||
                    .zip(&mut self.wb)
 | 
			
		||||
                    .zip(bt)
 | 
			
		||||
                {
 | 
			
		||||
                    s.spawn(move |_| euler::RHS_upwind(fut, prev, grid, &bt, wb));
 | 
			
		||||
                }
 | 
			
		||||
            });
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[derive(Debug, StructOpt)]
 | 
			
		||||
@@ -160,6 +276,8 @@ struct Options {
 | 
			
		||||
    json: std::path::PathBuf,
 | 
			
		||||
    #[structopt(long, help = "Disable the progressbar")]
 | 
			
		||||
    no_progressbar: bool,
 | 
			
		||||
    #[structopt(short, long, help = "Number of simultaneous threads")]
 | 
			
		||||
    jobs: Option<usize>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
fn main() {
 | 
			
		||||
@@ -207,6 +325,17 @@ fn main() {
 | 
			
		||||
 | 
			
		||||
    let ntime = (integration_time / dt).round() as usize;
 | 
			
		||||
 | 
			
		||||
    let pool = if let Some(j) = opt.jobs {
 | 
			
		||||
        Some(
 | 
			
		||||
            rayon::ThreadPoolBuilder::new()
 | 
			
		||||
                .num_threads(j)
 | 
			
		||||
                .build()
 | 
			
		||||
                .unwrap(),
 | 
			
		||||
        )
 | 
			
		||||
    } else {
 | 
			
		||||
        None
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    let bar = if opt.no_progressbar {
 | 
			
		||||
        indicatif::ProgressBar::hidden()
 | 
			
		||||
    } else {
 | 
			
		||||
@@ -218,8 +347,12 @@ fn main() {
 | 
			
		||||
    };
 | 
			
		||||
    for _ in 0..ntime {
 | 
			
		||||
        bar.inc(1);
 | 
			
		||||
        if let Some(pool) = pool.as_ref() {
 | 
			
		||||
            sys.advance_parallel(dt, &pool);
 | 
			
		||||
        } else {
 | 
			
		||||
            sys.advance(dt);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    bar.finish();
 | 
			
		||||
 | 
			
		||||
    dump_to_file(&sys);
 | 
			
		||||
 
 | 
			
		||||
@@ -5,7 +5,7 @@ use crate::Float;
 | 
			
		||||
 | 
			
		||||
use ndarray::{ArrayView2, ArrayViewMut2};
 | 
			
		||||
 | 
			
		||||
pub trait SbpOperator {
 | 
			
		||||
pub trait SbpOperator: Send + Sync {
 | 
			
		||||
    fn diffxi(prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>);
 | 
			
		||||
    fn diffeta(prev: ArrayView2<Float>, fut: ArrayViewMut2<Float>);
 | 
			
		||||
    fn h() -> &'static [Float];
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user