change order in matmul_into

This commit is contained in:
Magnus Ulimoen 2021-01-31 13:23:15 +01:00
parent 481f2d607e
commit b0e1ec62f8
1 changed files with 43 additions and 29 deletions

View File

@ -61,31 +61,6 @@ pub(crate) mod constmatrix {
pub const fn ncols(&self) -> usize { pub const fn ncols(&self) -> usize {
N N
} }
pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P>
where
T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
{
let mut out = Matrix::default();
self.matmul_into(other, &mut out);
out
}
pub fn matmul_into<const P: usize>(
&self,
other: &Matrix<T, N, P>,
out: &mut Matrix<T, M, P>,
) where
T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
{
for i in 0..M {
for j in 0..P {
let mut t = T::default();
for k in 0..N {
t = t + self[(i, k)] * other[(k, j)];
}
out[(i, j)] = t;
}
}
}
#[inline(always)] #[inline(always)]
pub fn iter(&self) -> impl Iterator<Item = &T> { pub fn iter(&self) -> impl Iterator<Item = &T> {
self.data.iter().flatten() self.data.iter().flatten()
@ -135,6 +110,45 @@ pub(crate) mod constmatrix {
} }
} }
impl<T, const M: usize, const P: usize> Matrix<T, M, P> {
#[inline(always)]
pub fn matmul_into<const N: usize>(&mut self, lhs: &Matrix<T, M, N>, rhs: &Matrix<T, N, P>)
where
T: Default + Copy + core::ops::Mul<Output = T> + core::ops::Add<Output = T>,
{
for i in 0..M {
for j in 0..P {
let mut t = T::default();
for k in 0..N {
t = t + lhs[(i, k)] * rhs[(k, j)];
}
self[(i, j)] = t;
}
}
}
}
macro_rules! impl_op_mul_mul {
($lhs:ty, $rhs:ty) => {
impl<T, const N: usize, const M: usize, const P: usize> core::ops::Mul<$rhs> for $lhs
where
T: Copy + Default + core::ops::Add<Output = T> + core::ops::Mul<Output = T>,
{
type Output = Matrix<T, M, P>;
fn mul(self, rhs: $rhs) -> Self::Output {
let mut out = Matrix::default();
out.matmul_into(&self, &rhs);
out
}
}
};
}
impl_op_mul_mul! { Matrix<T, M, N>, Matrix<T, N, P> }
impl_op_mul_mul! { &Matrix<T, M, N>, Matrix<T, N, P> }
impl_op_mul_mul! { Matrix<T, M, N>, &Matrix<T, N, P> }
impl_op_mul_mul! { &Matrix<T, M, N>, &Matrix<T, N, P> }
impl<T, const M: usize, const N: usize> core::ops::MulAssign<T> for Matrix<T, M, N> impl<T, const M: usize, const N: usize> core::ops::MulAssign<T> for Matrix<T, M, N>
where where
T: Copy + core::ops::MulAssign<T>, T: Copy + core::ops::MulAssign<T>,
@ -161,7 +175,7 @@ pub(crate) mod constmatrix {
let m1 = Matrix::new([[1_u8, 2, 3], [4, 5, 6]]); let m1 = Matrix::new([[1_u8, 2, 3], [4, 5, 6]]);
let m2 = Matrix::new([[7_u8, 8, 9, 10], [11, 12, 13, 14], [15, 16, 17, 18]]); let m2 = Matrix::new([[7_u8, 8, 9, 10], [11, 12, 13, 14], [15, 16, 17, 18]]);
let m3 = m1.matmul(&m2); let m3 = m1 * m2;
assert_eq!(m3, Matrix::new([[74, 80, 86, 92], [173, 188, 203, 218]])); assert_eq!(m3, Matrix::new([[74, 80, 86, 92], [173, 188, 203, 218]]));
} }
} }
@ -325,7 +339,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap()); let prev = ColVector::<_, N>::map_to_col(prev.array_windows::<N>().next().unwrap());
let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut(futb1.try_into().unwrap());
block.matmul_into(prev, fut); fut.matmul_into(block, prev);
*fut *= idx; *fut *= idx;
} }
@ -341,7 +355,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
let fut = ColVector::<_, 1>::map_to_col_mut(f); let fut = ColVector::<_, 1>::map_to_col_mut(f);
let prev = ColVector::<_, D>::map_to_col(window); let prev = ColVector::<_, D>::map_to_col(window);
diag.matmul_into(prev, fut); fut.matmul_into(diag, prev);
*fut *= idx; *fut *= idx;
} }
@ -350,7 +364,7 @@ pub(crate) fn diff_op_1d_slice_matrix<const M: usize, const N: usize, const D: u
let prev = ColVector::<_, N>::map_to_col(prev); let prev = ColVector::<_, N>::map_to_col(prev);
let fut = ColVector::<_, M>::map_to_col_mut(futb2.try_into().unwrap()); let fut = ColVector::<_, M>::map_to_col_mut(futb2.try_into().unwrap());
endblock.matmul_into(prev, fut); fut.matmul_into(endblock, prev);
*fut *= idx; *fut *= idx;
} }
} }