diff --git a/sbp/src/lib.rs b/sbp/src/lib.rs index 1d09854..62cb872 100644 --- a/sbp/src/lib.rs +++ b/sbp/src/lib.rs @@ -1,4 +1,4 @@ -#![feature(core_intrinsics)] +#![cfg_attr(feature = "fast-float", feature(core_intrinsics))] #![feature(array_windows)] #![feature(array_chunks)] diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index b32425c..11b9923 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -353,6 +353,12 @@ mod fastfloat { Self(f) } } + impl From for Float { + #[inline(always)] + fn from(f: FastFloat) -> Self { + f.0 + } + } } #[cfg(feature = "fast-float")] use fastfloat::FastFloat; @@ -838,14 +844,17 @@ pub(crate) fn diff_op_col_simd( } #[inline(always)] -fn product_fast<'a>( - u: impl Iterator, - v: impl Iterator, -) -> Float { - use std::intrinsics::{fadd_fast, fmul_fast}; - u.zip(v).fold(0.0, |acc, (&u, &v)| unsafe { - // We do not care about the order of multiplication nor addition - fadd_fast(acc, fmul_fast(u, v)) +fn dotproduct<'a>(u: impl Iterator, v: impl Iterator) -> Float { + u.zip(v).fold(0.0, |acc, (&u, &v)| { + #[cfg(feature = "fast-float")] + unsafe { + // We do not care about the order of multiplication nor addition + (FastFloat::from(acc) + FastFloat::from(u) * FastFloat::from(v)).into() + } + #[cfg(not(feature = "fast-float"))] + { + acc + u * v + } }) } @@ -1043,7 +1052,7 @@ pub(crate) fn diff_op_row( assert!(prev.len() >= 2 * block.len()); for (bl, f) in block.iter().zip(fut.iter_mut()) { - let diff = product_fast(bl.iter(), prev[..bl.len()].iter()); + let diff = dotproduct(bl.iter(), prev[..bl.len()].iter()); *f = diff * idx; } @@ -1057,12 +1066,12 @@ pub(crate) fn diff_op_row( .zip(fut.iter_mut().skip(block.len())) .take(nx - 2 * block.len()) { - let diff = product_fast(diag.iter(), window.iter()); + let diff = dotproduct(diag.iter(), window.iter()); *f = diff * idx; } for (bl, f) in block.iter().zip(fut.iter_mut().rev()) { - let diff = product_fast(bl.iter(), prev.iter().rev()); + let diff = dotproduct(bl.iter(), prev.iter().rev()); *f = idx * if symmetry == Symmetry::Symmetric {