From 74d99a4a181dd1d7f87882167f9c1c79f7074122 Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Mon, 1 Feb 2021 23:58:13 +0100 Subject: [PATCH] try ndarray transmute --- sbp/Cargo.toml | 1 + sbp/src/operators/algos.rs | 25 ++++++---- sbp/src/operators/algos/fastfloat.rs | 71 +++++++++++++++++++++++++++- 3 files changed, 87 insertions(+), 10 deletions(-) diff --git a/sbp/Cargo.toml b/sbp/Cargo.toml index 9b7834f..9cb7a50 100644 --- a/sbp/Cargo.toml +++ b/sbp/Cargo.toml @@ -11,6 +11,7 @@ packed_simd = { version = "0.3.3", package = "packed_simd_2" } rayon = { version = "1.3.0", optional = true } sprs = { version = "0.9.0", optional = true, default-features = false } serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] } +num-traits = "0.2.14" [features] # Use f32 as precision, default is f64 diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index 9d24e9d..a6485a8 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -1,5 +1,6 @@ use super::*; use ndarray::s; +use num_traits::Zero; pub(crate) mod constmatrix; pub(crate) use constmatrix::{flip_lr, flip_sign, flip_ud, ColVector, Matrix, RowVector}; @@ -194,14 +195,24 @@ pub(crate) fn diff_op_1d( } #[inline(always)] -#[allow(unused)] /// 2D diff fallback for when matrices are not slicable pub(crate) fn diff_op_2d_fallback( matrix: &BlockMatrix, optype: OperatorType, prev: ArrayView2, - mut fut: ArrayViewMut2, + fut: ArrayViewMut2, ) { + #[cfg(feature = "fast-float")] + let (matrix, prev, mut fut) = unsafe { + ( + std::mem::transmute::<_, &BlockMatrix>(matrix), + std::mem::transmute::<_, ArrayView2>(prev), + std::mem::transmute::<_, ArrayViewMut2>(fut), + ) + }; + #[cfg(not(feature = "fast-float"))] + let mut fut = fut; + assert_eq!(prev.shape(), fut.shape()); let nx = prev.shape()[1]; let ny = prev.shape()[0]; @@ -214,8 +225,7 @@ pub(crate) fn diff_op_2d_fallback for FastFloat { + type Output = FastFloat; + #[inline(always)] + fn mul(self, o: Float) -> Self::Output { + unsafe { Self(fmul_fast(self.0, o)) } + } +} + +impl core::ops::Mul for Float { + type Output = FastFloat; + #[inline(always)] + fn mul(self, o: FastFloat) -> Self::Output { + unsafe { FastFloat(fmul_fast(self, o.0)) } + } +} + +impl core::ops::Add for FastFloat { + type Output = FastFloat; #[inline(always)] fn add(self, o: FastFloat) -> Self::Output { unsafe { Self(fadd_fast(self.0, o.0)) } } } +impl core::ops::Add for FastFloat { + type Output = FastFloat; + #[inline(always)] + fn add(self, o: Float) -> Self::Output { + unsafe { Self(fadd_fast(self.0, o)) } + } +} + +impl core::ops::Add for Float { + type Output = FastFloat; + #[inline(always)] + fn add(self, o: FastFloat) -> Self::Output { + unsafe { FastFloat(fadd_fast(self, o.0)) } + } +} + +impl core::ops::Sub for FastFloat { + type Output = Self; + #[inline(always)] + fn sub(self, o: FastFloat) -> Self::Output { + unsafe { Self(fadd_fast(self.0, -o.0)) } + } +} + +impl core::ops::Div for FastFloat { + type Output = Self; + #[inline(always)] + fn div(self, o: FastFloat) -> Self::Output { + Self(self.0 / o.0) + } +} + impl core::ops::MulAssign for FastFloat { #[inline(always)] fn mul_assign(&mut self, o: FastFloat) { @@ -51,3 +99,22 @@ impl From for Float { f.0 } } + +mod numt { + use super::{FastFloat, Float}; + use num_traits::identities::{One, Zero}; + + impl One for FastFloat { + fn one() -> Self { + Self(Float::one()) + } + } + impl Zero for FastFloat { + fn zero() -> Self { + Self(Float::zero()) + } + fn is_zero(&self) -> bool { + self.0.is_zero() + } + } +}