Move integrate to separate crate
This commit is contained in:
parent
be1330ec02
commit
7aadda3de9
|
@ -9,6 +9,7 @@ members = [
|
|||
"gridgeneration",
|
||||
"heat-equation",
|
||||
"utils/float",
|
||||
"utils/integrate",
|
||||
"utils/constmatrix",
|
||||
]
|
||||
|
||||
|
|
|
@ -15,6 +15,7 @@ ndarray = "0.14.0"
|
|||
sbp = { path = "../sbp" }
|
||||
arrayvec = "0.5.1"
|
||||
serde = { version = "1.0.115", default-features = false, optional = true, features = ["derive"] }
|
||||
integrate = { path = "../utils/integrate" }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.2"
|
||||
|
|
|
@ -2,7 +2,6 @@ pub use arrayvec::ArrayVec;
|
|||
use ndarray::azip;
|
||||
use ndarray::prelude::*;
|
||||
use sbp::grid::{Grid, Metrics};
|
||||
use sbp::integrate;
|
||||
use sbp::operators::{InterpolationOperator, SbpOperator2d, UpwindOperator2d};
|
||||
use sbp::utils::Direction;
|
||||
use sbp::Float;
|
||||
|
@ -57,7 +56,7 @@ impl<SBP: SbpOperator2d> System<SBP> {
|
|||
let boundaries = boundary_extractor(y, grid, &bc);
|
||||
RHS_trad(op, k, y, metrics, &boundaries, wb)
|
||||
};
|
||||
integrate::integrate::<integrate::Rk4, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, Field, _>(
|
||||
rhs_trad,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
|
@ -131,7 +130,7 @@ impl<UO: UpwindOperator2d + SbpOperator2d> System<UO> {
|
|||
let boundaries = boundary_extractor(y, grid, &bc);
|
||||
RHS_upwind(op, k, y, metrics, &boundaries, wb)
|
||||
};
|
||||
integrate::integrate::<integrate::Rk4, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, Field, _>(
|
||||
rhs_upwind,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
|
@ -159,7 +158,7 @@ impl<UO: UpwindOperator2d + SbpOperator2d> System<UO> {
|
|||
let mut time = 0.0;
|
||||
let mut sys2 = self.sys.0.clone();
|
||||
while time < dt {
|
||||
integrate::integrate_embedded_rk::<integrate::BogackiShampine, _, _, _>(
|
||||
integrate::integrate_embedded_rk::<integrate::BogackiShampine, Field, _>(
|
||||
&mut rhs_upwind,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
|
@ -184,27 +183,15 @@ impl<UO: UpwindOperator2d + SbpOperator2d> System<UO> {
|
|||
/// A 4 x ny x nx array
|
||||
pub struct Field(pub(crate) Array3<Float>);
|
||||
|
||||
impl std::ops::Deref for Field {
|
||||
type Target = Array3<Float>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl integrate::Integrable for Field {
|
||||
type State = Field;
|
||||
type Diff = Field;
|
||||
|
||||
impl std::ops::DerefMut for Field {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
fn assign(s: &mut Self::State, o: &Self::State) {
|
||||
s.0.assign(&o.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
|
||||
fn from(f: &'a Field) -> Self {
|
||||
f.0.view()
|
||||
}
|
||||
}
|
||||
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
|
||||
fn from(f: &'a mut Field) -> Self {
|
||||
f.0.view_mut()
|
||||
fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) {
|
||||
s.0.scaled_add(scale, &o.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -222,6 +209,20 @@ impl Field {
|
|||
self.0.shape()[1]
|
||||
}
|
||||
|
||||
pub(crate) fn slice<Do: Dimension>(
|
||||
&self,
|
||||
info: &ndarray::SliceInfo<[ndarray::SliceOrIndex; 3], Do>,
|
||||
) -> ArrayView<Float, Do> {
|
||||
self.0.slice(info)
|
||||
}
|
||||
|
||||
pub(crate) fn slice_mut<Do: Dimension>(
|
||||
&mut self,
|
||||
info: &ndarray::SliceInfo<[ndarray::SliceOrIndex; 3], Do>,
|
||||
) -> ArrayViewMut<Float, Do> {
|
||||
self.0.slice_mut(info)
|
||||
}
|
||||
|
||||
pub fn rho(&self) -> ArrayView2<Float> {
|
||||
self.slice(s![0, .., ..])
|
||||
}
|
||||
|
@ -613,9 +614,9 @@ fn upwind_dissipation(
|
|||
tmp: (&mut Field, &mut Field),
|
||||
) {
|
||||
let n = y.nx() * y.ny();
|
||||
let yview = y.view().into_shape((4, n)).unwrap();
|
||||
let mut tmp0 = tmp.0.view_mut().into_shape((4, n)).unwrap();
|
||||
let mut tmp1 = tmp.1.view_mut().into_shape((4, n)).unwrap();
|
||||
let yview = y.0.view().into_shape((4, n)).unwrap();
|
||||
let mut tmp0 = tmp.0 .0.view_mut().into_shape((4, n)).unwrap();
|
||||
let mut tmp1 = tmp.1 .0.view_mut().into_shape((4, n)).unwrap();
|
||||
|
||||
for ((((((y, mut tmp0), mut tmp1), detj_dxi_dx), detj_dxi_dy), detj_deta_dx), detj_deta_dy) in
|
||||
yview
|
||||
|
|
|
@ -10,6 +10,7 @@ sbp = { path = "../sbp", features = ["sparse"] }
|
|||
ndarray = "0.14.0"
|
||||
plotters = { version = "0.3.0", default-features = false, features = ["bitmap_gif", "bitmap_backend", "line_series"] }
|
||||
sprs = { version = "0.10.0", default-features = false }
|
||||
integrate = { path = "../utils/integrate" }
|
||||
|
||||
[dev-dependencies]
|
||||
arpack = { git = "https://github.com/mulimoen/arpack-rs", branch = "main" }
|
||||
|
|
|
@ -1,11 +1,24 @@
|
|||
use integrate::{integrate, Rk4};
|
||||
use ndarray::{Array1, ArrayView1};
|
||||
use plotters::prelude::*;
|
||||
use sbp::{
|
||||
integrate::{integrate, Rk4},
|
||||
operators::{SbpOperator1d, SbpOperator1d2, SBP4},
|
||||
Float,
|
||||
};
|
||||
|
||||
struct Field(Array1<Float>);
|
||||
|
||||
impl integrate::Integrable for Field {
|
||||
type State = Array1<Float>;
|
||||
type Diff = Array1<Float>;
|
||||
fn assign(s: &mut Self::State, o: &Self::State) {
|
||||
s.assign(o)
|
||||
}
|
||||
fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) {
|
||||
s.scaled_add(scale, o)
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let nx: usize = 101;
|
||||
let x = Array1::from_shape_fn((nx,), |i| i as Float / (nx - 1) as Float);
|
||||
|
@ -76,7 +89,7 @@ fn dual_dirichlet(v: ArrayView1<Float>, v0: Float, vn: Float) {
|
|||
.unwrap();
|
||||
drawing_area.present().unwrap();
|
||||
}
|
||||
integrate::<Rk4, _, _, _>(rhs, &v1, &mut v2, &mut 0.0, dt, &mut k);
|
||||
integrate::<Rk4, Field, _>(rhs, &v1, &mut v2, &mut 0.0, dt, &mut k);
|
||||
std::mem::swap(&mut v1, &mut v2);
|
||||
}
|
||||
}
|
||||
|
@ -143,7 +156,7 @@ fn neumann_dirichlet(v: ArrayView1<Float>, v0: Float, vn: Float) {
|
|||
.unwrap();
|
||||
drawing_area.present().unwrap();
|
||||
}
|
||||
integrate::<Rk4, _, _, _>(rhs, &v1, &mut v2, &mut 0.0, dt, &mut k);
|
||||
integrate::<Rk4, Field, _>(rhs, &v1, &mut v2, &mut 0.0, dt, &mut k);
|
||||
std::mem::swap(&mut v1, &mut v2);
|
||||
}
|
||||
}
|
||||
|
@ -231,7 +244,7 @@ fn dual_dirichlet_sparse(v: ArrayView1<Float>, v0: Float, vn: Float) {
|
|||
.unwrap();
|
||||
drawing_area.present().unwrap();
|
||||
}
|
||||
integrate::<Rk4, _, _, _>(rhs, &v1, &mut v2, &mut 0.0, dt, &mut k);
|
||||
integrate::<Rk4, Field, _>(rhs, &v1, &mut v2, &mut 0.0, dt, &mut k);
|
||||
std::mem::swap(&mut v1, &mut v2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ sparse = ["sbp/sparse", "sprs"]
|
|||
ndarray = "0.14.0"
|
||||
sbp = { path = "../sbp" }
|
||||
sprs = { version = "0.10.0", optional = true, default-features = false }
|
||||
integrate = { path = "../utils/integrate" }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = "0.3.2"
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use ndarray::azip;
|
||||
use ndarray::prelude::*;
|
||||
use sbp::grid::{Grid, Metrics};
|
||||
use sbp::integrate;
|
||||
use sbp::operators::{SbpOperator2d, UpwindOperator2d};
|
||||
use sbp::Float;
|
||||
|
||||
|
@ -11,27 +10,15 @@ pub mod sparse;
|
|||
#[derive(Clone, Debug)]
|
||||
pub struct Field(pub(crate) Array3<Float>);
|
||||
|
||||
impl std::ops::Deref for Field {
|
||||
type Target = Array3<Float>;
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
impl integrate::Integrable for Field {
|
||||
type State = Field;
|
||||
type Diff = Field;
|
||||
|
||||
impl std::ops::DerefMut for Field {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
&mut self.0
|
||||
fn assign(s: &mut Self::State, o: &Self::State) {
|
||||
s.0.assign(&o.0);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> std::convert::From<&'a Field> for ArrayView3<'a, Float> {
|
||||
fn from(f: &'a Field) -> Self {
|
||||
f.0.view()
|
||||
}
|
||||
}
|
||||
impl<'a> std::convert::From<&'a mut Field> for ArrayViewMut3<'a, Float> {
|
||||
fn from(f: &'a mut Field) -> Self {
|
||||
f.0.view_mut()
|
||||
fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) {
|
||||
s.0.scaled_add(scale, &o.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,6 +36,20 @@ impl Field {
|
|||
self.0.shape()[1]
|
||||
}
|
||||
|
||||
pub(crate) fn slice<Do: Dimension>(
|
||||
&self,
|
||||
info: &ndarray::SliceInfo<[ndarray::SliceOrIndex; 3], Do>,
|
||||
) -> ArrayView<Float, Do> {
|
||||
self.0.slice(info)
|
||||
}
|
||||
|
||||
pub(crate) fn slice_mut<Do: Dimension>(
|
||||
&mut self,
|
||||
info: &ndarray::SliceInfo<[ndarray::SliceOrIndex; 3], Do>,
|
||||
) -> ArrayViewMut<Float, Do> {
|
||||
self.0.slice_mut(info)
|
||||
}
|
||||
|
||||
pub fn ex(&self) -> ArrayView2<Float> {
|
||||
self.slice(s![0, .., ..])
|
||||
}
|
||||
|
@ -147,7 +148,7 @@ impl<SBP: SbpOperator2d> System<SBP> {
|
|||
RHS(op, fut, prev, grid, metrics, wb);
|
||||
};
|
||||
let mut _time = 0.0;
|
||||
integrate::integrate::<integrate::Rk4, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, Field, _>(
|
||||
rhs_adaptor,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
|
@ -162,15 +163,15 @@ impl<SBP: SbpOperator2d> System<SBP> {
|
|||
let rhs = self.rhs.view();
|
||||
//let lhs = self.explicit.view();
|
||||
let rhs_f = |next: &mut Field, now: &Field, _t: Float| {
|
||||
next.fill(0.0);
|
||||
next.0.fill(0.0);
|
||||
sprs::prod::mul_acc_mat_vec_csr(
|
||||
rhs,
|
||||
now.as_slice().unwrap(),
|
||||
next.as_slice_mut().unwrap(),
|
||||
now.0.as_slice().unwrap(),
|
||||
next.0.as_slice_mut().unwrap(),
|
||||
);
|
||||
// sprs::lingalg::dsolve(..)
|
||||
};
|
||||
sbp::integrate::integrate::<sbp::integrate::Rk4, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, Field, _>(
|
||||
rhs_f,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
|
@ -188,9 +189,9 @@ impl<SBP: SbpOperator2d> System<SBP> {
|
|||
|
||||
sbp::utils::jacobi_method(
|
||||
lhs,
|
||||
b.as_slice().unwrap(),
|
||||
self.sys.0.as_slice_mut().unwrap(),
|
||||
self.sys.1.as_slice_mut().unwrap(),
|
||||
b.0.as_slice().unwrap(),
|
||||
self.sys.0 .0.as_slice_mut().unwrap(),
|
||||
self.sys.1 .0.as_slice_mut().unwrap(),
|
||||
10,
|
||||
);
|
||||
}
|
||||
|
@ -207,7 +208,7 @@ impl<UO: SbpOperator2d + UpwindOperator2d> System<UO> {
|
|||
RHS_upwind(op, fut, prev, grid, metrics, wb);
|
||||
};
|
||||
let mut _time = 0.0;
|
||||
integrate::integrate::<integrate::Rk4, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, Field, _>(
|
||||
rhs_adaptor,
|
||||
&self.sys.0,
|
||||
&mut self.sys.1,
|
||||
|
|
|
@ -6,9 +6,10 @@ edition = "2018"
|
|||
|
||||
|
||||
[dependencies]
|
||||
sbp = { path = "../sbp", features = ["rayon", "serde1", "fast-float"] }
|
||||
sbp = { path = "../sbp", features = ["serde1", "fast-float"] }
|
||||
euler = { path = "../euler", features = ["serde1"] }
|
||||
hdf5 = "0.7.0"
|
||||
integrate = { path = "../utils/integrate", features = ["rayon"] }
|
||||
rayon = "1.3.0"
|
||||
indicatif = "0.15.0"
|
||||
structopt = "0.3.14"
|
||||
|
|
|
@ -38,7 +38,8 @@ impl OutputThread {
|
|||
match self.rx.as_ref().unwrap().try_recv() {
|
||||
Ok(mut copy_fields) => {
|
||||
for (from, to) in fields.iter().zip(copy_fields.iter_mut()) {
|
||||
to.assign(&from);
|
||||
use integrate::Integrable;
|
||||
euler::Field::assign(to, from);
|
||||
}
|
||||
self.tx
|
||||
.as_ref()
|
||||
|
|
|
@ -106,7 +106,7 @@ impl System {
|
|||
.iter_mut()
|
||||
.map(|k| k.as_mut_slice())
|
||||
.collect::<Vec<_>>();
|
||||
sbp::integrate::integrate_multigrid::<sbp::integrate::Rk4, _, _, _>(
|
||||
integrate::integrate_multigrid::<integrate::Rk4, euler::Field, _>(
|
||||
rhs,
|
||||
&self.fnow,
|
||||
&mut self.fnext,
|
||||
|
|
|
@ -8,7 +8,6 @@ edition = "2018"
|
|||
ndarray = { version = "0.14.0", features = ["approx"] }
|
||||
approx = "0.4.0"
|
||||
packed_simd = { version = "0.3.3", package = "packed_simd_2" }
|
||||
rayon = { version = "1.3.0", optional = true }
|
||||
sprs = { version = "0.10.0", optional = true, default-features = false }
|
||||
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
||||
num-traits = "0.2.14"
|
||||
|
|
|
@ -7,8 +7,6 @@ pub use float::{consts, Float};
|
|||
|
||||
/// Grid and grid metrics
|
||||
pub mod grid;
|
||||
/// RK operators and methods for implicit integration
|
||||
pub mod integrate;
|
||||
/// SBP and interpolation operators
|
||||
pub mod operators;
|
||||
/// General utilities
|
||||
|
|
|
@ -8,3 +8,4 @@ edition = "2018"
|
|||
ndarray = "0.14.0"
|
||||
sbp = { path = "../sbp" }
|
||||
log = "0.4.8"
|
||||
integrate = { path = "../utils/integrate" }
|
||||
|
|
|
@ -9,15 +9,15 @@ const G: Float = 1.0;
|
|||
#[derive(Clone, Debug)]
|
||||
pub struct Field(Array3<Float>);
|
||||
|
||||
impl<'a> Into<ArrayView3<'a, Float>> for &'a Field {
|
||||
fn into(self) -> ArrayView3<'a, Float> {
|
||||
self.0.view()
|
||||
}
|
||||
}
|
||||
impl integrate::Integrable for Field {
|
||||
type State = Field;
|
||||
type Diff = Field;
|
||||
|
||||
impl<'a> Into<ArrayViewMut3<'a, Float>> for &'a mut Field {
|
||||
fn into(self) -> ArrayViewMut3<'a, Float> {
|
||||
self.0.view_mut()
|
||||
fn assign(s: &mut Self::State, o: &Self::State) {
|
||||
s.0.assign(&o.0);
|
||||
}
|
||||
fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float) {
|
||||
s.0.scaled_add(scale, &o.0);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -301,7 +301,7 @@ impl System {
|
|||
}
|
||||
log::trace!("Iteration complete");
|
||||
};
|
||||
integrate::integrate::<integrate::Rk4, _, _, _>(
|
||||
integrate::integrate::<integrate::Rk4, Field, _>(
|
||||
rhs,
|
||||
&self.fnow,
|
||||
&mut self.fnext,
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
[package]
|
||||
name = "integrate"
|
||||
version = "0.1.0"
|
||||
authors = ["Magnus Ulimoen <magnus@ulimoen.dev>"]
|
||||
edition = "2018"
|
||||
|
||||
[dependencies]
|
||||
float = { path = "../float/" }
|
||||
rayon = { version = "1.5.0", optional = true }
|
|
@ -8,8 +8,7 @@
|
|||
//! on the `k` parameter to hold the system state differences.
|
||||
//! This parameter is tied to the Butcher Tableau
|
||||
|
||||
use super::Float;
|
||||
use ndarray::{ArrayView, ArrayViewMut};
|
||||
use float::Float;
|
||||
|
||||
/// The Butcher Tableau, with the state transitions described as
|
||||
/// [on wikipedia](https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods#Explicit_Runge%E2%80%93Kutta_methods).
|
||||
|
@ -150,6 +149,14 @@ impl EmbeddedButcherTableau for BogackiShampine {
|
|||
const BSTAR: &'static [Float] = &[7.0 / 24.0, 1.0 / 4.0, 1.0 / 3.0, 1.0 / 8.0];
|
||||
}
|
||||
|
||||
pub trait Integrable {
|
||||
type State;
|
||||
type Diff;
|
||||
|
||||
fn assign(s: &mut Self::State, o: &Self::State);
|
||||
fn scaled_add(s: &mut Self::State, o: &Self::Diff, scale: Float);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Integrates using the [`ButcherTableau`] specified. `rhs` should be the result
|
||||
/// of the right hand side of $u_t = rhs$
|
||||
|
@ -159,48 +166,44 @@ impl EmbeddedButcherTableau for BogackiShampine {
|
|||
///
|
||||
/// Should be called as
|
||||
/// ```rust,ignore
|
||||
/// integrate::<Rk4, _, _, _, _>(...)
|
||||
/// integrate::<Rk4, System, _>(...)
|
||||
/// ```
|
||||
pub fn integrate<BTableau: ButcherTableau, F, RHS, D>(
|
||||
pub fn integrate<BTableau: ButcherTableau, F: Integrable, RHS>(
|
||||
mut rhs: RHS,
|
||||
prev: &F,
|
||||
fut: &mut F,
|
||||
prev: &F::State,
|
||||
fut: &mut F::State,
|
||||
time: &mut Float,
|
||||
dt: Float,
|
||||
k: &mut [F],
|
||||
k: &mut [F::Diff],
|
||||
) where
|
||||
for<'r> &'r F: std::convert::Into<ArrayView<'r, Float, D>>,
|
||||
for<'r> &'r mut F: std::convert::Into<ArrayViewMut<'r, Float, D>>,
|
||||
D: ndarray::Dimension,
|
||||
RHS: FnMut(&mut F, &F, Float),
|
||||
RHS: FnMut(&mut F::Diff, &F::State, Float),
|
||||
{
|
||||
assert_eq!(prev.into().shape(), fut.into().shape());
|
||||
assert!(k.len() >= BTableau::S);
|
||||
|
||||
for i in 0.. {
|
||||
let simtime;
|
||||
match i {
|
||||
0 => {
|
||||
fut.into().assign(&prev.into());
|
||||
F::assign(fut, prev);
|
||||
simtime = *time;
|
||||
}
|
||||
i if i < BTableau::S => {
|
||||
fut.into().assign(&prev.into());
|
||||
F::assign(fut, prev);
|
||||
for (&a, k) in BTableau::A[i - 1].iter().zip(k.iter()) {
|
||||
if a == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.into().scaled_add(a * dt, &k.into());
|
||||
F::scaled_add(fut, k, a * dt);
|
||||
}
|
||||
simtime = *time + dt * BTableau::C[i - 1];
|
||||
}
|
||||
_ if i == BTableau::S => {
|
||||
fut.into().assign(&prev.into());
|
||||
F::assign(fut, prev);
|
||||
for (&b, k) in BTableau::B.iter().zip(k.iter()) {
|
||||
if b == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.into().scaled_add(b * dt, &k.into());
|
||||
F::scaled_add(fut, k, b * dt);
|
||||
}
|
||||
*time += dt;
|
||||
return;
|
||||
|
@ -219,27 +222,24 @@ pub fn integrate<BTableau: ButcherTableau, F, RHS, D>(
|
|||
///
|
||||
/// This produces two results, the most accurate result in `fut`, and the less accurate
|
||||
/// result in `fut2`. This can be used for convergence testing and adaptive timesteps.
|
||||
pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F, RHS, D>(
|
||||
pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F: Integrable, RHS>(
|
||||
rhs: RHS,
|
||||
prev: &F,
|
||||
fut: &mut F,
|
||||
fut2: &mut F,
|
||||
prev: &F::State,
|
||||
fut: &mut F::State,
|
||||
fut2: &mut F::State,
|
||||
time: &mut Float,
|
||||
dt: Float,
|
||||
k: &mut [F],
|
||||
k: &mut [F::Diff],
|
||||
) where
|
||||
for<'r> &'r F: std::convert::Into<ArrayView<'r, Float, D>>,
|
||||
for<'r> &'r mut F: std::convert::Into<ArrayViewMut<'r, Float, D>>,
|
||||
RHS: FnMut(&mut F, &F, Float),
|
||||
D: ndarray::Dimension,
|
||||
RHS: FnMut(&mut F::Diff, &F::State, Float),
|
||||
{
|
||||
integrate::<BTableau, F, RHS, D>(rhs, prev, fut, time, dt, k);
|
||||
fut2.into().assign(&prev.into());
|
||||
integrate::<BTableau, F, RHS>(rhs, prev, fut, time, dt, k);
|
||||
F::assign(fut2, prev);
|
||||
for (&b, k) in BTableau::BSTAR.iter().zip(k.iter()) {
|
||||
if b == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut2.into().scaled_add(b * dt, &k.into());
|
||||
F::scaled_add(fut2, k, b * dt);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -255,21 +255,19 @@ pub fn integrate_embedded_rk<BTableau: EmbeddedButcherTableau, F, RHS, D>(
|
|||
///
|
||||
/// This function requires the `rayon` feature, and is not callable in
|
||||
/// a `wasm` context.
|
||||
pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS, D>(
|
||||
pub fn integrate_multigrid<BTableau: ButcherTableau, F: Integrable, RHS>(
|
||||
mut rhs: RHS,
|
||||
prev: &[F],
|
||||
fut: &mut [F],
|
||||
prev: &[F::State],
|
||||
fut: &mut [F::State],
|
||||
time: &mut Float,
|
||||
dt: Float,
|
||||
k: &mut [&mut [F]],
|
||||
k: &mut [&mut [F::Diff]],
|
||||
|
||||
pool: &rayon::ThreadPool,
|
||||
) where
|
||||
for<'r> &'r F: std::convert::Into<ArrayView<'r, Float, D>>,
|
||||
for<'r> &'r mut F: std::convert::Into<ArrayViewMut<'r, Float, D>>,
|
||||
RHS: FnMut(&mut [F], &[F], Float),
|
||||
F: Send + Sync,
|
||||
D: ndarray::Dimension,
|
||||
RHS: FnMut(&mut [F::Diff], &[F::State], Float),
|
||||
F::State: Send + Sync,
|
||||
F::Diff: Send + Sync,
|
||||
{
|
||||
for i in 0.. {
|
||||
let simtime;
|
||||
|
@ -279,8 +277,7 @@ pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS, D>(
|
|||
assert!(k.len() >= BTableau::S);
|
||||
for (prev, fut) in prev.iter().zip(fut.iter_mut()) {
|
||||
s.spawn(move |_| {
|
||||
assert_eq!(prev.into().shape(), fut.into().shape());
|
||||
fut.into().assign(&prev.into());
|
||||
F::assign(fut, prev);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
@ -291,12 +288,12 @@ pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS, D>(
|
|||
for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
|
||||
let k = &k;
|
||||
s.spawn(move |_| {
|
||||
fut.into().assign(&prev.into());
|
||||
F::assign(fut, prev);
|
||||
for (ik, &a) in BTableau::A[i - 1].iter().enumerate() {
|
||||
if a == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.into().scaled_add(a * dt, &(&k[ik][ig]).into());
|
||||
F::scaled_add(fut, &k[ik][ig], a * dt);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -308,12 +305,12 @@ pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS, D>(
|
|||
for (ig, (prev, fut)) in prev.iter().zip(fut.iter_mut()).enumerate() {
|
||||
let k = &k;
|
||||
s.spawn(move |_| {
|
||||
fut.into().assign(&prev.into());
|
||||
F::assign(fut, prev);
|
||||
for (ik, &b) in BTableau::B.iter().enumerate() {
|
||||
if b == 0.0 {
|
||||
continue;
|
||||
}
|
||||
fut.into().scaled_add(b * dt, &(&k[ik][ig]).into());
|
||||
F::scaled_add(fut, &k[ik][ig], b * dt);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
@ -329,3 +326,47 @@ pub fn integrate_multigrid<BTableau: ButcherTableau, F, RHS, D>(
|
|||
rhs(&mut k[i], &fut, simtime);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
/// Solving a second order PDE
|
||||
fn ballistic() {
|
||||
#[derive(Clone, Debug)]
|
||||
struct Ball {
|
||||
z: Float,
|
||||
v: Float,
|
||||
}
|
||||
impl Integrable for Ball {
|
||||
type State = Ball;
|
||||
type Diff = (Float, Float);
|
||||
fn assign(s: &mut Self::State, o: &Self::State) {
|
||||
s.z = o.z;
|
||||
s.v = o.v;
|
||||
}
|
||||
fn scaled_add(s: &mut Self::State, o: &Self::Diff, sc: Float) {
|
||||
s.z += o.0 * sc;
|
||||
s.v += o.1 * sc;
|
||||
}
|
||||
}
|
||||
|
||||
let mut t = 0.0;
|
||||
let dt = 0.001;
|
||||
let initial = Ball { z: 0.0, v: 10.0 };
|
||||
let g = -9.81;
|
||||
|
||||
let mut k = [(0.0, 0.0); 4];
|
||||
let gravity = |d: &mut (Float, Float), s: &Ball, _time: Float| {
|
||||
d.1 = g;
|
||||
d.0 = s.v
|
||||
};
|
||||
let mut next = initial.clone();
|
||||
//while next.z >= 0.0 {
|
||||
while t < 1.0 {
|
||||
let mut next2 = next.clone();
|
||||
integrate::<EulerMethod, Ball, _>(gravity, &next, &mut next2, &mut t, dt, &mut k);
|
||||
std::mem::swap(&mut next, &mut next2);
|
||||
}
|
||||
let expected_vel = initial.v + g * t;
|
||||
assert!((next.v - expected_vel).abs() < 1e-3);
|
||||
let expected_pos = initial.z + initial.v * t + g / 2.0 * t.powi(2);
|
||||
assert!((next.z - expected_pos).abs() < 1e-2);
|
||||
}
|
Loading…
Reference in New Issue