simplify traits

This commit is contained in:
Magnus Ulimoen 2021-01-29 16:42:49 +01:00
parent f7c238f6a7
commit 3c7cc4605a
4 changed files with 26 additions and 61 deletions

View File

@ -10,7 +10,6 @@ approx = "0.3.2"
packed_simd = { version = "0.3.3", package = "packed_simd_2" } packed_simd = { version = "0.3.3", package = "packed_simd_2" }
rayon = { version = "1.3.0", optional = true } rayon = { version = "1.3.0", optional = true }
sprs = { version = "0.9.0", optional = true, default-features = false } sprs = { version = "0.9.0", optional = true, default-features = false }
num-traits = "0.2.11"
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] } serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
[features] [features]

View File

@ -1,6 +1,7 @@
use super::*; use super::*;
pub(crate) mod constmatrix { pub(crate) mod constmatrix {
#![allow(unused)]
/// A row-major matrix /// A row-major matrix
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
#[repr(transparent)] #[repr(transparent)]
@ -70,8 +71,7 @@ pub(crate) mod constmatrix {
} }
pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P> pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P>
where where
T: Default + core::ops::AddAssign<T>, T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
for<'f> &'f T: std::ops::Mul<Output = T>,
{ {
let mut out = Matrix::default(); let mut out = Matrix::default();
self.matmul_into(other, &mut out); self.matmul_into(other, &mut out);
@ -82,14 +82,13 @@ pub(crate) mod constmatrix {
other: &Matrix<T, N, P>, other: &Matrix<T, N, P>,
out: &mut Matrix<T, M, P>, out: &mut Matrix<T, M, P>,
) where ) where
T: Default + core::ops::AddAssign<T>, T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
for<'f> &'f T: std::ops::Mul<Output = T>,
{ {
for i in 0..M { for i in 0..M {
for j in 0..P { for j in 0..P {
let mut t = T::default(); let mut t = T::default();
for k in 0..N { for k in 0..N {
t += &self[(i, k)] * &other[(k, j)]; t = t + self[(i, k)] * other[(k, j)];
} }
out[(i, j)] = t; out[(i, j)] = t;
} }
@ -144,12 +143,12 @@ pub(crate) mod constmatrix {
} }
} }
impl<T, const M: usize, const N: usize> core::ops::MulAssign<&T> for Matrix<T, M, N> impl<T, const M: usize, const N: usize> core::ops::MulAssign<T> for Matrix<T, M, N>
where where
for<'f> T: core::ops::MulAssign<&'f T>, T: Copy + core::ops::MulAssign<T>,
{ {
#[inline(always)] #[inline(always)]
fn mul_assign(&mut self, other: &T) { fn mul_assign(&mut self, other: T) {
self.iter_mut().for_each(|x| *x *= other) self.iter_mut().for_each(|x| *x *= other)
} }
} }
@ -243,22 +242,33 @@ pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
} }
} }
#[cfg(feature = "fast-float")]
mod fastfloat { mod fastfloat {
use super::*; use super::*;
#[repr(transparent)] #[repr(transparent)]
#[derive(Debug, PartialEq, Default)] #[derive(Copy, Clone, Debug, PartialEq, Default)]
pub(crate) struct FastFloat(Float); pub(crate) struct FastFloat(Float);
use core::intrinsics::{fadd_fast, fmul_fast}; use core::intrinsics::{fadd_fast, fmul_fast};
impl core::ops::Mul for FastFloat { impl core::ops::Mul for FastFloat {
type Output = Self; type Output = Self;
#[inline(always)]
fn mul(self, o: Self) -> Self::Output { fn mul(self, o: Self) -> Self::Output {
unsafe { Self(fmul_fast(self.0, o.0)) } unsafe { Self(fmul_fast(self.0, o.0)) }
} }
} }
impl core::ops::Add for FastFloat {
type Output = Self;
#[inline(always)]
fn add(self, o: FastFloat) -> Self::Output {
unsafe { Self(fadd_fast(self.0, o.0)) }
}
}
impl core::ops::MulAssign<FastFloat> for FastFloat { impl core::ops::MulAssign<FastFloat> for FastFloat {
#[inline(always)]
fn mul_assign(&mut self, o: FastFloat) { fn mul_assign(&mut self, o: FastFloat) {
unsafe { unsafe {
self.0 = fmul_fast(self.0, o.0); self.0 = fmul_fast(self.0, o.0);
@ -266,60 +276,16 @@ mod fastfloat {
} }
} }
impl core::ops::MulAssign<&FastFloat> for FastFloat {
fn mul_assign(&mut self, o: &FastFloat) {
unsafe {
self.0 = fmul_fast(self.0, o.0);
}
}
}
impl core::ops::MulAssign<FastFloat> for &mut FastFloat {
fn mul_assign(&mut self, o: FastFloat) {
unsafe {
self.0 = fmul_fast(self.0, o.0);
}
}
}
impl core::ops::MulAssign<&FastFloat> for &mut FastFloat {
fn mul_assign(&mut self, o: &FastFloat) {
unsafe {
self.0 = fmul_fast(self.0, o.0);
}
}
}
impl core::ops::AddAssign for FastFloat {
fn add_assign(&mut self, o: Self) {
unsafe {
self.0 = fadd_fast(self.0, o.0);
}
}
}
impl core::ops::Mul for &FastFloat { impl core::ops::Mul for &FastFloat {
type Output = FastFloat; type Output = FastFloat;
#[inline(always)]
fn mul(self, o: Self) -> Self::Output { fn mul(self, o: Self) -> Self::Output {
unsafe { FastFloat(fmul_fast(self.0, o.0)) } unsafe { FastFloat(fmul_fast(self.0, o.0)) }
} }
} }
impl core::ops::MulAssign<&Float> for FastFloat {
fn mul_assign(&mut self, o: &Float) {
unsafe {
self.0 = fmul_fast(self.0, *o);
}
}
}
impl core::ops::MulAssign<&Float> for &mut FastFloat {
fn mul_assign(&mut self, o: &Float) {
unsafe {
self.0 = fmul_fast(self.0, *o);
}
}
}
impl From<Float> for FastFloat { impl From<Float> for FastFloat {
#[inline(always)]
fn from(f: Float) -> Self { fn from(f: Float) -> Self {
Self(f) Self(f)
} }
@ -373,7 +339,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap());
block.matmul_into(prev, fut); block.matmul_into(prev, fut);
*fut *= &idx; *fut *= idx;
} }
// The window needs to be aligned to the diagonal elements, // The window needs to be aligned to the diagonal elements,
@ -391,7 +357,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
let prev = ColVector::<_, D>::map_to_col(window); let prev = ColVector::<_, D>::map_to_col(window);
diag.matmul_into(prev, fut); diag.matmul_into(prev, fut);
*fut *= &idx; *fut *= idx;
} }
{ {
@ -400,7 +366,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());
endblock.matmul_into(prev, fut); endblock.matmul_into(prev, fut);
*fut *= &idx; *fut *= idx;
} }
} }

View File

@ -227,7 +227,7 @@ fn test_trad4() {
#[test] #[test]
fn block_equality() { fn block_equality() {
let mut flipped_inverted = SBP4::BLOCK_MATRIX.flip(); let mut flipped_inverted = SBP4::BLOCK_MATRIX.flip();
flipped_inverted *= &-1.0; flipped_inverted *= -1.0;
assert!(flipped_inverted assert!(flipped_inverted
.iter() .iter()

View File

@ -207,7 +207,7 @@ fn test_trad8() {
#[test] #[test]
fn block_equality() { fn block_equality() {
let mut flipped_inverted = SBP8::BLOCK_MATRIX.flip(); let mut flipped_inverted = SBP8::BLOCK_MATRIX.flip();
flipped_inverted *= &-1.0; flipped_inverted *= -1.0;
assert!(flipped_inverted assert!(flipped_inverted
.iter() .iter()