10% instr reduction with fast_ intr
This commit is contained in:
parent
30c563c19d
commit
36293e75e6
|
@ -59,9 +59,11 @@ pub(crate) mod constmatrix {
|
|||
pub const fn new(data: [[T; N]; M]) -> Self {
|
||||
Self { data }
|
||||
}
|
||||
#[inline]
|
||||
pub const fn nrows(&self) -> usize {
|
||||
M
|
||||
}
|
||||
#[inline]
|
||||
pub const fn ncols(&self) -> usize {
|
||||
N
|
||||
}
|
||||
|
@ -82,12 +84,13 @@ pub(crate) mod constmatrix {
|
|||
T: Default + core::ops::AddAssign<T>,
|
||||
for<'f> &'f T: std::ops::Mul<Output = T>,
|
||||
{
|
||||
*out = Default::default();
|
||||
for i in 0..M {
|
||||
for j in 0..P {
|
||||
let mut t = T::default();
|
||||
for k in 0..N {
|
||||
out[(i, j)] += &self[(i, k)] * &other[(k, j)];
|
||||
t += &self[(i, k)] * &other[(k, j)];
|
||||
}
|
||||
out[(i, j)] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -148,6 +151,34 @@ pub(crate) mod constmatrix {
|
|||
}
|
||||
}
|
||||
|
||||
use super::Float;
|
||||
impl<const M: usize, const N: usize> Matrix<Float, M, N> {
|
||||
pub fn matmul_fast<const P: usize>(
|
||||
&self,
|
||||
other: &Matrix<Float, N, P>,
|
||||
) -> Matrix<Float, M, P> {
|
||||
let mut out = Matrix::default();
|
||||
self.matmul_into_fast(other, &mut out);
|
||||
out
|
||||
}
|
||||
pub fn matmul_into_fast<const P: usize>(
|
||||
&self,
|
||||
other: &Matrix<Float, N, P>,
|
||||
out: &mut Matrix<Float, M, P>,
|
||||
) {
|
||||
use core::intrinsics::{fadd_fast, fmul_fast};
|
||||
for i in 0..M {
|
||||
for j in 0..P {
|
||||
let mut t = 0.0;
|
||||
for k in 0..N {
|
||||
t = unsafe { fadd_fast(t, fmul_fast(self[(i, k)], other[(k, j)])) }
|
||||
}
|
||||
out[(i, j)] = t;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{super::*, *};
|
||||
|
@ -240,8 +271,8 @@ pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
|||
#[inline(always)]
|
||||
pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: usize>(
|
||||
block: &Matrix<Float, M, N>,
|
||||
endblock: &Matrix<Float, M, N>,
|
||||
diag: &RowVector<Float, D>,
|
||||
symmetry: Symmetry,
|
||||
optype: OperatorType,
|
||||
prev: &[Float],
|
||||
fut: &mut [Float],
|
||||
|
@ -262,10 +293,10 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
|
|||
|
||||
use std::convert::TryInto;
|
||||
{
|
||||
let prev = ColVector::<_, N>::map_to_col((&prev[0..N]).try_into().unwrap());
|
||||
let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().nth(0).unwrap());
|
||||
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap());
|
||||
|
||||
block.matmul_into(prev, fut);
|
||||
block.matmul_into_fast(prev, fut);
|
||||
*fut *= &idx;
|
||||
}
|
||||
|
||||
|
@ -283,24 +314,16 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
|
|||
let fut = ColVector::<Float, 1>::map_to_col_mut(f);
|
||||
let prev = ColVector::<_, D>::map_to_col(window);
|
||||
|
||||
diag.matmul_into(prev, fut);
|
||||
diag.matmul_into_fast(prev, fut);
|
||||
*fut *= &idx;
|
||||
}
|
||||
|
||||
let flipped = {
|
||||
let mut flipped = block.flip();
|
||||
if symmetry != Symmetry::Symmetric {
|
||||
flipped *= &-1.0;
|
||||
}
|
||||
flipped
|
||||
};
|
||||
|
||||
{
|
||||
let prev = prev.array_windows::<N>().last().unwrap();
|
||||
let prev = ColVector::<_, N>::map_to_col(prev);
|
||||
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());
|
||||
|
||||
flipped.matmul_into(prev, fut);
|
||||
endblock.matmul_into_fast(prev, fut);
|
||||
*fut *= &idx;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -83,14 +83,16 @@ impl SbpOperator1d for SBP4 {
|
|||
}
|
||||
|
||||
fn diff_op_row_local(prev: ndarray::ArrayView2<Float>, mut fut: ndarray::ArrayViewMut2<Float>) {
|
||||
let mut flipmatrix = SBP4::BLOCK_MATRIX.flip();
|
||||
flipmatrix *= &-1.0;
|
||||
for (p, mut f) in prev
|
||||
.axis_iter(ndarray::Axis(0))
|
||||
.zip(fut.axis_iter_mut(ndarray::Axis(0)))
|
||||
{
|
||||
super::diff_op_1d_slice_matrix(
|
||||
&SBP4::BLOCK_MATRIX,
|
||||
&flipmatrix,
|
||||
&SBP4::DIAG_MATRIX,
|
||||
super::Symmetry::AntiSymmetric,
|
||||
super::OperatorType::Normal,
|
||||
p.as_slice().unwrap(),
|
||||
f.as_slice_mut().unwrap(),
|
||||
|
|
Loading…
Reference in New Issue