add matrix type
This commit is contained in:
parent
3cc7c31ee5
commit
c104082ac0
|
@ -1,5 +1,127 @@
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
pub(crate) mod constmatrix {
|
||||||
|
/// A row-major matrix
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
|
||||||
|
pub struct Matrix<T, const M: usize, const N: usize> {
|
||||||
|
data: [[T; N]; M],
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Default, const M: usize, const N: usize> Default for Matrix<T, M, N> {
|
||||||
|
fn default() -> Self {
|
||||||
|
use std::mem::MaybeUninit;
|
||||||
|
let mut d: [[MaybeUninit<T>; N]; M] = unsafe { MaybeUninit::uninit().assume_init() };
|
||||||
|
|
||||||
|
for row in d.iter_mut() {
|
||||||
|
for item in row.iter_mut() {
|
||||||
|
*item = MaybeUninit::new(T::default());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = unsafe { std::mem::transmute_copy::<_, [[T; N]; M]>(&d) };
|
||||||
|
Self { data }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const M: usize, const N: usize> core::ops::Index<(usize, usize)> for Matrix<T, M, N> {
|
||||||
|
type Output = T;
|
||||||
|
#[inline(always)]
|
||||||
|
fn index(&self, (i, j): (usize, usize)) -> &Self::Output {
|
||||||
|
&self.data[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const M: usize, const N: usize> core::ops::IndexMut<(usize, usize)> for Matrix<T, M, N> {
|
||||||
|
#[inline(always)]
|
||||||
|
fn index_mut(&mut self, (i, j): (usize, usize)) -> &mut Self::Output {
|
||||||
|
&mut self.data[i][j]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const M: usize, const N: usize> core::ops::Index<usize> for Matrix<T, M, N> {
|
||||||
|
type Output = [T; N];
|
||||||
|
#[inline(always)]
|
||||||
|
fn index(&self, i: usize) -> &Self::Output {
|
||||||
|
&self.data[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const M: usize, const N: usize> core::ops::IndexMut<usize> for Matrix<T, M, N> {
|
||||||
|
#[inline(always)]
|
||||||
|
fn index_mut(&mut self, i: usize) -> &mut Self::Output {
|
||||||
|
&mut self.data[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T, const M: usize, const N: usize> Matrix<T, M, N> {
|
||||||
|
pub const fn new(data: [[T; N]; M]) -> Self {
|
||||||
|
Self { data }
|
||||||
|
}
|
||||||
|
pub const fn nrows(&self) -> usize {
|
||||||
|
M
|
||||||
|
}
|
||||||
|
pub const fn ncols(&self) -> usize {
|
||||||
|
N
|
||||||
|
}
|
||||||
|
pub fn matmul<const P: usize>(&self, other: &Matrix<T, N, P>) -> Matrix<T, M, P>
|
||||||
|
where
|
||||||
|
T: Default + core::ops::AddAssign<T>,
|
||||||
|
for<'f> &'f T: std::ops::Mul<Output = T>,
|
||||||
|
{
|
||||||
|
let mut out = Matrix::default();
|
||||||
|
self.matmul_into(other, &mut out);
|
||||||
|
out
|
||||||
|
}
|
||||||
|
pub fn matmul_into<const P: usize>(
|
||||||
|
&self,
|
||||||
|
other: &Matrix<T, N, P>,
|
||||||
|
out: &mut Matrix<T, M, P>,
|
||||||
|
) where
|
||||||
|
T: Default + core::ops::AddAssign<T>,
|
||||||
|
for<'f> &'f T: std::ops::Mul<Output = T>,
|
||||||
|
{
|
||||||
|
*out = Default::default();
|
||||||
|
for i in 0..M {
|
||||||
|
for j in 0..P {
|
||||||
|
for k in 0..N {
|
||||||
|
out[(i, j)] += &self[(i, k)] * &other[(k, j)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn iter_rows(
|
||||||
|
&self,
|
||||||
|
) -> impl ExactSizeIterator<Item = &[T; N]> + DoubleEndedIterator<Item = &[T; N]> {
|
||||||
|
(0..M).map(move |i| &self[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::{super::*, *};
|
||||||
|
#[test]
|
||||||
|
fn construct_copy_type() {
|
||||||
|
let _m0 = Matrix::<i32, 4, 3>::default();
|
||||||
|
let _m1: Matrix<u8, 8, 8> = Matrix::default();
|
||||||
|
|
||||||
|
let _m2 = Matrix::new([[1, 2], [3, 4]]);
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn construct_non_copy() {
|
||||||
|
let _m = Matrix::<String, 2, 1>::default();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul() {
|
||||||
|
let m1 = Matrix::new([[1_u8, 2, 3], [4, 5, 6]]);
|
||||||
|
let m2 = Matrix::new([[7_u8, 8, 9, 10], [11, 12, 13, 14], [15, 16, 17, 18]]);
|
||||||
|
|
||||||
|
let m3 = m1.matmul(&m2);
|
||||||
|
assert_eq!(m3, Matrix::new([[74, 80, 86, 92], [173, 188, 203, 218]]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
pub(crate) fn diff_op_1d(
|
pub(crate) fn diff_op_1d(
|
||||||
block: &[&[Float]],
|
block: &[&[Float]],
|
||||||
|
|
Loading…
Reference in New Issue