From b0e1ec62f86b3e19a98cdc066b3288fcd74be8bb Mon Sep 17 00:00:00 2001 From: Magnus Ulimoen Date: Sun, 31 Jan 2021 13:23:15 +0100 Subject: [PATCH] change order in matmul_into --- sbp/src/operators/algos.rs | 72 +++++++++++++++++++++++--------------- 1 file changed, 43 insertions(+), 29 deletions(-) diff --git a/sbp/src/operators/algos.rs b/sbp/src/operators/algos.rs index 1928f2d..1681501 100644 --- a/sbp/src/operators/algos.rs +++ b/sbp/src/operators/algos.rs @@ -61,31 +61,6 @@ pub(crate) mod constmatrix { pub const fn ncols(&self) -> usize { N } - pub fn matmul(&self, other: &Matrix) -> Matrix - where - T: Copy + Default + core::ops::Add + core::ops::Mul, - { - let mut out = Matrix::default(); - self.matmul_into(other, &mut out); - out - } - pub fn matmul_into( - &self, - other: &Matrix, - out: &mut Matrix, - ) where - T: Copy + Default + core::ops::Add + core::ops::Mul, - { - for i in 0..M { - for j in 0..P { - let mut t = T::default(); - for k in 0..N { - t = t + self[(i, k)] * other[(k, j)]; - } - out[(i, j)] = t; - } - } - } #[inline(always)] pub fn iter(&self) -> impl Iterator { self.data.iter().flatten() @@ -135,6 +110,45 @@ pub(crate) mod constmatrix { } } + impl Matrix { + #[inline(always)] + pub fn matmul_into(&mut self, lhs: &Matrix, rhs: &Matrix) + where + T: Default + Copy + core::ops::Mul + core::ops::Add, + { + for i in 0..M { + for j in 0..P { + let mut t = T::default(); + for k in 0..N { + t = t + lhs[(i, k)] * rhs[(k, j)]; + } + self[(i, j)] = t; + } + } + } + } + + macro_rules! impl_op_mul_mul { + ($lhs:ty, $rhs:ty) => { + impl core::ops::Mul<$rhs> for $lhs + where + T: Copy + Default + core::ops::Add + core::ops::Mul, + { + type Output = Matrix; + fn mul(self, rhs: $rhs) -> Self::Output { + let mut out = Matrix::default(); + out.matmul_into(&self, &rhs); + out + } + } + }; + } + + impl_op_mul_mul! { Matrix, Matrix } + impl_op_mul_mul! { &Matrix, Matrix } + impl_op_mul_mul! { Matrix, &Matrix } + impl_op_mul_mul! { &Matrix, &Matrix } + impl core::ops::MulAssign for Matrix where T: Copy + core::ops::MulAssign, @@ -161,7 +175,7 @@ pub(crate) mod constmatrix { let m1 = Matrix::new([[1_u8, 2, 3], [4, 5, 6]]); let m2 = Matrix::new([[7_u8, 8, 9, 10], [11, 12, 13, 14], [15, 16, 17, 18]]); - let m3 = m1.matmul(&m2); + let m3 = m1 * m2; assert_eq!(m3, Matrix::new([[74, 80, 86, 92], [173, 188, 203, 218]])); } } @@ -325,7 +339,7 @@ pub(crate) fn diff_op_1d_slice_matrix::map_to_col(prev.array_windows::().next().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap()); - block.matmul_into(prev, fut); + fut.matmul_into(block, prev); *fut *= idx; } @@ -341,7 +355,7 @@ pub(crate) fn diff_op_1d_slice_matrix::map_to_col_mut(f); let prev = ColVector::<_, D>::map_to_col(window); - diag.matmul_into(prev, fut); + fut.matmul_into(diag, prev); *fut *= idx; } @@ -350,7 +364,7 @@ pub(crate) fn diff_op_1d_slice_matrix::map_to_col(prev); let fut = ColVector::<_, M>::map_to_col_mut(futb2.try_into().unwrap()); - endblock.matmul_into(prev, fut); + fut.matmul_into(endblock, prev); *fut *= idx; } }