checkpoint
This commit is contained in:
parent
c104082ac0
commit
94e8fb5b7c
|
@ -1,4 +1,5 @@
|
|||
#![feature(core_intrinsics)]
|
||||
#![feature(array_windows)]
|
||||
|
||||
/// Type used for floats, configure with the `f32` feature
|
||||
#[cfg(feature = "f32")]
|
||||
|
|
|
@ -6,6 +6,8 @@ pub(crate) mod constmatrix {
|
|||
pub struct Matrix<T, const M: usize, const N: usize> {
|
||||
data: [[T; N]; M],
|
||||
}
|
||||
pub type RowVector<T, const N: usize> = Matrix<T, 1, N>;
|
||||
pub type ColVector<T, const N: usize> = Matrix<T, N, 1>;
|
||||
|
||||
impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
||||
fn default() -> Self {
|
||||
|
@ -89,11 +91,61 @@ pub(crate) mod constmatrix {
|
|||
}
|
||||
}
|
||||
}
|
||||
pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + DoubleEndedIterator<Item = &T> {
|
||||
(0..N * M).map(move |x| {
|
||||
let i = x / N;
|
||||
let j = x % N;
|
||||
&self[(i, j)]
|
||||
})
|
||||
}
|
||||
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
|
||||
self.data.iter_mut().flatten()
|
||||
}
|
||||
pub fn iter_rows(
|
||||
&self,
|
||||
) -> impl ExactSizeIterator<Item = &[T; N]> + DoubleEndedIterator<Item = &[T; N]> {
|
||||
(0..M).map(move |i| &self[i])
|
||||
}
|
||||
|
||||
pub fn flip(&self) -> Self
|
||||
where
|
||||
T: Default + Clone,
|
||||
{
|
||||
let mut v = Self::default();
|
||||
for i in 0..M {
|
||||
for j in 0..N {
|
||||
v[(i, j)] = self[(N - 1 - i, M - 1 - j)].clone()
|
||||
}
|
||||
}
|
||||
v
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const N: usize> ColVector<T, N> {
|
||||
pub fn map_to_col(slice: &[T; N]) -> &ColVector<T, N> {
|
||||
unsafe { std::mem::transmute::<&[T; N], &Self>(slice) }
|
||||
}
|
||||
pub fn map_to_col_mut(slice: &mut [T; N]) -> &mut ColVector<T, N> {
|
||||
unsafe { std::mem::transmute::<&mut [T; N], &mut Self>(slice) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const N: usize> RowVector<T, N> {
|
||||
pub fn map_to_row(slice: &[T; N]) -> &Self {
|
||||
unsafe { std::mem::transmute::<&[T; N], &Self>(slice) }
|
||||
}
|
||||
pub fn map_to_row_mut(slice: &mut [T; N]) -> &mut Self {
|
||||
unsafe { std::mem::transmute::<&mut [T; N], &mut Self>(slice) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const M: usize, const N: usize> core::ops::MulAssign<&T> for Matrix<T, M, N>
|
||||
where
|
||||
for<'f> T: core::ops::MulAssign<&'f T>,
|
||||
{
|
||||
fn mul_assign(&mut self, other: &T) {
|
||||
self.iter_mut().for_each(|x| *x *= other)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -122,6 +174,137 @@ pub(crate) mod constmatrix {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) use constmatrix::{ColVector, Matrix, RowVector};
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
||||
block: &Matrix<Float, M, N>,
|
||||
diag: &RowVector<Float, N>,
|
||||
symmetry: Symmetry,
|
||||
optype: OperatorType,
|
||||
prev: ArrayView1<Float>,
|
||||
mut fut: ArrayViewMut1<Float>,
|
||||
) {
|
||||
assert_eq!(prev.shape(), fut.shape());
|
||||
let nx = prev.shape()[0];
|
||||
assert!(nx >= 2 * M);
|
||||
assert!(nx >= N);
|
||||
|
||||
let dx = if optype == OperatorType::H2 {
|
||||
1.0 / (nx - 2) as Float
|
||||
} else {
|
||||
1.0 / (nx - 1) as Float
|
||||
};
|
||||
let idx = 1.0 / dx;
|
||||
|
||||
for (bl, f) in block.iter_rows().zip(&mut fut) {
|
||||
let diff = bl
|
||||
.iter()
|
||||
.zip(prev.iter())
|
||||
.map(|(x, y)| x * y)
|
||||
.sum::<Float>();
|
||||
*f = diff * idx;
|
||||
}
|
||||
|
||||
// The window needs to be aligned to the diagonal elements,
|
||||
// based on the block size
|
||||
let window_elems_to_skip = M - ((D - 1) / 2);
|
||||
|
||||
for (window, f) in prev
|
||||
.windows(D)
|
||||
.into_iter()
|
||||
.skip(window_elems_to_skip)
|
||||
.zip(fut.iter_mut().skip(M))
|
||||
.take(nx - 2 * M)
|
||||
{
|
||||
let diff = diag.iter().zip(&window).map(|(x, y)| x * y).sum::<Float>();
|
||||
*f = diff * idx;
|
||||
}
|
||||
|
||||
for (bl, f) in block.iter_rows().zip(fut.iter_mut().rev()) {
|
||||
let diff = bl
|
||||
.iter()
|
||||
.zip(prev.iter().rev())
|
||||
.map(|(x, y)| x * y)
|
||||
.sum::<Float>();
|
||||
|
||||
*f = idx
|
||||
* if symmetry == Symmetry::Symmetric {
|
||||
diff
|
||||
} else {
|
||||
-diff
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: usize>(
|
||||
block: &Matrix<Float, M, N>,
|
||||
diag: &RowVector<Float, D>,
|
||||
symmetry: Symmetry,
|
||||
optype: OperatorType,
|
||||
prev: &[Float],
|
||||
fut: &mut [Float],
|
||||
) {
|
||||
assert_eq!(prev.len(), fut.len());
|
||||
let nx = prev.len();
|
||||
assert!(nx >= 2 * M);
|
||||
assert!(nx >= N);
|
||||
let prev = &prev[..nx];
|
||||
let fut = &mut fut[..nx];
|
||||
|
||||
let dx = if optype == OperatorType::H2 {
|
||||
1.0 / (nx - 2) as Float
|
||||
} else {
|
||||
1.0 / (nx - 1) as Float
|
||||
};
|
||||
let idx = 1.0 / dx;
|
||||
|
||||
use std::convert::TryInto;
|
||||
{
|
||||
let prev = ColVector::<_, N>::map_to_col((&prev[0..N]).try_into().unwrap());
|
||||
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap());
|
||||
|
||||
block.matmul_into(prev, fut);
|
||||
*fut *= &idx;
|
||||
}
|
||||
|
||||
// The window needs to be aligned to the diagonal elements,
|
||||
// based on the block size
|
||||
let window_elems_to_skip = M - ((D - 1) / 2);
|
||||
|
||||
for (window, f) in prev
|
||||
.array_windows::<D>()
|
||||
.skip(window_elems_to_skip)
|
||||
.zip(fut.iter_mut().skip(M))
|
||||
.take(nx - 2 * M)
|
||||
{
|
||||
let fut = ColVector::<Float, 1>::map_to_col_mut(unsafe {
|
||||
std::mem::transmute::<&mut Float, &mut [Float; 1]>(f)
|
||||
});
|
||||
let prev = ColVector::<_, D>::map_to_col(window);
|
||||
|
||||
diag.matmul_into(prev, fut);
|
||||
*fut *= &idx;
|
||||
}
|
||||
|
||||
let flipped = {
|
||||
let mut flipped = block.flip();
|
||||
if symmetry != Symmetry::Symmetric {
|
||||
flipped *= &-1.0;
|
||||
}
|
||||
flipped
|
||||
};
|
||||
|
||||
{
|
||||
let prev = ColVector::<_, N>::map_to_col((&prev[nx - N..]).try_into().unwrap());
|
||||
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());
|
||||
|
||||
flipped.matmul_into(prev, fut);
|
||||
*fut *= &idx;
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn diff_op_1d(
|
||||
block: &[&[Float]],
|
||||
|
|
Loading…
Reference in New Issue