diff --git a/multigrid/Cargo.toml b/multigrid/Cargo.toml index e8f7132..ceda6bf 100644 --- a/multigrid/Cargo.toml +++ b/multigrid/Cargo.toml @@ -4,8 +4,9 @@ version = "0.1.0" authors = ["Magnus Ulimoen "] edition = "2018" + [dependencies] -sbp = { path = "../sbp" } +sbp = { path = "../sbp", features = ["rayon"] } hdf5 = "0.6.0" rayon = "1.3.0" indicatif = "0.14.0" diff --git a/multigrid/src/main.rs b/multigrid/src/main.rs index 3d5aa2d..1fdcd39 100644 --- a/multigrid/src/main.rs +++ b/multigrid/src/main.rs @@ -63,105 +63,62 @@ impl System { } } - fn advance(&mut self, dt: Float, s: &rayon::ThreadPool) { - for i in 0.. { - let time; - match i { - 0 => { - s.scope(|s| { - for (prev, fut) in self.fnow.iter().zip(self.fnext.iter_mut()) { - s.spawn(move |_| { - fut.assign(prev); - }); - } - }); - time = self.time; - } - 1 | 2 => { - s.scope(|s| { - for ((prev, fut), k) in self - .fnow - .iter() - .zip(self.fnext.iter_mut()) - .zip(&self.k[i - 1]) - { - s.spawn(move |_| { - fut.assign(prev); - fut.scaled_add(1.0 / 2.0 * dt, k); - }); - } - }); - time = self.time + dt / 2.0; - } - 3 => { - s.scope(|s| { - for ((prev, fut), k) in self - .fnow - .iter() - .zip(self.fnext.iter_mut()) - .zip(&self.k[i - 1]) - { - s.spawn(move |_| { - fut.assign(prev); - fut.scaled_add(dt, k); - }); - } - }); - time = self.time + dt; - } - 4 => { - s.scope(|s| { - for (((((prev, fut), k0), k1), k2), k3) in self - .fnow - .iter() - .zip(self.fnext.iter_mut()) - .zip(&self.k[0]) - .zip(&self.k[1]) - .zip(&self.k[2]) - .zip(&self.k[3]) - { - s.spawn(move |_| { - ndarray::Zip::from(&mut **fut) - .and(&**prev) - .and(&**k0) - .and(&**k1) - .and(&**k2) - .and(&**k3) - .apply(|y1, &y0, &k1, &k2, &k3, &k4| { - *y1 = y0 + dt / 6.0 * (k1 + 2.0 * k2 + 2.0 * k3 + k4) - }); - }); - } - }); - std::mem::swap(&mut self.fnext, &mut self.fnow); - self.time += dt; - return; - } - _ => { - unreachable!(); - } - } + fn advance(&mut self, dt: Float, pool: &rayon::ThreadPool) { + let rhs = move |fut: &mut [euler::Field], + prev: &[euler::Field], + time: Float, + c: &( + &[grid::Grid], + &[grid::Metrics<_>], + &[euler::BoundaryCharacteristics], + ), + mt: &mut ( + &mut [( + euler::Field, + euler::Field, + euler::Field, + euler::Field, + euler::Field, + euler::Field, + )], + &mut [euler::BoundaryStorage], + )| { + let (grids, metrics, bt) = c; + let (wb, eb) = mt; - s.scope(|s| { - let fields = &self.fnext; - let bt = euler::extract_boundaries::( - &fields, - &mut self.bt, - &mut self.eb, - &self.grids, - time, - ); - for ((((prev, fut), metrics), wb), bt) in fields - .iter() - .zip(&mut self.k[i]) - .zip(&self.metrics) - .zip(&mut self.wb) - .zip(bt) + let bc = euler::extract_boundaries::( + prev, *bt, *eb, *grids, time, + ); + pool.scope(|s| { + for ((((fut, prev), bc), wb), metrics) in fut + .iter_mut() + .zip(prev.iter()) + .zip(bc) + .zip(wb.iter_mut()) + .zip(metrics.iter()) { - s.spawn(move |_| euler::RHS_upwind(fut, prev, metrics, &bt, wb)); + s.spawn(move |_| euler::RHS_upwind(fut, prev, metrics, &bc, wb)); } }); - } + }; + let mut k = self + .k + .iter_mut() + .map(|k| k.as_mut_slice()) + .collect::>(); + sbp::integrate::integrate_multigrid::( + rhs, + &self.fnow, + &mut self.fnext, + &mut self.time, + dt, + &mut k, + &(&self.grids, &self.metrics, &self.bt), + &mut (&mut self.wb, &mut self.eb), + pool, + ); + + std::mem::swap(&mut self.fnow, &mut self.fnext); } } diff --git a/sbp/Cargo.toml b/sbp/Cargo.toml index f374827..142f76d 100644 --- a/sbp/Cargo.toml +++ b/sbp/Cargo.toml @@ -9,6 +9,7 @@ ndarray = { version = "0.13.0", features = ["approx"] } approx = "0.3.2" packed_simd = "0.3.3" json = "0.12.4" +rayon = { version = "1.3.0", optional = true } [features] # Internal feature flag to gate the expensive tests diff --git a/sbp/src/integrate.rs b/sbp/src/integrate.rs index 9214f6e..aacf43d 100644 --- a/sbp/src/integrate.rs +++ b/sbp/src/integrate.rs @@ -140,3 +140,83 @@ pub fn integrate<'a, BTableau, F: 'a, RHS, MT, C>( rhs(&mut k[i], &fut, simtime, constants, &mut mutables); } } + +#[cfg(feature = "rayon")] +pub fn integrate_multigrid<'a, BTableau, F: 'a, RHS, MT, C>( + rhs: RHS, + prev: &[F], + fut: &mut [F], + time: &mut Float, + dt: Float, + k: &mut [&mut [F]], + + constants: C, + mut mutables: &mut MT, + pool: &rayon::ThreadPool, +) where + C: Copy, + F: std::ops::Deref> + + std::ops::DerefMut> + + Send + + Sync, + RHS: Fn(&mut [F], &[F], Float, C, &mut MT), + BTableau: ButcherTableau, +{ + for i in 0.. { + let simtime; + match i { + 0 => { + pool.scope(|s| { + assert!(k.len() >= BTableau::S); + for (prev, fut) in prev.iter().zip(fut.iter_mut()) { + s.spawn(move |_| { + assert_eq!(prev.shape(), fut.shape()); + fut.assign(prev); + }); + } + }); + simtime = *time; + } + i if i < BTableau::S => { + pool.scope(|s| { + for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { + let k = &k; + s.spawn(move |_| { + fut.assign(prev); + for (ik, &a) in BTableau::A[i - 1].iter().enumerate() { + if a == 0.0 { + continue; + } + fut.scaled_add(a * dt, &k[ik][ig]); + } + }); + } + }); + simtime = *time + dt * BTableau::C[i - 1]; + } + _ if i == BTableau::S => { + pool.scope(|s| { + for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { + let k = &k; + s.spawn(move |_| { + fut.assign(prev); + for (ik, &b) in BTableau::B.iter().enumerate() { + if b == 0.0 { + continue; + } + fut.scaled_add(b * dt, &k[ik][ig]); + } + }); + } + }); + *time += dt; + return; + } + _ => { + unreachable!(); + } + }; + + rhs(&mut k[i], &fut, simtime, constants, &mut mutables); + } +}