diff --git a/src/lib.rs b/src/lib.rs index 0a0c9c8..339bb5c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ use wasm_bindgen::prelude::*; mod maxwell; mod operators; use maxwell::{System, WorkBuffers}; +use operators::Upwind4; #[cfg(feature = "wee_alloc")] #[global_allocator] @@ -34,7 +35,7 @@ impl Universe { } pub fn advance(&mut self, dt: f32) { - self.sys.0.advance(&mut self.sys.1, dt, Some(&mut self.wb)); + System::advance::(&self.sys.0, &mut self.sys.1, dt, Some(&mut self.wb)); std::mem::swap(&mut self.sys.0, &mut self.sys.1); } diff --git a/src/maxwell.rs b/src/maxwell.rs index 2f87226..89ed152 100644 --- a/src/maxwell.rs +++ b/src/maxwell.rs @@ -1,4 +1,4 @@ -use super::operators::{diffx, diffy}; +use super::operators::SbpOperator; use ndarray::{Array2, Zip}; pub struct System { @@ -43,7 +43,10 @@ impl System { } } - pub fn advance(&self, fut: &mut System, dt: f32, work_buffers: Option<&mut WorkBuffers>) { + pub fn advance(&self, fut: &mut System, dt: f32, work_buffers: Option<&mut WorkBuffers>) + where + SBP: SbpOperator, + { assert_eq!(self.ex.shape(), fut.ex.shape()); let mut wb: WorkBuffers; @@ -84,18 +87,18 @@ impl System { // ex = hz_y k[i].0.fill(0.0); - diffy(y.1.view(), k[i].0.view_mut()); + SBP::diffy(y.1.view(), k[i].0.view_mut()); // ey = -hz_x k[i].2.fill(0.0); - diffx(y.1.view(), k[i].2.view_mut()); + SBP::diffx(y.1.view(), k[i].2.view_mut()); k[i].2.mapv_inplace(|v| -v); // hz = -ey_x + ex_y k[i].1.fill(0.0); - diffx(y.2.view(), k[i].1.view_mut()); + SBP::diffx(y.2.view(), k[i].1.view_mut()); k[i].1.mapv_inplace(|v| -v); - diffy(y.0.view(), k[i].1.view_mut()); + SBP::diffy(y.0.view(), k[i].1.view_mut()); // Boundary conditions (SAT) let ny = y.0.shape()[0]; diff --git a/src/operators.rs b/src/operators.rs index 3f9109d..9fd66c4 100644 --- a/src/operators.rs +++ b/src/operators.rs @@ -1,344 +1,10 @@ -#![allow(dead_code)] -use ndarray::{arr1, arr2, s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; +use ndarray::{ArrayView2, ArrayViewMut2}; -pub(crate) fn diffx(prev: ArrayView2, mut fut: ArrayViewMut2) { - for j in 0..prev.shape()[0] { - upwind4(prev.slice(s!(j, ..)), fut.slice_mut(s!(j, ..))); - } +pub trait SbpOperator { + fn diffx(prev: ArrayView2, fut: ArrayViewMut2); + fn diffy(prev: ArrayView2, fut: ArrayViewMut2); + fn h() -> &'static [f32]; } -pub(crate) fn diffx_periodic(prev: ArrayView2, mut fut: ArrayViewMut2) { - for j in 0..prev.shape()[0] { - upwind4_periodic(prev.slice(s!(j, ..)), fut.slice_mut(s!(j, ..))); - } -} - -pub(crate) fn diffy(prev: ArrayView2, mut fut: ArrayViewMut2) { - for i in 0..prev.shape()[1] { - upwind4(prev.slice(s!(.., i)), fut.slice_mut(s!(.., i))); - } -} - -pub(crate) fn diffy_periodic(prev: ArrayView2, mut fut: ArrayViewMut2) { - for i in 0..prev.shape()[1] { - upwind4_periodic(prev.slice(s!(.., i)), fut.slice_mut(s!(.., i))); - } -} - -pub(crate) fn dissx(prev: ArrayView2, mut fut: ArrayViewMut2) { - for j in 0..prev.shape()[0] { - upwind4_diss(prev.slice(s!(j, ..)), fut.slice_mut(s!(j, ..))); - } -} -pub(crate) fn dissy(prev: ArrayView2, mut fut: ArrayViewMut2) { - for i in 0..prev.shape()[1] { - upwind4_diss(prev.slice(s!(.., i)), fut.slice_mut(s!(.., i))); - } -} - -fn trad4_periodic(prev: ArrayView1, mut fut: ArrayViewMut1) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[0]; - - let dx = 1.0 / (nx - 1) as f32; - - let diag = [1.0 / 12.0, -2.0 / 3.0, 0.0, 2.0 / 3.0, -1.0 / 12.0]; - - let diff = diag[0] * prev[(nx - 2)] - + diag[1] * prev[(nx - 1)] - + diag[2] * prev[(0)] - + diag[3] * prev[(1)] - + diag[4] * prev[(2)]; - fut[(0)] += diff / dx; - let diff = diag[0] * prev[(nx - 1)] - + diag[1] * prev[(0)] - + diag[2] * prev[(1)] - + diag[3] * prev[(2)] - + diag[4] * prev[(3)]; - fut[(1)] += diff / dx; - for i in 2..nx - 2 { - let diff = diag[0] * prev[(i - 2)] - + diag[1] * prev[(i - 1)] - + diag[2] * prev[(i)] - + diag[3] * prev[(i + 1)] - + diag[4] * prev[(i + 2)]; - fut[(i)] += diff / dx; - } - let diff = diag[0] * prev[(nx - 4)] - + diag[1] * prev[(nx - 3)] - + diag[2] * prev[(nx - 2)] - + diag[3] * prev[(nx - 1)] - + diag[4] * prev[(0)]; - fut[(nx - 2)] += diff / dx; - let diff = diag[0] * prev[(nx - 3)] - + diag[1] * prev[(nx - 2)] - + diag[2] * prev[(nx - 1)] - + diag[3] * prev[(0)] - + diag[4] * prev[(1)]; - fut[(nx - 1)] += diff / dx; -} - -fn upwind4(prev: ArrayView1, mut fut: ArrayViewMut1) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[0]; - - let dx = 1.0 / (nx - 1) as f32; - - let diag = arr1(&[ - -1.0 / 24.0, - 1.0 / 4.0, - -7.0 / 8.0, - 0.0, - 7.0 / 8.0, - -1.0 / 4.0, - 1.0 / 24.0, - ]); - - let block = arr2(&[ - [ - -72.0 / 49.0f32, - 187.0 / 98.0, - -20.0 / 49.0, - -3.0 / 98.0, - 0.0, - 0.0, - 0.0, - ], - [ - -187.0 / 366.0, - 0.0, - 69.0 / 122.0, - -16.0 / 183.0, - 2.0 / 61.0, - 0.0, - 0.0, - ], - [ - 20.0 / 123.0, - -69.0 / 82.0, - 0.0, - 227.0 / 246.0, - -12.0 / 41.0, - 2.0 / 41.0, - 0.0, - ], - [ - 3.0 / 298.0, - 16.0 / 149.0, - -227.0 / 298.0, - 0.0, - 126.0 / 149.0, - -36.0 / 149.0, - 6.0 / 149.0, - ], - ]); - - // let h_block = [49.0 / 144.0, 61.0 / 48.0, 41.0 / 48.0, 149.0 / 144.0]; - - let first_elems = prev.slice(s!(..7)); - for i in 0..4 { - let diff = first_elems.dot(&block.slice(s!(i, ..))); - fut[i] += diff / dx; - } - - for i in 4..nx - 4 { - let diff = diag.dot(&prev.slice(s!(i - 3..i + 3 + 1))); - fut[(i)] += diff / dx; - } - let last_elems = prev.slice(s!(nx - 7..)); - for i in 0..4 { - let ii = nx - 4 + i; - let block = block.slice(s!(3 - i, ..;-1)); - let diff = last_elems.dot(&block); - fut[ii] += -diff / dx; - } -} - -fn upwind4_periodic(prev: ArrayView1, mut fut: ArrayViewMut1) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[0]; - - let dx = 1.0 / (nx - 1) as f32; - - let diag = [ - -1.0 / 24.0, - 1.0 / 4.0, - -7.0 / 8.0, - 0.0, - 7.0 / 8.0, - -1.0 / 4.0, - 1.0 / 24.0, - ]; - - let diff = diag[0] * prev[(nx - 3)] - + diag[1] * prev[(nx - 2)] - + diag[2] * prev[(nx - 1)] - + diag[3] * prev[(0)] - + diag[4] * prev[(1)] - + diag[5] * prev[(2)] - + diag[6] * prev[(3)]; - fut[0] += diff / dx; - let diff = diag[0] * prev[(nx - 2)] - + diag[1] * prev[(nx - 1)] - + diag[2] * prev[(0)] - + diag[3] * prev[(1)] - + diag[4] * prev[(2)] - + diag[5] * prev[(3)] - + diag[6] * prev[(4)]; - fut[1] += diff / dx; - let diff = diag[0] * prev[(nx - 1)] - + diag[1] * prev[(0)] - + diag[2] * prev[(1)] - + diag[3] * prev[(2)] - + diag[4] * prev[(3)] - + diag[5] * prev[(4)] - + diag[6] * prev[(5)]; - fut[2] += diff / dx; - - for i in 3..nx - 3 { - let diff = diag[0] * prev[(i - 3)] - + diag[1] * prev[(i - 2)] - + diag[2] * prev[(i - 1)] - + diag[3] * prev[(i)] - + diag[4] * prev[(i + 1)] - + diag[5] * prev[(i + 2)] - + diag[6] * prev[(i + 3)]; - fut[(i)] += diff / dx; - } - let diff = diag[0] * prev[(nx - 6)] - + diag[1] * prev[(nx - 5)] - + diag[2] * prev[(nx - 4)] - + diag[3] * prev[(nx - 3)] - + diag[4] * prev[(nx - 2)] - + diag[5] * prev[(nx - 1)] - + diag[6] * prev[(0)]; - fut[(nx - 3)] += diff / dx; - let diff = diag[0] * prev[(nx - 5)] - + diag[1] * prev[(nx - 4)] - + diag[2] * prev[(nx - 3)] - + diag[3] * prev[(nx - 2)] - + diag[4] * prev[(nx - 1)] - + diag[5] * prev[(0)] - + diag[6] * prev[(1)]; - fut[(nx - 2)] += diff / dx; - let diff = diag[0] * prev[(nx - 4)] - + diag[1] * prev[(nx - 3)] - + diag[2] * prev[(nx - 2)] - + diag[3] * prev[(nx - 1)] - + diag[4] * prev[(0)] - + diag[5] * prev[(1)] - + diag[6] * prev[(2)]; - fut[(nx - 1)] += diff / dx; -} - -#[test] -fn upwind4_test() { - let nx = 20; - let dx = 1.0 / (nx - 1) as f32; - let mut source: ndarray::Array1 = ndarray::Array1::zeros((nx)); - let mut res = ndarray::Array1::zeros((nx)); - let mut target = ndarray::Array1::zeros((nx)); - - for i in 0..nx { - source[i] = i as f32 * dx; - target[i] = 1.0; - } - res.fill(0.0); - upwind4(source.view(), res.view_mut()); - assert!(res.all_close(&target, 1e-4)); - - for i in 0..nx { - let x = i as f32 * dx; - source[i] = x * x; - target[i] = 2.0 * x; - } - res.fill(0.0); - upwind4(source.view(), res.view_mut()); - assert!(res.all_close(&target, 1e-4)); - - for i in 0..nx { - let x = i as f32 * dx; - source[i] = x * x * x; - target[i] = 3.0 * x * x; - } - res.fill(0.0); - upwind4(source.view(), res.view_mut()); - assert!(res.all_close(&target, 1e-2)); -} - -fn upwind4_diss(prev: ArrayView1, mut fut: ArrayViewMut1) { - assert_eq!(prev.shape(), fut.shape()); - let nx = prev.shape()[0]; - - let dx = 1.0 / (nx - 1) as f32; - - let diag = [ - 1.0 / 24.0, - -1.0 / 4.0, - 5.0 / 8.0, - -5.0 / 6.0, - 5.0 / 8.0, - -1.0 / 4.0, - 1.0 / 24.0, - ]; - - let diff = diag[0] * prev[(nx - 3)] - + diag[1] * prev[(nx - 2)] - + diag[2] * prev[(nx - 1)] - + diag[3] * prev[(0)] - + diag[4] * prev[(1)] - + diag[5] * prev[(2)] - + diag[6] * prev[(3)]; - fut[0] += diff / dx; - let diff = diag[0] * prev[(nx - 2)] - + diag[1] * prev[(nx - 1)] - + diag[2] * prev[(0)] - + diag[3] * prev[(1)] - + diag[4] * prev[(2)] - + diag[5] * prev[(3)] - + diag[6] * prev[(4)]; - fut[1] += diff / dx; - let diff = diag[0] * prev[(nx - 1)] - + diag[1] * prev[(0)] - + diag[2] * prev[(1)] - + diag[3] * prev[(2)] - + diag[4] * prev[(3)] - + diag[5] * prev[(4)] - + diag[6] * prev[(5)]; - fut[2] += diff / dx; - - for i in 3..nx - 3 { - let diff = diag[0] * prev[(i - 3)] - + diag[1] * prev[(i - 2)] - + diag[2] * prev[(i - 1)] - + diag[3] * prev[(i)] - + diag[4] * prev[(i + 1)] - + diag[5] * prev[(i + 2)] - + diag[6] * prev[(i + 3)]; - fut[(i)] += diff / dx; - } - let diff = diag[0] * prev[(nx - 6)] - + diag[1] * prev[(nx - 5)] - + diag[2] * prev[(nx - 4)] - + diag[3] * prev[(nx - 3)] - + diag[4] * prev[(nx - 2)] - + diag[5] * prev[(nx - 1)] - + diag[6] * prev[(0)]; - fut[(nx - 3)] += diff / dx; - let diff = diag[0] * prev[(nx - 5)] - + diag[1] * prev[(nx - 4)] - + diag[2] * prev[(nx - 3)] - + diag[3] * prev[(nx - 2)] - + diag[4] * prev[(nx - 1)] - + diag[5] * prev[(0)] - + diag[6] * prev[(1)]; - fut[(nx - 2)] += diff / dx; - let diff = diag[0] * prev[(nx - 4)] - + diag[1] * prev[(nx - 3)] - + diag[2] * prev[(nx - 2)] - + diag[3] * prev[(nx - 1)] - + diag[4] * prev[(0)] - + diag[5] * prev[(1)] - + diag[6] * prev[(2)]; - fut[(nx - 1)] += diff / dx; -} +mod upwind4; +pub use upwind4::Upwind4; diff --git a/src/operators/upwind4.rs b/src/operators/upwind4.rs new file mode 100644 index 0000000..760ad3c --- /dev/null +++ b/src/operators/upwind4.rs @@ -0,0 +1,136 @@ +use super::SbpOperator; +use ndarray::{arr1, arr2, s, ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2}; + +pub struct Upwind4 {} + +impl Upwind4 { + const HBLOCK: &'static [f32] = &[49.0 / 144.0, 61.0 / 48.0, 41.0 / 48.0, 149.0 / 144.0]; + const DIAG: &'static [f32] = &[ + -1.0 / 24.0, + 1.0 / 4.0, + -7.0 / 8.0, + 0.0, + 7.0 / 8.0, + -1.0 / 4.0, + 1.0 / 24.0, + ]; + const BLOCK: &'static [[f32; 7]] = &[ + [ + -72.0 / 49.0f32, + 187.0 / 98.0, + -20.0 / 49.0, + -3.0 / 98.0, + 0.0, + 0.0, + 0.0, + ], + [ + -187.0 / 366.0, + 0.0, + 69.0 / 122.0, + -16.0 / 183.0, + 2.0 / 61.0, + 0.0, + 0.0, + ], + [ + 20.0 / 123.0, + -69.0 / 82.0, + 0.0, + 227.0 / 246.0, + -12.0 / 41.0, + 2.0 / 41.0, + 0.0, + ], + [ + 3.0 / 298.0, + 16.0 / 149.0, + -227.0 / 298.0, + 0.0, + 126.0 / 149.0, + -36.0 / 149.0, + 6.0 / 149.0, + ], + ]; + + fn diff(prev: ArrayView1, mut fut: ArrayViewMut1) { + assert_eq!(prev.shape(), fut.shape()); + let nx = prev.shape()[0]; + + let dx = 1.0 / (nx - 1) as f32; + + let diag = arr1(Upwind4::DIAG); + let block = arr2(Upwind4::BLOCK); + + let first_elems = prev.slice(s!(..7)); + for i in 0..4 { + let diff = first_elems.dot(&block.slice(s!(i, ..))); + fut[i] += diff / dx; + } + + for i in 4..nx - 4 { + let diff = diag.dot(&prev.slice(s!(i - 3..i + 3 + 1))); + fut[(i)] += diff / dx; + } + let last_elems = prev.slice(s!(nx - 7..)); + for i in 0..4 { + let ii = nx - 4 + i; + let block = block.slice(s!(3 - i, ..;-1)); + let diff = last_elems.dot(&block); + fut[ii] += -diff / dx; + } + } +} + +#[test] +fn upwind4_test() { + let nx = 20; + let dx = 1.0 / (nx - 1) as f32; + let mut source: ndarray::Array1 = ndarray::Array1::zeros(nx); + let mut res = ndarray::Array1::zeros(nx); + let mut target = ndarray::Array1::zeros(nx); + + for i in 0..nx { + source[i] = i as f32 * dx; + target[i] = 1.0; + } + res.fill(0.0); + Upwind4::diff(source.view(), res.view_mut()); + assert!(res.all_close(&target, 1e-4)); + + for i in 0..nx { + let x = i as f32 * dx; + source[i] = x * x; + target[i] = 2.0 * x; + } + res.fill(0.0); + Upwind4::diff(source.view(), res.view_mut()); + assert!(res.all_close(&target, 1e-4)); + + for i in 0..nx { + let x = i as f32 * dx; + source[i] = x * x * x; + target[i] = 3.0 * x * x; + } + res.fill(0.0); + Upwind4::diff(source.view(), res.view_mut()); + assert!(res.all_close(&target, 1e-2)); +} + +impl SbpOperator for Upwind4 { + fn diffx(prev: ArrayView2, mut fut: ArrayViewMut2) { + for j in 0..prev.shape()[0] { + Upwind4::diff(prev.slice(s!(j, ..)), fut.slice_mut(s!(j, ..))); + } + } + + fn diffy(prev: ArrayView2, mut fut: ArrayViewMut2) { + for i in 0..prev.shape()[1] { + Upwind4::diff(prev.slice(s!(.., i)), fut.slice_mut(s!(.., i))); + } + } + + fn h() -> &'static [f32] { + Upwind4::HBLOCK + } +}