simplify traits
This commit is contained in:
parent
f7c238f6a7
commit
3c7cc4605a
|
@ -10,7 +10,6 @@ approx = "0.3.2"
|
||||||
packed_simd = { version = "0.3.3", package = "packed_simd_2" }
|
packed_simd = { version = "0.3.3", package = "packed_simd_2" }
|
||||||
rayon = { version = "1.3.0", optional = true }
|
rayon = { version = "1.3.0", optional = true }
|
||||||
sprs = { version = "0.9.0", optional = true, default-features = false }
|
sprs = { version = "0.9.0", optional = true, default-features = false }
|
||||||
num-traits = "0.2.11"
|
|
||||||
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
pub(crate) mod constmatrix {
|
pub(crate) mod constmatrix {
|
||||||
|
#![allow(unused)]
|
||||||
/// A row-major matrix
|
/// A row-major matrix
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
|
@ -70,8 +71,7 @@ pub(crate) mod constmatrix {
|
||||||
}
|
}
|
||||||
pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P>
|
pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P>
|
||||||
where
|
where
|
||||||
T: Default + core::ops::AddAssign<T>,
|
T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
|
||||||
for<'f> &'f T: std::ops::Mul<Output = T>,
|
|
||||||
{
|
{
|
||||||
let mut out = Matrix::default();
|
let mut out = Matrix::default();
|
||||||
self.matmul_into(other, &mut out);
|
self.matmul_into(other, &mut out);
|
||||||
|
@ -82,14 +82,13 @@ pub(crate) mod constmatrix {
|
||||||
other: &Matrix<T, N, P>,
|
other: &Matrix<T, N, P>,
|
||||||
out: &mut Matrix<T, M, P>,
|
out: &mut Matrix<T, M, P>,
|
||||||
) where
|
) where
|
||||||
T: Default + core::ops::AddAssign<T>,
|
T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
|
||||||
for<'f> &'f T: std::ops::Mul<Output = T>,
|
|
||||||
{
|
{
|
||||||
for i in 0..M {
|
for i in 0..M {
|
||||||
for j in 0..P {
|
for j in 0..P {
|
||||||
let mut t = T::default();
|
let mut t = T::default();
|
||||||
for k in 0..N {
|
for k in 0..N {
|
||||||
t += &self[(i, k)] * &other[(k, j)];
|
t = t + self[(i, k)] * other[(k, j)];
|
||||||
}
|
}
|
||||||
out[(i, j)] = t;
|
out[(i, j)] = t;
|
||||||
}
|
}
|
||||||
|
@ -144,12 +143,12 @@ pub(crate) mod constmatrix {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T, const M: usize, const N: usize> core::ops::MulAssign<&T> for Matrix<T, M, N>
|
impl<T, const M: usize, const N: usize> core::ops::MulAssign<T> for Matrix<T, M, N>
|
||||||
where
|
where
|
||||||
for<'f> T: core::ops::MulAssign<&'f T>,
|
T: Copy + core::ops::MulAssign<T>,
|
||||||
{
|
{
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn mul_assign(&mut self, other: &T) {
|
fn mul_assign(&mut self, other: T) {
|
||||||
self.iter_mut().for_each(|x| *x *= other)
|
self.iter_mut().for_each(|x| *x *= other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -243,22 +242,33 @@ pub(crate) fn diff_op_1d_matrix<const M: usize, const N: usize, const D: usize>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "fast-float")]
|
||||||
mod fastfloat {
|
mod fastfloat {
|
||||||
use super::*;
|
use super::*;
|
||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
#[derive(Debug, PartialEq, Default)]
|
#[derive(Copy, Clone, Debug, PartialEq, Default)]
|
||||||
pub(crate) struct FastFloat(Float);
|
pub(crate) struct FastFloat(Float);
|
||||||
|
|
||||||
use core::intrinsics::{fadd_fast, fmul_fast};
|
use core::intrinsics::{fadd_fast, fmul_fast};
|
||||||
|
|
||||||
impl core::ops::Mul for FastFloat {
|
impl core::ops::Mul for FastFloat {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
#[inline(always)]
|
||||||
fn mul(self, o: Self) -> Self::Output {
|
fn mul(self, o: Self) -> Self::Output {
|
||||||
unsafe { Self(fmul_fast(self.0, o.0)) }
|
unsafe { Self(fmul_fast(self.0, o.0)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl core::ops::Add for FastFloat {
|
||||||
|
type Output = Self;
|
||||||
|
#[inline(always)]
|
||||||
|
fn add(self, o: FastFloat) -> Self::Output {
|
||||||
|
unsafe { Self(fadd_fast(self.0, o.0)) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl core::ops::MulAssign<FastFloat> for FastFloat {
|
impl core::ops::MulAssign<FastFloat> for FastFloat {
|
||||||
|
#[inline(always)]
|
||||||
fn mul_assign(&mut self, o: FastFloat) {
|
fn mul_assign(&mut self, o: FastFloat) {
|
||||||
unsafe {
|
unsafe {
|
||||||
self.0 = fmul_fast(self.0, o.0);
|
self.0 = fmul_fast(self.0, o.0);
|
||||||
|
@ -266,60 +276,16 @@ mod fastfloat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl core::ops::MulAssign<&FastFloat> for FastFloat {
|
|
||||||
fn mul_assign(&mut self, o: &FastFloat) {
|
|
||||||
unsafe {
|
|
||||||
self.0 = fmul_fast(self.0, o.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl core::ops::MulAssign<FastFloat> for &mut FastFloat {
|
|
||||||
fn mul_assign(&mut self, o: FastFloat) {
|
|
||||||
unsafe {
|
|
||||||
self.0 = fmul_fast(self.0, o.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl core::ops::MulAssign<&FastFloat> for &mut FastFloat {
|
|
||||||
fn mul_assign(&mut self, o: &FastFloat) {
|
|
||||||
unsafe {
|
|
||||||
self.0 = fmul_fast(self.0, o.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl core::ops::AddAssign for FastFloat {
|
|
||||||
fn add_assign(&mut self, o: Self) {
|
|
||||||
unsafe {
|
|
||||||
self.0 = fadd_fast(self.0, o.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl core::ops::Mul for &FastFloat {
|
impl core::ops::Mul for &FastFloat {
|
||||||
type Output = FastFloat;
|
type Output = FastFloat;
|
||||||
|
#[inline(always)]
|
||||||
fn mul(self, o: Self) -> Self::Output {
|
fn mul(self, o: Self) -> Self::Output {
|
||||||
unsafe { FastFloat(fmul_fast(self.0, o.0)) }
|
unsafe { FastFloat(fmul_fast(self.0, o.0)) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl core::ops::MulAssign<&Float> for FastFloat {
|
|
||||||
fn mul_assign(&mut self, o: &Float) {
|
|
||||||
unsafe {
|
|
||||||
self.0 = fmul_fast(self.0, *o);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl core::ops::MulAssign<&Float> for &mut FastFloat {
|
|
||||||
fn mul_assign(&mut self, o: &Float) {
|
|
||||||
unsafe {
|
|
||||||
self.0 = fmul_fast(self.0, *o);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl From<Float> for FastFloat {
|
impl From<Float> for FastFloat {
|
||||||
|
#[inline(always)]
|
||||||
fn from(f: Float) -> Self {
|
fn from(f: Float) -> Self {
|
||||||
Self(f)
|
Self(f)
|
||||||
}
|
}
|
||||||
|
@ -373,7 +339,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
|
||||||
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap());
|
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[0..M]).try_into().unwrap());
|
||||||
|
|
||||||
block.matmul_into(prev, fut);
|
block.matmul_into(prev, fut);
|
||||||
*fut *= &idx;
|
*fut *= idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
// The window needs to be aligned to the diagonal elements,
|
// The window needs to be aligned to the diagonal elements,
|
||||||
|
@ -391,7 +357,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
|
||||||
let prev = ColVector::<_, D>::map_to_col(window);
|
let prev = ColVector::<_, D>::map_to_col(window);
|
||||||
|
|
||||||
diag.matmul_into(prev, fut);
|
diag.matmul_into(prev, fut);
|
||||||
*fut *= &idx;
|
*fut *= idx;
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -400,7 +366,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
|
||||||
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());
|
let fut = ColVector::<_, M>::map_to_col_mut((&mut fut[nx - M..]).try_into().unwrap());
|
||||||
|
|
||||||
endblock.matmul_into(prev, fut);
|
endblock.matmul_into(prev, fut);
|
||||||
*fut *= &idx;
|
*fut *= idx;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -227,7 +227,7 @@ fn test_trad4() {
|
||||||
#[test]
|
#[test]
|
||||||
fn block_equality() {
|
fn block_equality() {
|
||||||
let mut flipped_inverted = SBP4::BLOCK_MATRIX.flip();
|
let mut flipped_inverted = SBP4::BLOCK_MATRIX.flip();
|
||||||
flipped_inverted *= &-1.0;
|
flipped_inverted *= -1.0;
|
||||||
|
|
||||||
assert!(flipped_inverted
|
assert!(flipped_inverted
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
@ -207,7 +207,7 @@ fn test_trad8() {
|
||||||
#[test]
|
#[test]
|
||||||
fn block_equality() {
|
fn block_equality() {
|
||||||
let mut flipped_inverted = SBP8::BLOCK_MATRIX.flip();
|
let mut flipped_inverted = SBP8::BLOCK_MATRIX.flip();
|
||||||
flipped_inverted *= &-1.0;
|
flipped_inverted *= -1.0;
|
||||||
|
|
||||||
assert!(flipped_inverted
|
assert!(flipped_inverted
|
||||||
.iter()
|
.iter()
|
||||||
|
|
Loading…
Reference in New Issue