try ndarray transmute
This commit is contained in:
parent
b15ea57e6d
commit
74d99a4a18
|
@ -11,6 +11,7 @@ packed_simd = { version = "0.3.3", package = "packed_simd_2" }
|
||||||
rayon = { version = "1.3.0", optional = true }
|
rayon = { version = "1.3.0", optional = true }
|
||||||
sprs = { version = "0.9.0", optional = true, default-features = false }
|
sprs = { version = "0.9.0", optional = true, default-features = false }
|
||||||
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
||||||
|
num-traits = "0.2.14"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
# Use f32 as precision, default is f64
|
# Use f32 as precision, default is f64
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
use ndarray::s;
|
use ndarray::s;
|
||||||
|
use num_traits::Zero;
|
||||||
|
|
||||||
pub(crate) mod constmatrix;
|
pub(crate) mod constmatrix;
|
||||||
pub(crate) use constmatrix::{flip_lr, flip_sign, flip_ud, ColVector, Matrix, RowVector};
|
pub(crate) use constmatrix::{flip_lr, flip_sign, flip_ud, ColVector, Matrix, RowVector};
|
||||||
|
@ -194,14 +195,24 @@ pub(crate) fn diff_op_1d<const M: usize, const N: usize, const D: usize>(
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
#[allow(unused)]
|
|
||||||
/// 2D diff fallback for when matrices are not slicable
|
/// 2D diff fallback for when matrices are not slicable
|
||||||
pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize>(
|
pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize>(
|
||||||
matrix: &BlockMatrix<Float, M, N, D>,
|
matrix: &BlockMatrix<Float, M, N, D>,
|
||||||
optype: OperatorType,
|
optype: OperatorType,
|
||||||
prev: ArrayView2<Float>,
|
prev: ArrayView2<Float>,
|
||||||
mut fut: ArrayViewMut2<Float>,
|
fut: ArrayViewMut2<Float>,
|
||||||
) {
|
) {
|
||||||
|
#[cfg(feature = "fast-float")]
|
||||||
|
let (matrix, prev, mut fut) = unsafe {
|
||||||
|
(
|
||||||
|
std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
|
||||||
|
std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
|
||||||
|
std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
#[cfg(not(feature = "fast-float"))]
|
||||||
|
let mut fut = fut;
|
||||||
|
|
||||||
assert_eq!(prev.shape(), fut.shape());
|
assert_eq!(prev.shape(), fut.shape());
|
||||||
let nx = prev.shape()[1];
|
let nx = prev.shape()[1];
|
||||||
let ny = prev.shape()[0];
|
let ny = prev.shape()[0];
|
||||||
|
@ -214,8 +225,7 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
||||||
};
|
};
|
||||||
let idx = 1.0 / dx;
|
let idx = 1.0 / dx;
|
||||||
|
|
||||||
fut.fill(0.0);
|
fut.fill(0.0.into());
|
||||||
|
|
||||||
let (mut fut0, mut futmid, mut futn) = fut.multi_slice_mut((
|
let (mut fut0, mut futmid, mut futn) = fut.multi_slice_mut((
|
||||||
ndarray::s![.., ..M],
|
ndarray::s![.., ..M],
|
||||||
ndarray::s![.., M..nx - M],
|
ndarray::s![.., M..nx - M],
|
||||||
|
@ -230,7 +240,7 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
||||||
{
|
{
|
||||||
debug_assert_eq!(fut.len(), prev.shape()[0]);
|
debug_assert_eq!(fut.len(), prev.shape()[0]);
|
||||||
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
|
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
|
||||||
if bl == 0.0 {
|
if bl.is_zero() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
debug_assert_eq!(prev.len(), fut.len());
|
debug_assert_eq!(prev.len(), fut.len());
|
||||||
|
@ -246,7 +256,7 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
||||||
.zip(prev.windows((ny, D)).into_iter().skip(window_elems_to_skip))
|
.zip(prev.windows((ny, D)).into_iter().skip(window_elems_to_skip))
|
||||||
{
|
{
|
||||||
for (&d, id) in matrix.diag.iter().zip(id.axis_iter(ndarray::Axis(1))) {
|
for (&d, id) in matrix.diag.iter().zip(id.axis_iter(ndarray::Axis(1))) {
|
||||||
if d == 0.0 {
|
if d.is_zero() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
fut.scaled_add(idx * d, &id)
|
fut.scaled_add(idx * d, &id)
|
||||||
|
@ -260,9 +270,8 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
|
||||||
.iter_rows()
|
.iter_rows()
|
||||||
.zip(futn.axis_iter_mut(ndarray::Axis(1)))
|
.zip(futn.axis_iter_mut(ndarray::Axis(1)))
|
||||||
{
|
{
|
||||||
fut.fill(0.0);
|
|
||||||
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
|
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
|
||||||
if bl == 0.0 {
|
if bl.is_zero() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
fut.scaled_add(idx * bl, &prev);
|
fut.scaled_add(idx * bl, &prev);
|
||||||
|
|
|
@ -14,14 +14,62 @@ impl core::ops::Mul for FastFloat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl core::ops::Add for FastFloat {
|
impl core::ops::Mul<Float> for FastFloat {
|
||||||
type Output = Self;
|
type Output = FastFloat;
|
||||||
|
#[inline(always)]
|
||||||
|
fn mul(self, o: Float) -> Self::Output {
|
||||||
|
unsafe { Self(fmul_fast(self.0, o)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::ops::Mul<FastFloat> for Float {
|
||||||
|
type Output = FastFloat;
|
||||||
|
#[inline(always)]
|
||||||
|
fn mul(self, o: FastFloat) -> Self::Output {
|
||||||
|
unsafe { FastFloat(fmul_fast(self, o.0)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::ops::Add<FastFloat> for FastFloat {
|
||||||
|
type Output = FastFloat;
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn add(self, o: FastFloat) -> Self::Output {
|
fn add(self, o: FastFloat) -> Self::Output {
|
||||||
unsafe { Self(fadd_fast(self.0, o.0)) }
|
unsafe { Self(fadd_fast(self.0, o.0)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl core::ops::Add<Float> for FastFloat {
|
||||||
|
type Output = FastFloat;
|
||||||
|
#[inline(always)]
|
||||||
|
fn add(self, o: Float) -> Self::Output {
|
||||||
|
unsafe { Self(fadd_fast(self.0, o)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::ops::Add<FastFloat> for Float {
|
||||||
|
type Output = FastFloat;
|
||||||
|
#[inline(always)]
|
||||||
|
fn add(self, o: FastFloat) -> Self::Output {
|
||||||
|
unsafe { FastFloat(fadd_fast(self, o.0)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::ops::Sub for FastFloat {
|
||||||
|
type Output = Self;
|
||||||
|
#[inline(always)]
|
||||||
|
fn sub(self, o: FastFloat) -> Self::Output {
|
||||||
|
unsafe { Self(fadd_fast(self.0, -o.0)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl core::ops::Div for FastFloat {
|
||||||
|
type Output = Self;
|
||||||
|
#[inline(always)]
|
||||||
|
fn div(self, o: FastFloat) -> Self::Output {
|
||||||
|
Self(self.0 / o.0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl core::ops::MulAssign<FastFloat> for FastFloat {
|
impl core::ops::MulAssign<FastFloat> for FastFloat {
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn mul_assign(&mut self, o: FastFloat) {
|
fn mul_assign(&mut self, o: FastFloat) {
|
||||||
|
@ -51,3 +99,22 @@ impl From<FastFloat> for Float {
|
||||||
f.0
|
f.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mod numt {
|
||||||
|
use super::{FastFloat, Float};
|
||||||
|
use num_traits::identities::{One, Zero};
|
||||||
|
|
||||||
|
impl One for FastFloat {
|
||||||
|
fn one() -> Self {
|
||||||
|
Self(Float::one())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Zero for FastFloat {
|
||||||
|
fn zero() -> Self {
|
||||||
|
Self(Float::zero())
|
||||||
|
}
|
||||||
|
fn is_zero(&self) -> bool {
|
||||||
|
self.0.is_zero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue