checkpoint
This commit is contained in:
parent
c104082ac0
commit
94e8fb5b7c
|
@ -1,4 +1,5 @@
|
||||||
#![feature(core_intrinsics)]
|
#![feature(core_intrinsics)]
|
||||||
|
#![feature(array_windows)]
|
||||||
|
|
||||||
/// Type used for floats, configure with the `f32` feature
|
/// Type used for floats, configure with the `f32` feature
|
||||||
#[cfg(feature = "f32")]
|
#[cfg(feature = "f32")]
|
||||||
|
|
|
@ -6,6 +6,8 @@ pub(crate) mod constmatrix {
|
||||||
pub struct Matrix<T, const M: usize, const N: usize> {
|
pub struct Matrix<T, const M: usize, const N: usize> {
|
||||||
data: [[T; N]; M],
|
data: [[T; N]; M],
|
||||||
}
|
}
|
||||||
|
pub type RowVector<T, const N: usize> = Matrix<T, 1, N>;
|
||||||
|
pub type ColVector<T, const N: usize> = Matrix<T, N, 1>;
|
||||||
|
|
||||||
impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
|
@ -89,11 +91,61 @@ pub(crate) mod constmatrix {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + DoubleEndedIterator<Item = &T> {
|
||||||
|
(0..N * M).map(move |x| {
|
||||||
|
let i = x / N;
|
||||||
|
let j = x % N;
|
||||||
|
&self[(i, j)]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
|
||||||
|
self.data.iter_mut().flatten()
|
||||||
|
}
|
||||||
pub fn iter_rows(
|
pub fn iter_rows(
|
||||||
&self,
|
&self,
|
||||||
) -> impl ExactSizeIterator<Item = &[T; N]> + DoubleEndedIterator<Item = &[T; N]> {
|
) -> impl ExactSizeIterator<Item = &[T; N]> + DoubleEndedIterator<Item = &[T; N]> {
|
||||||
(0..M).map(move |i| &self[i])
|
(0..M).map(move |i| &self[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn flip(&self) -> Self
|
||||||
|
where
|
||||||
|
T: Default + Clone,
|
||||||
|
{
|
||||||
|
let mut v = Self::default();
|
||||||
|
for i in 0..M {
|
||||||
|
for j in 0..N {
|
||||||
|
v[(i, j)] = self[(N - 1 - i, M - 1 - j)].clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const N: usize> ColVector<T, N> {
|
||||||
|
pub fn map_to_col(slice: &[T; N]) -> &ColVector<T, N> {
|
||||||
|
unsafe { std::mem::transmute::<&[T; N], &Self>(slice) }
|
||||||
|
}
|
||||||
|
pub fn map_to_col_mut(slice: &mut [T; N]) -> &mut ColVector<T, N> {
|
||||||
|
unsafe { std::mem::transmute::<&mut [T; N], &mut Self>(slice) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const N: usize> RowVector<T, N> {
|
||||||
|
pub fn map_to_row(slice: &[T; N]) -> &Self {
|
||||||
|
unsafe { std::mem::transmute::<&[T; N], &Self>(slice) }
|
||||||
|
}
|
||||||
|
pub fn map_to_row_mut(slice: &mut [T; N]) -> &mut Self {
|
||||||
|
unsafe { std::mem::transmute::<&mut [T; N], &mut Self>(slice) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const M: usize, const N: usize> core::ops::MulAssign<&T> for Matrix<T, M, N>
|
||||||
|
where
|
||||||
|
for<'f> T: core::ops::MulAssign<&'f T>,
|
||||||
|
{
|
||||||
|
fn mul_assign(&mut self, other: &T) {
|
||||||
|
self.iter_mut().for_each(|x| *x *= other)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
@ -122,6 +174,137 @@ pub(crate) mod constmatrix {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) use constmatrix::{ColVector, Matrix, RowVector};
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
||||||
|
block: &Matrix<Float, M, N>,
|
||||||
|
diag: &RowVector<Float, N>,
|
||||||
|
symmetry: Symmetry,
|
||||||
|
optype: OperatorType,
|
||||||
|
prev: ArrayView1<Float>,
|
||||||
|
mut fut: ArrayViewMut1<Float>,
|
||||||
|
) {
|
||||||
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
|
let nx = prev.shape()[0];
|
||||||
|
assert!(nx >= 2 * M);
|
||||||
|
assert!(nx >= N);
|
||||||
|
|
||||||
|
let dx = if optype == OperatorType::H2 {
|
||||||
|
1.0 / (nx - 2) as Float
|
||||||
|
} else {
|
||||||
|
1.0 / (nx - 1) as Float
|
||||||
|
};
|
||||||
|
let idx = 1.0 / dx;
|
||||||
|
|
||||||
|
for (bl, f) in block.iter_rows().zip(&mut fut) {
|
||||||
|
let diff = bl
|
||||||
|
.iter()
|
||||||
|
.zip(prev.iter())
|
||||||
|
.map(|(x, y)| x * y)
|
||||||
|
.sum::<Float>();
|
||||||
|
*f = diff * idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The window needs to be aligned to the diagonal elements,
|
||||||
|
// based on the block size
|
||||||
|
let window_elems_to_skip = M - ((D - 1) / 2);
|
||||||
|
|
||||||
|
for (window, f) in prev
|
||||||
|
.windows(D)
|
||||||
|
.into_iter()
|
||||||
|
.skip(window_elems_to_skip)
|
||||||
|
.zip(fut.iter_mut().skip(M))
|
||||||
|
.take(nx - 2 * M)
|
||||||
|
{
|
||||||
|
let diff = diag.iter().zip(&window).map(|(x, y)| x * y).sum::<Float>();
|
||||||
|
*f = diff * idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (bl, f) in block.iter_rows().zip(fut.iter_mut().rev()) {
|
||||||
|
let diff = bl
|
||||||
|
.iter()
|
||||||
|
.zip(prev.iter().rev())
|
||||||
|
.map(|(x, y)| x * y)
|
||||||
|
.sum::<Float>();
|
||||||
|
|
||||||
|
*f = idx
|
||||||
|
* if symmetry == Symmetry::Symmetric {
|
||||||
|
diff
|
||||||
|
} else {
|
||||||
|
-diff
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
|
pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: usize>(
|
||||||
|
block: &Matrix<Float, M, N>,
|
||||||
|
diag: &RowVector<Float, D>,
|
||||||
|
symmetry: Symmetry,
|
||||||
|
optype: OperatorType,
|
||||||
|
prev: &[Float],
|
||||||
|
fut: &mut [Float],
|
||||||
|
) {
|
||||||
|
assert_eq!(prev.len(), fut.len());
|
||||||
|
let nx = prev.len();
|
||||||
|
assert!(nx >= 2 * M);
|
||||||
|
assert!(nx >= N);
|
||||||
|
let prev = &prev[..nx];
|
||||||
|
let fut = &mut fut[..nx];
|
||||||
|
|
||||||
|
let dx = if optype == OperatorType::H2 {
|
||||||
|
1.0 / (nx - 2) as Float
|
||||||
|
} else {
|
||||||
|
1.0 / (nx - 1) as Float
|
||||||
|
};
|
||||||
|
let idx = 1.0 / dx;
|
||||||
|
|
||||||
|
use std::convert::TryInto;
|
||||||
|
{
|
||||||
|
let prev = ColVector::<_, N>::map_to_col((&prev[0..N]).try_into().unwrap());
|
||||||
|
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap());
|
||||||
|
|
||||||
|
block.matmul_into(prev, fut);
|
||||||
|
*fut *= &idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The window needs to be aligned to the diagonal elements,
|
||||||
|
// based on the block size
|
||||||
|
let window_elems_to_skip = M - ((D - 1) / 2);
|
||||||
|
|
||||||
|
for (window, f) in prev
|
||||||
|
.array_windows::<D>()
|
||||||
|
.skip(window_elems_to_skip)
|
||||||
|
.zip(fut.iter_mut().skip(M))
|
||||||
|
.take(nx - 2 * M)
|
||||||
|
{
|
||||||
|
let fut = ColVector::<Float, 1>::map_to_col_mut(unsafe {
|
||||||
|
std::mem::transmute::<&mut Float, &mut [Float; 1]>(f)
|
||||||
|
});
|
||||||
|
let prev = ColVector::<_, D>::map_to_col(window);
|
||||||
|
|
||||||
|
diag.matmul_into(prev, fut);
|
||||||
|
*fut *= &idx;
|
||||||
|
}
|
||||||
|
|
||||||
|
let flipped = {
|
||||||
|
let mut flipped = block.flip();
|
||||||
|
if symmetry != Symmetry::Symmetric {
|
||||||
|
flipped *= &-1.0;
|
||||||
|
}
|
||||||
|
flipped
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
let prev = ColVector::<_, N>::map_to_col((&prev[nx - N..]).try_into().unwrap());
|
||||||
|
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());
|
||||||
|
|
||||||
|
flipped.matmul_into(prev, fut);
|
||||||
|
*fut *= &idx;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn diff_op_1d(
|
pub(crate) fn diff_op_1d(
|
||||||
block: &[&[Float]],
|
block: &[&[Float]],
|
||||||
|
|
Loading…
Reference in New Issue