//! Integration of explicit PDEs using different Butcher Tableaus //! //! Integration can be performed on all systems that can be represented //! as using a transform into an [`ndarray::ArrayView`] for both the state //! and the state difference. //! //! The integration functions are memory efficient, and relies //! on the `k` parameter to hold the system state differences. //! This parameter is tied to the Butcher Tableau use float::Float; /// The Butcher Tableau, with the state transitions described as /// [on wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Explicit_Runge%E2%80%93Kutta_methods). pub trait ButcherTableau { /// This bound should not be overridden const S: usize = Self::B.len(); /// Only the lower triangle will be used (explicit integration) const A: &'static [&'static [Float]]; const B: &'static [Float]; const C: &'static [Float]; } pub trait EmbeddedButcherTableau: ButcherTableau { const BSTAR: &'static [Float]; } pub struct Rk4; impl ButcherTableau for Rk4 { const A: &'static [&'static [Float]] = &[&[0.5], &[0.0, 0.5], &[0.0, 0.0, 1.0]]; const B: &'static [Float] = &[1.0 / 6.0, 1.0 / 3.0, 1.0 / 3.0, 1.0 / 6.0]; const C: &'static [Float] = &[0.5, 0.5, 1.0]; } pub struct Rk4_38; impl ButcherTableau for Rk4_38 { const A: &'static [&'static [Float]] = &[&[1.0 / 3.0], &[-1.0 / 3.0, 1.0], &[1.0, -1.0, 1.0]]; const B: &'static [Float] = &[1.0 / 8.0, 3.0 / 8.0, 3.0 / 8.0, 1.0 / 8.0]; const C: &'static [Float] = &[1.0 / 3.0, 2.0 / 3.0, 1.0]; } pub struct EulerMethod; impl ButcherTableau for EulerMethod { const A: &'static [&'static [Float]] = &[&[]]; const B: &'static [Float] = &[1.0]; const C: &'static [Float] = &[]; } pub struct MidpointMethod; impl ButcherTableau for MidpointMethod { const A: &'static [&'static [Float]] = &[&[0.5]]; const B: &'static [Float] = &[0.0, 1.0]; const C: &'static [Float] = &[1.0 / 2.0]; } /// Bit excessive... #[allow(clippy::excessive_precision)] #[allow(clippy::unreadable_literal)] const SQRT_5: Float = 2.236067977499789696409173668731276235440618359611525724270897245410520925637804899414414408378782275; pub struct Rk6; impl ButcherTableau for Rk6 { const A: &'static [&'static [Float]] = &[ &[4.0 / 7.0], &[115.0 / 112.0, -5.0 / 16.0], &[589.0 / 630.0, 5.0 / 18.0, -16.0 / 45.0], &[ 229.0 / 1200.0 - 29.0 / 6000.0 * SQRT_5, 119.0 / 240.0 - 187.0 / 1200.0 * SQRT_5, -14.0 / 75.0 + 34.0 / 375.0 * SQRT_5, -3.0 / 100.0 * SQRT_5, ], &[ 71.0 / 2400.0 - 587.0 / 12000.0 * SQRT_5, 187.0 / 480.0 - 391.0 / 2400.0 * SQRT_5, -38.0 / 75.0 + 26.0 / 375.0 * SQRT_5, 27.0 / 80.0 - 3.0 / 400.0 * SQRT_5, (1.0 + SQRT_5) / 4.0, ], &[ -49.0 / 480.0 + 43.0 / 160.0 * SQRT_5, -425.0 / 96.0 + 51.0 / 32.0 * SQRT_5, 52.0 / 15.0 - 4.0 / 5.0 * SQRT_5, -27.0 / 16.0 + 3.0 / 16.0 * SQRT_5, 5.0 / 4.0 - 3.0 / 4.0 * SQRT_5, 5.0 / 2.0 - 1.0 / 2.0 * SQRT_5, ], ]; const B: &'static [Float] = &[ 1.0 / 12.0, 0.0, 0.0, 0.0, 5.0 / 12.0, 5.0 / 12.0, 1.0 / 12.0, ]; const C: &'static [Float] = &[ 4.0 / 7.0, 5.0 / 7.0, 6.0 / 7.0, (5.0 - SQRT_5) / 10.0, (5.0 + SQRT_5) / 10.0, 1.0, ]; } pub struct Fehlberg; impl ButcherTableau for Fehlberg { #[rustfmt::skip] const A: &'static [&'static [Float]] = &[ &[1.0 / 4.0], &[3.0 / 32.0, 9.0 / 32.0], &[1932.0 / 2197.0, -7200.0 / 2197.0, 7296.0 / 2197.0], &[439.0 / 216.0, -8.0, 3680.0 / 513.0, -845.0 / 4104.0], &[-8.0 / 27.0, 2.0, -3544.0 / 2565.0, 1859.0 / 4104.0, -11.0 / 40.0], ]; #[rustfmt::skip] const B: &'static [Float] = &[ 16.0 / 135.0, 0.0, 6656.0 / 12825.0, 28561.0 / 56430.0, -9.0 / 50.0, 2.0 / 55.0, ]; const C: &'static [Float] = &[0.0, 1.0 / 4.0, 3.0 / 8.0, 12.0 / 13.0, 1.0, 1.0 / 2.0]; } impl EmbeddedButcherTableau for Fehlberg { const BSTAR: &'static [Float] = &[ 25.0 / 216.0, 0.0, 1408.0 / 2565.0, 2197.0 / 4104.0, -1.0 / 5.0, 0.0, ]; } pub struct BogackiShampine; impl ButcherTableau for BogackiShampine { const A: &'static [&'static [Float]] = &[ &[1.0 / 2.0], &[0.0, 3.0 / 4.0], &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0], ]; const B: &'static [Float] = &[2.0 / 9.0, 1.0 / 3.0, 4.0 / 9.0, 0.0]; const C: &'static [Float] = &[0.0, 1.0 / 2.0, 3.0 / 4.0, 1.0]; } impl EmbeddedButcherTableau for BogackiShampine { const BSTAR: &'static [Float] = &[7.0 / 24.0, 1.0 / 4.0, 1.0 / 3.0, 1.0 / 8.0]; } pub trait Integrable { type State; type Diff; fn assign(s: &mut Self::State, o: &Self::State); fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float); } #[allow(clippy::too_many_arguments)] /// Integrates using the [`ButcherTableau`] specified. `rhs` should be the result /// of the right hand side of $u_t = rhs$ /// /// rhs takes the old state and the current time, and outputs the state difference /// in the first parameter /// /// Should be called as /// ```rust,ignore /// integrate::(...) /// ``` pub fn integrate( mut rhs: RHS, prev: &F::State, fut: &mut F::State, time: &mut Float, dt: Float, k: &mut [F::Diff], ) where RHS: FnMut(&mut F::Diff, &F::State, Float), { assert!(k.len() >= BTableau::S); for i in 0.. { let simtime; match i { 0 => { F::assign(fut, prev); simtime = *time; } i if i < BTableau::S => { F::assign(fut, prev); for (&a, k) in BTableau::A[i - 1].iter().zip(k.iter()) { if a == 0.0 { continue; } F::scaled_add(fut, k, a * dt); } simtime = *time + dt * BTableau::C[i - 1]; } _ if i == BTableau::S => { F::assign(fut, prev); for (&b, k) in BTableau::B.iter().zip(k.iter()) { if b == 0.0 { continue; } F::scaled_add(fut, k, b * dt); } *time += dt; return; } _ => { unreachable!(); } }; rhs(&mut k[i], &fut, simtime); } } #[allow(clippy::too_many_arguments)] /// Integrate using an [`EmbeddedButcherTableau`], else similar to [`integrate`] /// /// This produces two results, the most accurate result in `fut`, and the less accurate /// result in `fut2`. This can be used for convergence testing and adaptive timesteps. pub fn integrate_embedded_rk( rhs: RHS, prev: &F::State, fut: &mut F::State, fut2: &mut F::State, time: &mut Float, dt: Float, k: &mut [F::Diff], ) where RHS: FnMut(&mut F::Diff, &F::State, Float), { integrate::(rhs, prev, fut, time, dt, k); F::assign(fut2, prev); for (&b, k) in BTableau::BSTAR.iter().zip(k.iter()) { if b == 0.0 { continue; } F::scaled_add(fut2, k, b * dt); } } #[cfg(feature = "rayon")] #[allow(clippy::too_many_arguments)] /// Integrates a multigrid problem, much the same as [`integrate`], /// using a `rayon` threadpool for parallelisation. /// /// note that `rhs` accepts the full system state, and is responsible /// for computing the full state difference. /// `rhs` can be a mutable closure, so buffers can be used /// and mutated inside the closure. /// /// This function requires the `rayon` feature, and is not callable in /// a `wasm` context. pub fn integrate_multigrid( mut rhs: RHS, prev: &[F::State], fut: &mut [F::State], time: &mut Float, dt: Float, k: &mut [&mut [F::Diff]], pool: &rayon::ThreadPool, ) where RHS: FnMut(&mut [F::Diff], &[F::State], Float), F::State: Send + Sync, F::Diff: Send + Sync, { for i in 0.. { let simtime; match i { 0 => { pool.scope(|s| { assert!(k.len() >= BTableau::S); for (prev, fut) in prev.iter().zip(fut.iter_mut()) { s.spawn(move |_| { F::assign(fut, prev); }); } }); simtime = *time; } i if i < BTableau::S => { pool.scope(|s| { for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { let k = &k; s.spawn(move |_| { F::assign(fut, prev); for (ik, &a) in BTableau::A[i - 1].iter().enumerate() { if a == 0.0 { continue; } F::scaled_add(fut, &k[ik][ig], a * dt); } }); } }); simtime = *time + dt * BTableau::C[i - 1]; } _ if i == BTableau::S => { pool.scope(|s| { for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() { let k = &k; s.spawn(move |_| { F::assign(fut, prev); for (ik, &b) in BTableau::B.iter().enumerate() { if b == 0.0 { continue; } F::scaled_add(fut, &k[ik][ig], b * dt); } }); } }); *time += dt; return; } _ => { unreachable!(); } }; rhs(&mut k[i], &fut, simtime); } } #[test] /// Solving a second order PDE fn ballistic() { #[derive(Clone, Debug)] struct Ball { z: Float, v: Float, } impl Integrable for Ball { type State = Ball; type Diff = (Float, Float); fn assign(s: &mut Self::State, o: &Self::State) { s.z = o.z; s.v = o.v; } fn scaled_add(s: &mut Self::State, o: &Self::Diff, sc: Float) { s.z += o.0 * sc; s.v += o.1 * sc; } } let mut t = 0.0; let dt = 0.001; let initial = Ball { z: 0.0, v: 10.0 }; let g = -9.81; let mut k = [(0.0, 0.0); 4]; let gravity = |d: &mut (Float, Float), s: &Ball, _time: Float| { d.1 = g; d.0 = s.v }; let mut next = initial.clone(); //while next.z >= 0.0 { while t < 1.0 { let mut next2 = next.clone(); integrate::(gravity, &next, &mut next2, &mut t, dt, &mut k); std::mem::swap(&mut next, &mut next2); } let expected_vel = initial.v + g * t; assert!((next.v - expected_vel).abs() < 1e-3); let expected_pos = initial.z + initial.v * t + g / 2.0 * t.powi(2); assert!((next.z - expected_pos).abs() < 1e-2); }