Replace FastFloat with mul_add
This commit is contained in:
		@@ -4,9 +4,6 @@ use num_traits::Zero;
 | 
			
		||||
 | 
			
		||||
pub(crate) use constmatrix::{ColVector, Matrix, RowVector};
 | 
			
		||||
 | 
			
		||||
#[cfg(feature = "fast-float")]
 | 
			
		||||
use float::FastFloat;
 | 
			
		||||
 | 
			
		||||
#[derive(Clone, Debug, PartialEq)]
 | 
			
		||||
pub(crate) struct DiagonalMatrix<const B: usize> {
 | 
			
		||||
    pub start: [Float; B],
 | 
			
		||||
@@ -105,17 +102,14 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
    prev: &[Float],
 | 
			
		||||
    fut: &mut [Float],
 | 
			
		||||
) {
 | 
			
		||||
    #[cfg(feature = "fast-float")]
 | 
			
		||||
    let (matrix, prev, fut) = {
 | 
			
		||||
        use std::mem::transmute;
 | 
			
		||||
        unsafe {
 | 
			
		||||
            (
 | 
			
		||||
                transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
 | 
			
		||||
                transmute::<_, &[FastFloat]>(prev),
 | 
			
		||||
                transmute::<_, &mut [FastFloat]>(fut),
 | 
			
		||||
            )
 | 
			
		||||
        }
 | 
			
		||||
    };
 | 
			
		||||
    #[inline(never)]
 | 
			
		||||
    fn dedup_matmul<const M: usize, const N: usize>(
 | 
			
		||||
        c: &mut ColVector<Float, M>,
 | 
			
		||||
        a: &Matrix<Float, M, N>,
 | 
			
		||||
        b: &ColVector<Float, N>,
 | 
			
		||||
    ) {
 | 
			
		||||
        c.matmul_float_into(a, b)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    assert_eq!(prev.len(), fut.len());
 | 
			
		||||
    let nx = prev.len();
 | 
			
		||||
@@ -130,8 +124,6 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
        1.0 / (nx - 1) as Float
 | 
			
		||||
    };
 | 
			
		||||
    let idx = 1.0 / dx;
 | 
			
		||||
    #[cfg(feature = "fast-float")]
 | 
			
		||||
    let idx = FastFloat::from(idx);
 | 
			
		||||
 | 
			
		||||
    // Help aliasing analysis
 | 
			
		||||
    let (futb1, fut) = fut.split_at_mut(M);
 | 
			
		||||
@@ -142,7 +134,7 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
        let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap());
 | 
			
		||||
        let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap());
 | 
			
		||||
 | 
			
		||||
        fut.matmul_into(&matrix.start, prev);
 | 
			
		||||
        dedup_matmul(fut, &matrix.start, prev);
 | 
			
		||||
        *fut *= idx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -158,7 +150,7 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
        let fut = ColVector::<_, 1>::map_to_col_mut(f);
 | 
			
		||||
        let prev = ColVector::<_, D>::map_to_col(window);
 | 
			
		||||
 | 
			
		||||
        fut.matmul_into(&matrix.diag, prev);
 | 
			
		||||
        fut.matmul_float_into(&matrix.diag, prev);
 | 
			
		||||
        *fut *= idx;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -167,7 +159,7 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
 | 
			
		||||
        let prev = ColVector::<_, N>::map_to_col(prev);
 | 
			
		||||
        let fut = ColVector::<_, M>::map_to_col_mut(futb2.try_into().unwrap());
 | 
			
		||||
 | 
			
		||||
        fut.matmul_into(&matrix.end, prev);
 | 
			
		||||
        dedup_matmul(fut, &matrix.end, prev);
 | 
			
		||||
        *fut *= idx;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -199,19 +191,6 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
 | 
			
		||||
    prev: ArrayView2<Float>,
 | 
			
		||||
    mut fut: ArrayViewMut2<Float>,
 | 
			
		||||
) {
 | 
			
		||||
    /* Does not increase the perf...
 | 
			
		||||
    #[cfg(feature = "fast-float")]
 | 
			
		||||
    let (matrix, prev, mut fut) = unsafe {
 | 
			
		||||
        (
 | 
			
		||||
            std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
 | 
			
		||||
            std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
 | 
			
		||||
            std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
 | 
			
		||||
        )
 | 
			
		||||
    };
 | 
			
		||||
    #[cfg(not(feature = "fast-float"))]
 | 
			
		||||
    let mut fut = fut;
 | 
			
		||||
    */
 | 
			
		||||
 | 
			
		||||
    assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
    let nx = prev.shape()[1];
 | 
			
		||||
    let ny = prev.shape()[0];
 | 
			
		||||
@@ -287,19 +266,6 @@ pub(crate) fn diff_op_2d_sliceable_y<const M: usize, const N: usize, const D: us
 | 
			
		||||
    prev: ArrayView2<Float>,
 | 
			
		||||
    mut fut: ArrayViewMut2<Float>,
 | 
			
		||||
) {
 | 
			
		||||
    /* Does not increase the perf...
 | 
			
		||||
    #[cfg(feature = "fast-float")]
 | 
			
		||||
    let (matrix, prev, mut fut) = unsafe {
 | 
			
		||||
        (
 | 
			
		||||
            std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
 | 
			
		||||
            std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
 | 
			
		||||
            std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
 | 
			
		||||
        )
 | 
			
		||||
    };
 | 
			
		||||
    #[cfg(not(feature = "fast-float"))]
 | 
			
		||||
    let mut fut = fut;
 | 
			
		||||
    */
 | 
			
		||||
 | 
			
		||||
    assert_eq!(prev.shape(), fut.shape());
 | 
			
		||||
    let nx = prev.shape()[1];
 | 
			
		||||
    let ny = prev.shape()[0];
 | 
			
		||||
@@ -733,17 +699,9 @@ fn dotproduct<'a>(
 | 
			
		||||
    u: impl IntoIterator<Item = &'a Float>,
 | 
			
		||||
    v: impl IntoIterator<Item = &'a Float>,
 | 
			
		||||
) -> Float {
 | 
			
		||||
    u.into_iter().zip(v.into_iter()).fold(0.0, |acc, (&u, &v)| {
 | 
			
		||||
        #[cfg(feature = "fast-float")]
 | 
			
		||||
        {
 | 
			
		||||
            // We do not care about the order of multiplication nor addition
 | 
			
		||||
            (FastFloat::from(acc) + FastFloat::from(u) * FastFloat::from(v)).into()
 | 
			
		||||
        }
 | 
			
		||||
        #[cfg(not(feature = "fast-float"))]
 | 
			
		||||
        {
 | 
			
		||||
            acc + u * v
 | 
			
		||||
        }
 | 
			
		||||
    })
 | 
			
		||||
    u.into_iter()
 | 
			
		||||
        .zip(v.into_iter())
 | 
			
		||||
        .fold(0.0, |acc, (&u, &v)| Float::mul_add(u, v, acc))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#[cfg(feature = "sparse")]
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user