diff --git a/sbp/src/operators/algos/constmatrix.rs b/sbp/src/operators/algos/constmatrix.rs index b409bed..7203d64 100644 --- a/sbp/src/operators/algos/constmatrix.rs +++ b/sbp/src/operators/algos/constmatrix.rs @@ -1,4 +1,7 @@ #![allow(unused)] + +use num_traits::identities::Zero; + /// A row-major matrix #[derive(Debug, Clone, PartialEq, Eq, Hash)] #[repr(C)] @@ -16,6 +19,17 @@ impl Default for Matrix Zero for Matrix { + fn zero() -> Self { + Self { + data: [[T::zero(); N]; M], + } + } + fn is_zero(&self) -> bool { + self == &Self::zero() + } +} + impl core::ops::Index<(usize, usize)> for Matrix { type Output = T; #[inline(always)] @@ -123,9 +137,9 @@ macro_rules! impl_op_mul_mul { T: Copy + Default + core::ops::Add + core::ops::Mul, { type Output = Matrix; - fn mul(self, rhs: $rhs) -> Self::Output { + fn mul(self, lhs: $rhs) -> Self::Output { let mut out = Matrix::default(); - out.matmul_into(&self, &rhs); + out.matmul_into(&self, &lhs); out } } @@ -147,6 +161,22 @@ where } } +impl core::ops::Add> for Matrix +where + T: Copy + Zero + core::ops::Add + PartialEq, +{ + type Output = Self; + fn add(self, lhs: Self) -> Self::Output { + let mut out = Matrix::zero(); + for i in 0..M { + for j in 0..N { + out[(i, j)] = self[(i, j)] + lhs[(i, j)]; + } + } + out + } +} + #[cfg(test)] mod tests { use super::{super::*, *};