try ndarray transmute
This commit is contained in:
parent
b15ea57e6d
commit
74d99a4a18
|
@ -11,6 +11,7 @@ packed_simd = { version = "0.3.3", package = "packed_simd_2" }
|
|||
rayon = { version = "1.3.0", optional = true }
|
||||
sprs = { version = "0.9.0", optional = true, default-features = false }
|
||||
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
||||
num-traits = "0.2.14"
|
||||
|
||||
[features]
|
||||
# Use f32 as precision, default is f64
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::*;
|
||||
use ndarray::s;
|
||||
use num_traits::Zero;
|
||||
|
||||
pub(crate) mod constmatrix;
|
||||
pub(crate) use constmatrix::{flip_lr, flip_sign, flip_ud, ColVector, Matrix, RowVector};
|
||||
|
@ -194,14 +195,24 @@ pub(crate) fn diff_op_1d<const M: usize, const N: usize, const D: usize>(
|
|||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(unused)]
|
||||
/// 2D diff fallback for when matrices are not slicable
|
||||
pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize>(
|
||||
matrix: &BlockMatrix<Float, M, N, D>,
|
||||
optype: OperatorType,
|
||||
prev: ArrayView2<Float>,
|
||||
mut fut: ArrayViewMut2<Float>,
|
||||
fut: ArrayViewMut2<Float>,
|
||||
) {
|
||||
#[cfg(feature = "fast-float")]
|
||||
let (matrix, prev, mut fut) = unsafe {
|
||||
(
|
||||
std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
|
||||
std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
|
||||
std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
|
||||
)
|
||||
};
|
||||
#[cfg(not(feature = "fast-float"))]
|
||||
let mut fut = fut;
|
||||
|
||||
assert_eq!(prev.shape(), fut.shape());
|
||||
let nx = prev.shape()[1];
|
||||
let ny = prev.shape()[0];
|
||||
|
@ -214,8 +225,7 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
|||
};
|
||||
let idx = 1.0 / dx;
|
||||
|
||||
fut.fill(0.0);
|
||||
|
||||
fut.fill(0.0.into());
|
||||
let (mut fut0, mut futmid, mut futn) = fut.multi_slice_mut((
|
||||
ndarray::s![.., ..M],
|
||||
ndarray::s![.., M..nx - M],
|
||||
|
@ -230,7 +240,7 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
|||
{
|
||||
debug_assert_eq!(fut.len(), prev.shape()[0]);
|
||||
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
|
||||
if bl == 0.0 {
|
||||
if bl.is_zero() {
|
||||
continue;
|
||||
}
|
||||
debug_assert_eq!(prev.len(), fut.len());
|
||||
|
@ -246,7 +256,7 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
|||
.zip(prev.windows((ny, D)).into_iter().skip(window_elems_to_skip))
|
||||
{
|
||||
for (&d, id) in matrix.diag.iter().zip(id.axis_iter(ndarray::Axis(1))) {
|
||||
if d == 0.0 {
|
||||
if d.is_zero() {
|
||||
continue;
|
||||
}
|
||||
fut.scaled_add(idx * d, &id)
|
||||
|
@ -260,9 +270,8 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
|||
.iter_rows()
|
||||
.zip(futn.axis_iter_mut(ndarray::Axis(1)))
|
||||
{
|
||||
fut.fill(0.0);
|
||||
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
|
||||
if bl == 0.0 {
|
||||
if bl.is_zero() {
|
||||
continue;
|
||||
}
|
||||
fut.scaled_add(idx * bl, &prev);
|
||||
|
|
|
@ -14,14 +14,62 @@ impl core::ops::Mul for FastFloat {
|
|||
}
|
||||
}
|
||||
|
||||
impl core::ops::Add for FastFloat {
|
||||
type Output = Self;
|
||||
impl core::ops::Mul<Float> for FastFloat {
|
||||
type Output = FastFloat;
|
||||
#[inline(always)]
|
||||
fn mul(self, o: Float) -> Self::Output {
|
||||
unsafe { Self(fmul_fast(self.0, o)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Mul<FastFloat> for Float {
|
||||
type Output = FastFloat;
|
||||
#[inline(always)]
|
||||
fn mul(self, o: FastFloat) -> Self::Output {
|
||||
unsafe { FastFloat(fmul_fast(self, o.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Add<FastFloat> for FastFloat {
|
||||
type Output = FastFloat;
|
||||
#[inline(always)]
|
||||
fn add(self, o: FastFloat) -> Self::Output {
|
||||
unsafe { Self(fadd_fast(self.0, o.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Add<Float> for FastFloat {
|
||||
type Output = FastFloat;
|
||||
#[inline(always)]
|
||||
fn add(self, o: Float) -> Self::Output {
|
||||
unsafe { Self(fadd_fast(self.0, o)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Add<FastFloat> for Float {
|
||||
type Output = FastFloat;
|
||||
#[inline(always)]
|
||||
fn add(self, o: FastFloat) -> Self::Output {
|
||||
unsafe { FastFloat(fadd_fast(self, o.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Sub for FastFloat {
|
||||
type Output = Self;
|
||||
#[inline(always)]
|
||||
fn sub(self, o: FastFloat) -> Self::Output {
|
||||
unsafe { Self(fadd_fast(self.0, -o.0)) }
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Div for FastFloat {
|
||||
type Output = Self;
|
||||
#[inline(always)]
|
||||
fn div(self, o: FastFloat) -> Self::Output {
|
||||
Self(self.0 / o.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::MulAssign<FastFloat> for FastFloat {
|
||||
#[inline(always)]
|
||||
fn mul_assign(&mut self, o: FastFloat) {
|
||||
|
@ -51,3 +99,22 @@ impl From<FastFloat> for Float {
|
|||
f.0
|
||||
}
|
||||
}
|
||||
|
||||
mod numt {
|
||||
use super::{FastFloat, Float};
|
||||
use num_traits::identities::{One, Zero};
|
||||
|
||||
impl One for FastFloat {
|
||||
fn one() -> Self {
|
||||
Self(Float::one())
|
||||
}
|
||||
}
|
||||
impl Zero for FastFloat {
|
||||
fn zero() -> Self {
|
||||
Self(Float::zero())
|
||||
}
|
||||
fn is_zero(&self) -> bool {
|
||||
self.0.is_zero()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue