add matrix type
This commit is contained in:
parent
3cc7c31ee5
commit
c104082ac0
|
@ -1,5 +1,127 @@
|
|||
use super::*;
|
||||
|
||||
pub(crate) mod constmatrix {
|
||||
/// A row-major matrix
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct Matrix<T, const M: usize, const N: usize> {
|
||||
data: [[T; N]; M],
|
||||
}
|
||||
|
||||
impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
||||
fn default() -> Self {
|
||||
use std::mem::MaybeUninit;
|
||||
let mut d: [[MaybeUninit<T>; N]; M] = unsafe { MaybeUninit::uninit().assume_init() };
|
||||
|
||||
for row in d.iter_mut() {
|
||||
for item in row.iter_mut() {
|
||||
*item = MaybeUninit::new(T::default());
|
||||
}
|
||||
}
|
||||
|
||||
let data = unsafe { std::mem::transmute_copy::<_, [[T; N]; M]>(&d) };
|
||||
Self { data }
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const M: usize, const N: usize> core::ops::Index<(usize, usize)> for Matrix<T, M, N> {
|
||||
type Output = T;
|
||||
#[inline(always)]
|
||||
fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
|
||||
&self.data[i][j]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const M: usize, const N: usize> core::ops::IndexMut<(usize, usize)> for Matrix<T, M, N> {
|
||||
#[inline(always)]
|
||||
fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
|
||||
&mut self.data[i][j]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const M: usize, const N: usize> core::ops::Index<usize> for Matrix<T, M, N> {
|
||||
type Output = [T; N];
|
||||
#[inline(always)]
|
||||
fn index(&self, i: usize) -> &Self::Output {
|
||||
&self.data[i]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const M: usize, const N: usize> core::ops::IndexMut<usize> for Matrix<T, M, N> {
|
||||
#[inline(always)]
|
||||
fn index_mut(&mut self, i: usize) -> &mut Self::Output {
|
||||
&mut self.data[i]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const M: usize, const N: usize> Matrix<T, M, N> {
|
||||
pub const fn new(data: [[T; N]; M]) -> Self {
|
||||
Self { data }
|
||||
}
|
||||
pub const fn nrows(&self) -> usize {
|
||||
M
|
||||
}
|
||||
pub const fn ncols(&self) -> usize {
|
||||
N
|
||||
}
|
||||
pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P>
|
||||
where
|
||||
T: Default + core::ops::AddAssign<T>,
|
||||
for<'f> &'f T: std::ops::Mul<Output = T>,
|
||||
{
|
||||
let mut out = Matrix::default();
|
||||
self.matmul_into(other, &mut out);
|
||||
out
|
||||
}
|
||||
pub fn matmul_into<const P: usize>(
|
||||
&self,
|
||||
other: &Matrix<T, N, P>,
|
||||
out: &mut Matrix<T, M, P>,
|
||||
) where
|
||||
T: Default + core::ops::AddAssign<T>,
|
||||
for<'f> &'f T: std::ops::Mul<Output = T>,
|
||||
{
|
||||
*out = Default::default();
|
||||
for i in 0..M {
|
||||
for j in 0..P {
|
||||
for k in 0..N {
|
||||
out[(i, j)] += &self[(i, k)] * &other[(k, j)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
pub fn iter_rows(
|
||||
&self,
|
||||
) -> impl ExactSizeIterator<Item = &[T; N]> + DoubleEndedIterator<Item = &[T; N]> {
|
||||
(0..M).map(move |i| &self[i])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{super::*, *};
|
||||
#[test]
|
||||
fn construct_copy_type() {
|
||||
let _m0 = Matrix::<i32, 4, 3>::default();
|
||||
let _m1: Matrix<u8, 8, 8> = Matrix::default();
|
||||
|
||||
let _m2 = Matrix::new([[1, 2], [3, 4]]);
|
||||
}
|
||||
#[test]
|
||||
fn construct_non_copy() {
|
||||
let _m = Matrix::<String, 2, 1>::default();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul() {
|
||||
let m1 = Matrix::new([[1_u8, 2, 3], [4, 5, 6]]);
|
||||
let m2 = Matrix::new([[7_u8, 8, 9, 10], [11, 12, 13, 14], [15, 16, 17, 18]]);
|
||||
|
||||
let m3 = m1.matmul(&m2);
|
||||
assert_eq!(m3, Matrix::new([[74, 80, 86, 92], [173, 188, 203, 218]]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
pub(crate) fn diff_op_1d(
|
||||
block: &[&[Float]],
|
||||
|
|
Loading…
Reference in New Issue