Replace FastFloat with mul_add

This commit is contained in:
Magnus Ulimoen 2021-03-23 19:21:38 +01:00
parent df05c06270
commit 4ae5c02bb1
4 changed files with 16 additions and 59 deletions

View File

@ -6,7 +6,7 @@ edition = "2018"
[dependencies] [dependencies]
sbp = { path = "../sbp", features = ["serde1", "fast-float"] } sbp = { path = "../sbp", features = ["serde1"] }
euler = { path = "../euler", features = ["serde1"] } euler = { path = "../euler", features = ["serde1"] }
hdf5 = "0.7.0" hdf5 = "0.7.0"
integrate = { path = "../utils/integrate" } integrate = { path = "../utils/integrate" }

View File

@ -17,7 +17,6 @@ constmatrix = { path = "../utils/constmatrix" }
[features] [features]
# Use f32 as precision, default is f64 # Use f32 as precision, default is f64
f32 = ["float/f32"] f32 = ["float/f32"]
fast-float = ["float/fast-float"]
sparse = ["sprs"] sparse = ["sprs"]
serde1 = ["serde", "ndarray/serde"] serde1 = ["serde", "ndarray/serde"]

View File

@ -4,9 +4,6 @@ use num_traits::Zero;
pub(crate) use constmatrix::{ColVector, Matrix, RowVector}; pub(crate) use constmatrix::{ColVector, Matrix, RowVector};
#[cfg(feature = "fast-float")]
use float::FastFloat;
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub(crate) struct DiagonalMatrix<const B: usize> { pub(crate) struct DiagonalMatrix<const B: usize> {
pub start: [Float; B], pub start: [Float; B],
@ -105,17 +102,14 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
prev: &[Float], prev: &[Float],
fut: &mut [Float], fut: &mut [Float],
) { ) {
#[cfg(feature = "fast-float")] #[inline(never)]
let (matrix, prev, fut) = { fn dedup_matmul<const M: usize, const N: usize>(
use std::mem::transmute; c: &mut ColVector<Float, M>,
unsafe { a: &Matrix<Float, M, N>,
( b: &ColVector<Float, N>,
transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix), ) {
transmute::<_, &[FastFloat]>(prev), c.matmul_float_into(a, b)
transmute::<_, &mut [FastFloat]>(fut), }
)
}
};
assert_eq!(prev.len(), fut.len()); assert_eq!(prev.len(), fut.len());
let nx = prev.len(); let nx = prev.len();
@ -130,8 +124,6 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
1.0 / (nx - 1) as Float 1.0 / (nx - 1) as Float
}; };
let idx = 1.0 / dx; let idx = 1.0 / dx;
#[cfg(feature = "fast-float")]
let idx = FastFloat::from(idx);
// Help aliasing analysis // Help aliasing analysis
let (futb1, fut) = fut.split_at_mut(M); let (futb1, fut) = fut.split_at_mut(M);
@ -142,7 +134,7 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap()); let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap());
let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap());
fut.matmul_into(&matrix.start, prev); dedup_matmul(fut, &matrix.start, prev);
*fut *= idx; *fut *= idx;
} }
@ -158,7 +150,7 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
let fut = ColVector::<_, 1>::map_to_col_mut(f); let fut = ColVector::<_, 1>::map_to_col_mut(f);
let prev = ColVector::<_, D>::map_to_col(window); let prev = ColVector::<_, D>::map_to_col(window);
fut.matmul_into(&matrix.diag, prev); fut.matmul_float_into(&matrix.diag, prev);
*fut *= idx; *fut *= idx;
} }
@ -167,7 +159,7 @@ pub(crate) fn diff_op_1d_slice<const M: usize, const N: usize, const D: usize>(
let prev = ColVector::<_, N>::map_to_col(prev); let prev = ColVector::<_, N>::map_to_col(prev);
let fut = ColVector::<_, M>::map_to_col_mut(futb2.try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut(futb2.try_into().unwrap());
fut.matmul_into(&matrix.end, prev); dedup_matmul(fut, &matrix.end, prev);
*fut *= idx; *fut *= idx;
} }
} }
@ -199,19 +191,6 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
prev: ArrayView2<Float>, prev: ArrayView2<Float>,
mut fut: ArrayViewMut2<Float>, mut fut: ArrayViewMut2<Float>,
) { ) {
/* Does not increase the perf...
#[cfg(feature = "fast-float")]
let (matrix, prev, mut fut) = unsafe {
(
std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
)
};
#[cfg(not(feature = "fast-float"))]
let mut fut = fut;
*/
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
let nx = prev.shape()[1]; let nx = prev.shape()[1];
let ny = prev.shape()[0]; let ny = prev.shape()[0];
@ -287,19 +266,6 @@ pub(crate) fn diff_op_2d_sliceable_y<const M: usize, const N: usize, const D: us
prev: ArrayView2<Float>, prev: ArrayView2<Float>,
mut fut: ArrayViewMut2<Float>, mut fut: ArrayViewMut2<Float>,
) { ) {
/* Does not increase the perf...
#[cfg(feature = "fast-float")]
let (matrix, prev, mut fut) = unsafe {
(
std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
)
};
#[cfg(not(feature = "fast-float"))]
let mut fut = fut;
*/
assert_eq!(prev.shape(), fut.shape()); assert_eq!(prev.shape(), fut.shape());
let nx = prev.shape()[1]; let nx = prev.shape()[1];
let ny = prev.shape()[0]; let ny = prev.shape()[0];
@ -733,17 +699,9 @@ fn dotproduct<'a>(
u: impl IntoIterator<Item = &'a Float>, u: impl IntoIterator<Item = &'a Float>,
v: impl IntoIterator<Item = &'a Float>, v: impl IntoIterator<Item = &'a Float>,
) -> Float { ) -> Float {
u.into_iter().zip(v.into_iter()).fold(0.0, |acc, (&u, &v)| { u.into_iter()
#[cfg(feature = "fast-float")] .zip(v.into_iter())
{ .fold(0.0, |acc, (&u, &v)| Float::mul_add(u, v, acc))
// We do not care about the order of multiplication nor addition
(FastFloat::from(acc) + FastFloat::from(u) * FastFloat::from(v)).into()
}
#[cfg(not(feature = "fast-float"))]
{
acc + u * v
}
})
} }
#[cfg(feature = "sparse")] #[cfg(feature = "sparse")]

View File

@ -11,7 +11,7 @@ crate-type = ["cdylib"]
wasm-bindgen = "0.2.63" wasm-bindgen = "0.2.63"
console_error_panic_hook = "0.1.6" console_error_panic_hook = "0.1.6"
wee_alloc = "0.4.5" wee_alloc = "0.4.5"
sbp = { path = "../sbp", features = ["f32", "fast-float"] } sbp = { path = "../sbp", features = ["f32"] }
ndarray = "0.14.0" ndarray = "0.14.0"
euler = { path = "../euler" } euler = { path = "../euler" }
maxwell = { path = "../maxwell" } maxwell = { path = "../maxwell" }