use Matrix in SBP diff
This commit is contained in:
parent
3c7cc4605a
commit
c133557459
|
@ -3,7 +3,7 @@ use super::*;
|
||||||
pub(crate) mod constmatrix {
|
pub(crate) mod constmatrix {
|
||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
/// A row-major matrix
|
/// A row-major matrix
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub struct Matrix<T, const M: usize, const N: usize> {
|
pub struct Matrix<T, const M: usize, const N: usize> {
|
||||||
data: [[T; N]; M],
|
data: [[T; N]; M],
|
||||||
|
@ -11,20 +11,12 @@ pub(crate) mod constmatrix {
|
||||||
pub type RowVector<T, const N: usize> = Matrix<T, 1, N>;
|
pub type RowVector<T, const N: usize> = Matrix<T, 1, N>;
|
||||||
pub type ColVector<T, const N: usize> = Matrix<T, N, 1>;
|
pub type ColVector<T, const N: usize> = Matrix<T, N, 1>;
|
||||||
|
|
||||||
impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
impl<T: Copy + Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
use std::mem::MaybeUninit;
|
Self {
|
||||||
let mut d: [[MaybeUninit<T>; N]; M] = unsafe { MaybeUninit::uninit().assume_init() };
|
data: [[T::default(); N]; M],
|
||||||
|
|
||||||
for row in d.iter_mut() {
|
|
||||||
for item in row.iter_mut() {
|
|
||||||
*item = MaybeUninit::new(T::default());
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let data = unsafe { std::mem::transmute_copy::<_, [[T; N]; M]>(&d) };
|
|
||||||
Self { data }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, const M: usize, const N: usize> core::ops::Index<(usize, usize)> for Matrix<T, M, N> {
|
impl<T, const M: usize, const N: usize> core::ops::Index<(usize, usize)> for Matrix<T, M, N> {
|
||||||
|
@ -111,12 +103,12 @@ pub(crate) mod constmatrix {
|
||||||
|
|
||||||
pub fn flip(&self) -> Self
|
pub fn flip(&self) -> Self
|
||||||
where
|
where
|
||||||
T: Default + Clone,
|
T: Default + Copy,
|
||||||
{
|
{
|
||||||
let mut v = Self::default();
|
let mut v = Self::default();
|
||||||
for i in 0..M {
|
for i in 0..M {
|
||||||
for j in 0..N {
|
for j in 0..N {
|
||||||
v[(i, j)] = self[(M - 1 - i, N - 1 - j)].clone()
|
v[(i, j)] = self[(M - 1 - i, N - 1 - j)]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
v
|
v
|
||||||
|
@ -163,10 +155,6 @@ pub(crate) mod constmatrix {
|
||||||
|
|
||||||
let _m2 = Matrix::new([[1, 2], [3, 4]]);
|
let _m2 = Matrix::new([[1, 2], [3, 4]]);
|
||||||
}
|
}
|
||||||
#[test]
|
|
||||||
fn construct_non_copy() {
|
|
||||||
let _m = Matrix::<String, 2, 1>::default();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn matmul() {
|
fn matmul() {
|
||||||
|
@ -184,8 +172,8 @@ pub(crate) use constmatrix::{ColVector, Matrix, RowVector};
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
||||||
block: &Matrix<Float, M, N>,
|
block: &Matrix<Float, M, N>,
|
||||||
diag: &RowVector<Float, N>,
|
blockend: &Matrix<Float, M, N>,
|
||||||
symmetry: Symmetry,
|
diag: &RowVector<Float, D>,
|
||||||
optype: OperatorType,
|
optype: OperatorType,
|
||||||
prev: ArrayView1<Float>,
|
prev: ArrayView1<Float>,
|
||||||
mut fut: ArrayViewMut1<Float>,
|
mut fut: ArrayViewMut1<Float>,
|
||||||
|
@ -226,19 +214,14 @@ pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
||||||
*f = diff * idx;
|
*f = diff * idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (bl, f) in block.iter_rows().zip(fut.iter_mut().rev()) {
|
for (bl, f) in blockend.iter_rows().zip(fut.iter_mut().rev().take(M).rev()) {
|
||||||
let diff = bl
|
let diff = bl
|
||||||
.iter()
|
.iter()
|
||||||
.zip(prev.iter().rev())
|
.zip(prev.iter())
|
||||||
.map(|(x, y)| x * y)
|
.map(|(x, y)| x * y)
|
||||||
.sum::<Float>();
|
.sum::<Float>();
|
||||||
|
|
||||||
*f = idx
|
*f = diff * idx;
|
||||||
* if symmetry == Symmetry::Symmetric {
|
|
||||||
diff
|
|
||||||
} else {
|
|
||||||
-diff
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -708,6 +691,174 @@ fn product_fast<'a>(
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn diff_op_col_matrix<const M: usize, const N: usize, const D: usize>(
|
||||||
|
block: &Matrix<Float, M, N>,
|
||||||
|
block2: &Matrix<Float, M, N>,
|
||||||
|
diag: &RowVector<Float, D>,
|
||||||
|
optype: OperatorType,
|
||||||
|
prev: ArrayView2<Float>,
|
||||||
|
mut fut: ArrayViewMut2<Float>,
|
||||||
|
) {
|
||||||
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
|
let nx = prev.shape()[1];
|
||||||
|
assert!(nx >= 2 * M);
|
||||||
|
|
||||||
|
assert_eq!(prev.strides()[0], 1);
|
||||||
|
assert_eq!(fut.strides()[0], 1);
|
||||||
|
|
||||||
|
let dx = if optype == OperatorType::H2 {
|
||||||
|
1.0 / (nx - 2) as Float
|
||||||
|
} else {
|
||||||
|
1.0 / (nx - 1) as Float
|
||||||
|
};
|
||||||
|
let idx = 1.0 / dx;
|
||||||
|
|
||||||
|
#[cfg(not(feature = "f32"))]
|
||||||
|
type SimdT = packed_simd::f64x8;
|
||||||
|
#[cfg(feature = "f32")]
|
||||||
|
type SimdT = packed_simd::f32x16;
|
||||||
|
|
||||||
|
let ny = prev.shape()[0];
|
||||||
|
// How many elements that can be simdified
|
||||||
|
let simdified = SimdT::lanes() * (ny / SimdT::lanes());
|
||||||
|
|
||||||
|
let half_diag_width = (D - 1) / 2;
|
||||||
|
assert!(half_diag_width <= M);
|
||||||
|
|
||||||
|
let fut_base_ptr = fut.as_mut_ptr();
|
||||||
|
let fut_stride = fut.strides()[1];
|
||||||
|
let fut_ptr = |j, i| {
|
||||||
|
debug_assert!(j < ny && i < nx);
|
||||||
|
unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) }
|
||||||
|
};
|
||||||
|
|
||||||
|
let prev_base_ptr = prev.as_ptr();
|
||||||
|
let prev_stride = prev.strides()[1];
|
||||||
|
let prev_ptr = |j, i| {
|
||||||
|
debug_assert!(j < ny && i < nx);
|
||||||
|
unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Not algo necessary, but gives performance increase
|
||||||
|
assert_eq!(fut_stride, prev_stride);
|
||||||
|
|
||||||
|
// First block
|
||||||
|
{
|
||||||
|
for (ifut, &bl) in block.iter_rows().enumerate() {
|
||||||
|
for j in (0..simdified).step_by(SimdT::lanes()) {
|
||||||
|
let index_to_simd = |i| unsafe {
|
||||||
|
// j never moves past end of slice due to step_by and
|
||||||
|
// rounding down
|
||||||
|
SimdT::from_slice_unaligned(std::slice::from_raw_parts(
|
||||||
|
prev_ptr(j, i),
|
||||||
|
SimdT::lanes(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
let mut f = SimdT::splat(0.0);
|
||||||
|
for (iprev, &bl) in bl.iter().enumerate() {
|
||||||
|
f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
|
||||||
|
}
|
||||||
|
f *= idx;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
|
||||||
|
fut_ptr(j, ifut),
|
||||||
|
SimdT::lanes(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in simdified..ny {
|
||||||
|
unsafe {
|
||||||
|
let mut f = 0.0;
|
||||||
|
for (iprev, bl) in bl.iter().enumerate() {
|
||||||
|
f += bl * *prev_ptr(j, iprev);
|
||||||
|
}
|
||||||
|
*fut_ptr(j, ifut) = f * idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Diagonal elements
|
||||||
|
{
|
||||||
|
for ifut in M..nx - M {
|
||||||
|
for j in (0..simdified).step_by(SimdT::lanes()) {
|
||||||
|
let index_to_simd = |i| unsafe {
|
||||||
|
// j never moves past end of slice due to step_by and
|
||||||
|
// rounding down
|
||||||
|
SimdT::from_slice_unaligned(std::slice::from_raw_parts(
|
||||||
|
prev_ptr(j, i),
|
||||||
|
SimdT::lanes(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
let mut f = SimdT::splat(0.0);
|
||||||
|
for (id, &d) in diag.iter().enumerate() {
|
||||||
|
let offset = ifut - half_diag_width + id;
|
||||||
|
f = index_to_simd(offset).mul_adde(SimdT::splat(d), f);
|
||||||
|
}
|
||||||
|
f *= idx;
|
||||||
|
unsafe {
|
||||||
|
// puts simd along stride 1, j never goes past end of slice
|
||||||
|
f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
|
||||||
|
fut_ptr(j, ifut),
|
||||||
|
SimdT::lanes(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in simdified..ny {
|
||||||
|
let mut f = 0.0;
|
||||||
|
for (id, &d) in diag.iter().enumerate() {
|
||||||
|
let offset = ifut - half_diag_width + id;
|
||||||
|
unsafe {
|
||||||
|
f += d * *prev_ptr(j, offset);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
*fut_ptr(j, ifut) = idx * f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// End block
|
||||||
|
{
|
||||||
|
for (ifut, &bl) in (nx - M..nx).zip(block2.iter_rows()) {
|
||||||
|
for j in (0..simdified).step_by(SimdT::lanes()) {
|
||||||
|
let index_to_simd = |i| unsafe {
|
||||||
|
// j never moves past end of slice due to step_by and
|
||||||
|
// rounding down
|
||||||
|
SimdT::from_slice_unaligned(std::slice::from_raw_parts(
|
||||||
|
prev_ptr(j, i),
|
||||||
|
SimdT::lanes(),
|
||||||
|
))
|
||||||
|
};
|
||||||
|
let mut f = SimdT::splat(0.0);
|
||||||
|
for (iprev, &bl) in (nx - N..nx).zip(bl.iter()) {
|
||||||
|
f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
|
||||||
|
}
|
||||||
|
f *= idx;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
|
||||||
|
fut_ptr(j, ifut),
|
||||||
|
SimdT::lanes(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for j in simdified..ny {
|
||||||
|
unsafe {
|
||||||
|
let mut f = 0.0;
|
||||||
|
for (iprev, bl) in (nx - N..nx).zip(bl.iter()) {
|
||||||
|
f += bl * *prev_ptr(j, iprev);
|
||||||
|
}
|
||||||
|
*fut_ptr(j, ifut) = f * idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn diff_op_row(
|
pub(crate) fn diff_op_row(
|
||||||
block: &'static [&'static [Float]],
|
block: &'static [&'static [Float]],
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use super::{diff_op_col, diff_op_row, SbpOperator1d, SbpOperator2d};
|
use super::{SbpOperator1d, SbpOperator2d};
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
|
use ndarray::{ArrayView1, ArrayView2, ArrayViewMut1, ArrayViewMut2};
|
||||||
|
|
||||||
|
@ -55,10 +55,10 @@ impl SBP4 {
|
||||||
|
|
||||||
impl SbpOperator1d for SBP4 {
|
impl SbpOperator1d for SBP4 {
|
||||||
fn diff(&self, prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
|
fn diff(&self, prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
|
||||||
super::diff_op_1d(
|
super::diff_op_1d_matrix(
|
||||||
Self::BLOCK,
|
&Self::BLOCK_MATRIX,
|
||||||
Self::DIAG,
|
&Self::BLOCKEND_MATRIX,
|
||||||
super::Symmetry::AntiSymmetric,
|
&Self::DIAG_MATRIX,
|
||||||
super::OperatorType::Normal,
|
super::OperatorType::Normal,
|
||||||
prev,
|
prev,
|
||||||
fut,
|
fut,
|
||||||
|
@ -104,6 +104,17 @@ fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayVi
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fn diff_op_col_local(prev: ndarray::ArrayView2<Float>, fut: ndarray::ArrayViewMut2<Float>) {
|
||||||
|
let optype = super::OperatorType::Normal;
|
||||||
|
super::diff_op_col_matrix(
|
||||||
|
&SBP4::BLOCK_MATRIX,
|
||||||
|
&SBP4::BLOCKEND_MATRIX,
|
||||||
|
&SBP4::DIAG_MATRIX,
|
||||||
|
optype,
|
||||||
|
prev,
|
||||||
|
fut,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
impl SbpOperator2d for SBP4 {
|
impl SbpOperator2d for SBP4 {
|
||||||
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
|
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
|
||||||
|
@ -112,14 +123,14 @@ impl SbpOperator2d for SBP4 {
|
||||||
|
|
||||||
let symmetry = super::Symmetry::AntiSymmetric;
|
let symmetry = super::Symmetry::AntiSymmetric;
|
||||||
let optype = super::OperatorType::Normal;
|
let optype = super::OperatorType::Normal;
|
||||||
|
|
||||||
match (prev.strides(), fut.strides()) {
|
match (prev.strides(), fut.strides()) {
|
||||||
([_, 1], [_, 1]) => {
|
([_, 1], [_, 1]) => {
|
||||||
//diff_op_row(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
|
//diff_op_row(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
|
||||||
diff_op_row_local(prev, fut)
|
diff_op_row_local(prev, fut)
|
||||||
}
|
}
|
||||||
([1, _], [1, _]) => {
|
([1, _], [1, _]) => {
|
||||||
diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
|
//diff_op_col(SBP4::BLOCK, SBP4::DIAG, symmetry, optype)(prev, fut);
|
||||||
|
diff_op_col_local(prev, fut)
|
||||||
}
|
}
|
||||||
([_, _], [_, _]) => {
|
([_, _], [_, _]) => {
|
||||||
// Fallback, work row by row
|
// Fallback, work row by row
|
||||||
|
|
|
@ -55,10 +55,10 @@ impl SBP8 {
|
||||||
|
|
||||||
impl SbpOperator1d for SBP8 {
|
impl SbpOperator1d for SBP8 {
|
||||||
fn diff(&self, prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
|
fn diff(&self, prev: ArrayView1<Float>, fut: ArrayViewMut1<Float>) {
|
||||||
super::diff_op_1d(
|
super::diff_op_1d_matrix(
|
||||||
Self::BLOCK,
|
&Self::BLOCK_MATRIX,
|
||||||
Self::DIAG,
|
&Self::BLOCKEND_MATRIX,
|
||||||
super::Symmetry::AntiSymmetric,
|
&Self::DIAG_MATRIX,
|
||||||
super::OperatorType::Normal,
|
super::OperatorType::Normal,
|
||||||
prev,
|
prev,
|
||||||
fut,
|
fut,
|
||||||
|
@ -107,6 +107,18 @@ fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayVi
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn diff_op_col_local(prev: ndarray::ArrayView2<Float>, fut: ndarray::ArrayViewMut2<Float>) {
|
||||||
|
let optype = super::OperatorType::Normal;
|
||||||
|
super::diff_op_col_matrix(
|
||||||
|
&SBP8::BLOCK_MATRIX,
|
||||||
|
&SBP8::BLOCKEND_MATRIX,
|
||||||
|
&SBP8::DIAG_MATRIX,
|
||||||
|
optype,
|
||||||
|
prev,
|
||||||
|
fut,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
impl SbpOperator2d for SBP8 {
|
impl SbpOperator2d for SBP8 {
|
||||||
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
|
fn diffxi(&self, prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>) {
|
||||||
assert_eq!(prev.shape(), fut.shape());
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
|
@ -114,14 +126,14 @@ impl SbpOperator2d for SBP8 {
|
||||||
|
|
||||||
let symmetry = super::Symmetry::AntiSymmetric;
|
let symmetry = super::Symmetry::AntiSymmetric;
|
||||||
let optype = super::OperatorType::Normal;
|
let optype = super::OperatorType::Normal;
|
||||||
|
|
||||||
match (prev.strides(), fut.strides()) {
|
match (prev.strides(), fut.strides()) {
|
||||||
([_, 1], [_, 1]) => {
|
([_, 1], [_, 1]) => {
|
||||||
//diff_op_row(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
|
//diff_op_row(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
|
||||||
diff_op_row_local(prev, fut);
|
diff_op_row_local(prev, fut);
|
||||||
}
|
}
|
||||||
([1, _], [1, _]) => {
|
([1, _], [1, _]) => {
|
||||||
diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
|
//diff_op_col(SBP8::BLOCK, SBP8::DIAG, symmetry, optype)(prev, fut);
|
||||||
|
diff_op_col_local(prev, fut)
|
||||||
}
|
}
|
||||||
([_, _], [_, _]) => {
|
([_, _], [_, _]) => {
|
||||||
// Fallback, work row by row
|
// Fallback, work row by row
|
||||||
|
|
Loading…
Reference in New Issue