minor thingys

This commit is contained in:
Magnus Ulimoen 2021-01-29 00:08:31 +01:00
parent db552af4ff
commit 4c2daf5933
2 changed files with 18 additions and 9 deletions

View File

@ -3,6 +3,7 @@ use super::*;
pub(crate) mod constmatrix { pub(crate) mod constmatrix {
/// A row-major matrix /// A row-major matrix
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct Matrix<T, const M: usize, const N: usize> { pub struct Matrix<T, const M: usize, const N: usize> {
data: [[T; N]; M], data: [[T; N]; M],
} }
@ -94,12 +95,8 @@ pub(crate) mod constmatrix {
} }
} }
} }
pub fn iter(&self) -> impl ExactSizeIterator<Item = &T> + DoubleEndedIterator<Item = &T> { pub fn iter(&self) -> impl Iterator<Item = &T> {
(0..N * M).map(move |x| { self.data.iter().flatten()
let i = x / N;
let j = x % N;
&self[(i, j)]
})
} }
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> { pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut T> {
self.data.iter_mut().flatten() self.data.iter_mut().flatten()
@ -352,7 +349,8 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
} else { } else {
1.0 / (nx - 1) as Float 1.0 / (nx - 1) as Float
}; };
let idx = FastFloat::from(1.0 / dx); let idx = 1.0 / dx;
let idx = FastFloat::from(idx);
use std::convert::TryInto; use std::convert::TryInto;
{ {
@ -382,7 +380,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
} }
{ {
let prev = prev.array_windows::<N>().last().unwrap(); let prev = prev.array_windows::<N>().next_back().unwrap();
let prev = ColVector::<_, N>::map_to_col(prev); let prev = ColVector::<_, N>::map_to_col(prev);
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());

View File

@ -92,7 +92,7 @@ impl SbpOperator1d for SBP4 {
fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayViewMut2<Float>) { fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayViewMut2<Float>) {
// Magic two lines that prevents or enables optimisation // Magic two lines that prevents or enables optimisation
// (doubles instructions when not included) // (doubles instructions when not included)
let mut flipmatrix = SBP4::BLOCK_MATRIX.flip(); let mut flipmatrix = SBP4::BLOCK_MATRIX;
flipmatrix *= &-1.0; flipmatrix *= &-1.0;
for (p, mut f) in prev for (p, mut f) in prev
@ -228,3 +228,14 @@ fn test_trad4() {
1e-1, 1e-1,
); );
} }
#[test]
fn block_equality() {
let mut flipped_inverted = SBP4::BLOCK_MATRIX.flip();
flipped_inverted *= &-1.0;
assert!(flipped_inverted
.iter()
.zip(SBP4::BLOCKEND_MATRIX.iter())
.all(|(x, y)| (x - y).abs() < 1e-3))
}