add sparse matrix creating to all diff ops
This commit is contained in:
		@@ -10,6 +10,7 @@ approx = "0.3.2"
 | 
				
			|||||||
packed_simd = "0.3.3"
 | 
					packed_simd = "0.3.3"
 | 
				
			||||||
rayon = { version = "1.3.0", optional = true }
 | 
					rayon = { version = "1.3.0", optional = true }
 | 
				
			||||||
sprs = { version = "0.7.1", optional = true }
 | 
					sprs = { version = "0.7.1", optional = true }
 | 
				
			||||||
 | 
					num-traits = "0.2.11"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[features]
 | 
					[features]
 | 
				
			||||||
# Use f32 as precision, default is f64
 | 
					# Use f32 as precision, default is f64
 | 
				
			||||||
@@ -17,7 +18,7 @@ f32 = []
 | 
				
			|||||||
sparse = ["sprs"]
 | 
					sparse = ["sprs"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[dev-dependencies]
 | 
					[dev-dependencies]
 | 
				
			||||||
criterion = "0.3.1"
 | 
					criterion = "0.3.2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[bench]]
 | 
					[[bench]]
 | 
				
			||||||
name = "sbpoperators"
 | 
					name = "sbpoperators"
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -12,13 +12,9 @@ pub trait SbpOperator1d: Send + Sync {
 | 
				
			|||||||
        false
 | 
					        false
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    #[cfg(feature = "sparse")]
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
					    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float>;
 | 
				
			||||||
        unimplemented!()
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
    #[cfg(feature = "sparse")]
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
					    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float>;
 | 
				
			||||||
        unimplemented!()
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub trait SbpOperator2d: Send + Sync {
 | 
					pub trait SbpOperator2d: Send + Sync {
 | 
				
			||||||
@@ -30,6 +26,9 @@ pub trait SbpOperator2d: Send + Sync {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    fn is_h2xi(&self) -> bool;
 | 
					    fn is_h2xi(&self) -> bool;
 | 
				
			||||||
    fn is_h2eta(&self) -> bool;
 | 
					    fn is_h2eta(&self) -> bool;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fn op_xi(&self) -> &dyn SbpOperator1d;
 | 
				
			||||||
 | 
					    fn op_eta(&self) -> &dyn SbpOperator1d;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (&SBPeta, &SBPxi) {
 | 
					impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (&SBPeta, &SBPxi) {
 | 
				
			||||||
@@ -55,6 +54,13 @@ impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (&SBPeta, &S
 | 
				
			|||||||
    fn is_h2eta(&self) -> bool {
 | 
					    fn is_h2eta(&self) -> bool {
 | 
				
			||||||
        self.0.is_h2()
 | 
					        self.0.is_h2()
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fn op_xi(&self) -> &dyn SbpOperator1d {
 | 
				
			||||||
 | 
					        self.1
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    fn op_eta(&self) -> &dyn SbpOperator1d {
 | 
				
			||||||
 | 
					        self.0
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBP: SbpOperator1d + Copy> SbpOperator2d for SBP {
 | 
					impl<SBP: SbpOperator1d + Copy> SbpOperator2d for SBP {
 | 
				
			||||||
@@ -76,6 +82,13 @@ impl<SBP: SbpOperator1d + Copy> SbpOperator2d for SBP {
 | 
				
			|||||||
    fn is_h2eta(&self) -> bool {
 | 
					    fn is_h2eta(&self) -> bool {
 | 
				
			||||||
        <(&SBP, &SBP) as SbpOperator2d>::is_h2eta(&(self, self))
 | 
					        <(&SBP, &SBP) as SbpOperator2d>::is_h2eta(&(self, self))
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fn op_xi(&self) -> &dyn SbpOperator1d {
 | 
				
			||||||
 | 
					        self
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    fn op_eta(&self) -> &dyn SbpOperator1d {
 | 
				
			||||||
 | 
					        self
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub trait UpwindOperator1d: SbpOperator1d + Send + Sync {
 | 
					pub trait UpwindOperator1d: SbpOperator1d + Send + Sync {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -38,6 +38,21 @@ impl SbpOperator1d for SBP4 {
 | 
				
			|||||||
    fn h(&self) -> &'static [Float] {
 | 
					    fn h(&self) -> &'static [Float] {
 | 
				
			||||||
        Self::HBLOCK
 | 
					        Self::HBLOCK
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::sparse_from_block(
 | 
				
			||||||
 | 
					            Self::BLOCK,
 | 
				
			||||||
 | 
					            Self::DIAG,
 | 
				
			||||||
 | 
					            super::Symmetry::AntiSymmetric,
 | 
				
			||||||
 | 
					            super::OperatorType::Normal,
 | 
				
			||||||
 | 
					            n,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::h_matrix(Self::DIAG, n, self.is_h2())
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP4) {
 | 
					impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP4) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -42,6 +42,21 @@ impl SbpOperator1d for SBP8 {
 | 
				
			|||||||
    fn h(&self) -> &'static [Float] {
 | 
					    fn h(&self) -> &'static [Float] {
 | 
				
			||||||
        Self::HBLOCK
 | 
					        Self::HBLOCK
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::sparse_from_block(
 | 
				
			||||||
 | 
					            Self::BLOCK,
 | 
				
			||||||
 | 
					            Self::DIAG,
 | 
				
			||||||
 | 
					            super::Symmetry::AntiSymmetric,
 | 
				
			||||||
 | 
					            super::OperatorType::Normal,
 | 
				
			||||||
 | 
					            n,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::h_matrix(Self::DIAG, n, self.is_h2())
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP8) {
 | 
					impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP8) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -56,6 +56,21 @@ impl SbpOperator1d for Upwind4h2 {
 | 
				
			|||||||
    fn is_h2(&self) -> bool {
 | 
					    fn is_h2(&self) -> bool {
 | 
				
			||||||
        true
 | 
					        true
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::sparse_from_block(
 | 
				
			||||||
 | 
					            Self::BLOCK,
 | 
				
			||||||
 | 
					            Self::DIAG,
 | 
				
			||||||
 | 
					            super::Symmetry::AntiSymmetric,
 | 
				
			||||||
 | 
					            super::OperatorType::H2,
 | 
				
			||||||
 | 
					            n,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::h_matrix(Self::DIAG, n, self.is_h2())
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4h2) {
 | 
					impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4h2) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -61,6 +61,21 @@ impl SbpOperator1d for Upwind9 {
 | 
				
			|||||||
    fn h(&self) -> &'static [Float] {
 | 
					    fn h(&self) -> &'static [Float] {
 | 
				
			||||||
        Self::HBLOCK
 | 
					        Self::HBLOCK
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::sparse_from_block(
 | 
				
			||||||
 | 
					            Self::BLOCK,
 | 
				
			||||||
 | 
					            Self::DIAG,
 | 
				
			||||||
 | 
					            super::Symmetry::AntiSymmetric,
 | 
				
			||||||
 | 
					            super::OperatorType::Normal,
 | 
				
			||||||
 | 
					            n,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::h_matrix(Self::DIAG, n, self.is_h2())
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9) {
 | 
					impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -64,6 +64,21 @@ impl SbpOperator1d for Upwind9h2 {
 | 
				
			|||||||
    fn is_h2(&self) -> bool {
 | 
					    fn is_h2(&self) -> bool {
 | 
				
			||||||
        true
 | 
					        true
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::sparse_from_block(
 | 
				
			||||||
 | 
					            Self::BLOCK,
 | 
				
			||||||
 | 
					            Self::DIAG,
 | 
				
			||||||
 | 
					            super::Symmetry::AntiSymmetric,
 | 
				
			||||||
 | 
					            super::OperatorType::H2,
 | 
				
			||||||
 | 
					            n,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					    #[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					    fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
 | 
				
			||||||
 | 
					        super::h_matrix(Self::DIAG, n, self.is_h2())
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9h2) {
 | 
					impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9h2) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,14 @@
 | 
				
			|||||||
use crate::Float;
 | 
					use crate::Float;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					mod jacobi;
 | 
				
			||||||
 | 
					#[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					pub use jacobi::*;
 | 
				
			||||||
 | 
					#[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					mod outer_product;
 | 
				
			||||||
 | 
					#[cfg(feature = "sparse")]
 | 
				
			||||||
 | 
					pub use outer_product::sparse_sparse_outer_product;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pub struct Direction<T> {
 | 
					pub struct Direction<T> {
 | 
				
			||||||
    pub north: T,
 | 
					    pub north: T,
 | 
				
			||||||
    pub south: T,
 | 
					    pub south: T,
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										130
									
								
								sbp/src/utils/jacobi.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										130
									
								
								sbp/src/utils/jacobi.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,130 @@
 | 
				
			|||||||
 | 
					use crate::Float;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					/// A x = b
 | 
				
			||||||
 | 
					/// with A and b known
 | 
				
			||||||
 | 
					/// x should contain a first guess of
 | 
				
			||||||
 | 
					pub fn jacobi_method(
 | 
				
			||||||
 | 
					    a: sprs::CsMatView<Float>,
 | 
				
			||||||
 | 
					    b: &[Float],
 | 
				
			||||||
 | 
					    x: &mut [Float],
 | 
				
			||||||
 | 
					    tmp: &mut [Float],
 | 
				
			||||||
 | 
					    iter_count: usize,
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					    for _ in 0..iter_count {
 | 
				
			||||||
 | 
					        jacobi_step(a, b, x, tmp);
 | 
				
			||||||
 | 
					        x.copy_from_slice(tmp);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					pub fn jacobi_step(a: sprs::CsMatView<Float>, b: &[Float], x0: &[Float], x: &mut [Float]) {
 | 
				
			||||||
 | 
					    let n = a.shape().0;
 | 
				
			||||||
 | 
					    assert_eq!(n, a.shape().1);
 | 
				
			||||||
 | 
					    let b = &b[..n];
 | 
				
			||||||
 | 
					    let x0 = &x0[..n];
 | 
				
			||||||
 | 
					    let x = &mut x[..n];
 | 
				
			||||||
 | 
					    for (((i, ai), xi), &bi) in a
 | 
				
			||||||
 | 
					        .outer_iterator()
 | 
				
			||||||
 | 
					        .enumerate()
 | 
				
			||||||
 | 
					        .zip(x.iter_mut())
 | 
				
			||||||
 | 
					        .zip(b.iter())
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        let mut summa = 0.0;
 | 
				
			||||||
 | 
					        let mut aii = None;
 | 
				
			||||||
 | 
					        for (j, aij) in ai.iter() {
 | 
				
			||||||
 | 
					            if i == j {
 | 
				
			||||||
 | 
					                aii = Some(aij);
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            summa += aij * x0[j];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        *xi = 1.0 / aii.unwrap() * (bi - summa);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[test]
 | 
				
			||||||
 | 
					fn test_jacobi_2x2() {
 | 
				
			||||||
 | 
					    let mut a = sprs::CsMat::zero((2, 2));
 | 
				
			||||||
 | 
					    a.insert(0, 0, 2.0);
 | 
				
			||||||
 | 
					    a.insert(0, 1, 1.0);
 | 
				
			||||||
 | 
					    a.insert(1, 0, 5.0);
 | 
				
			||||||
 | 
					    a.insert(1, 1, 7.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let b = ndarray::arr1(&[11.0, 13.0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let mut x0 = ndarray::arr1(&[1.0; 2]);
 | 
				
			||||||
 | 
					    let mut tmp = x0.clone();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    jacobi_method(
 | 
				
			||||||
 | 
					        a.view(),
 | 
				
			||||||
 | 
					        b.as_slice().unwrap(),
 | 
				
			||||||
 | 
					        x0.as_slice_mut().unwrap(),
 | 
				
			||||||
 | 
					        tmp.as_slice_mut().unwrap(),
 | 
				
			||||||
 | 
					        25,
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    approx::assert_abs_diff_eq!(x0, ndarray::arr1(&[7.111, -3.222]), epsilon = 1e-2);
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[test]
 | 
				
			||||||
 | 
					fn test_jacobi_4x4() {
 | 
				
			||||||
 | 
					    let mut a = sprs::CsMat::zero((4, 4));
 | 
				
			||||||
 | 
					    a.insert(0, 0, 10.0);
 | 
				
			||||||
 | 
					    a.insert(0, 1, -1.0);
 | 
				
			||||||
 | 
					    a.insert(0, 2, 2.0);
 | 
				
			||||||
 | 
					    a.insert(1, 0, -1.0);
 | 
				
			||||||
 | 
					    a.insert(1, 1, 11.0);
 | 
				
			||||||
 | 
					    a.insert(1, 2, -1.0);
 | 
				
			||||||
 | 
					    a.insert(1, 3, 3.0);
 | 
				
			||||||
 | 
					    a.insert(2, 0, 2.0);
 | 
				
			||||||
 | 
					    a.insert(2, 1, -1.0);
 | 
				
			||||||
 | 
					    a.insert(2, 2, 10.0);
 | 
				
			||||||
 | 
					    a.insert(2, 3, -1.0);
 | 
				
			||||||
 | 
					    a.insert(3, 1, 3.0);
 | 
				
			||||||
 | 
					    a.insert(3, 2, -1.0);
 | 
				
			||||||
 | 
					    a.insert(3, 3, 8.0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let b = ndarray::arr1(&[6.0, 25.0, -11.0, 15.0]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let mut x0 = ndarray::Array::zeros(b.len());
 | 
				
			||||||
 | 
					    let mut tmp = x0.clone();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for iter in 0.. {
 | 
				
			||||||
 | 
					        jacobi_step(
 | 
				
			||||||
 | 
					            a.view(),
 | 
				
			||||||
 | 
					            b.as_slice().unwrap(),
 | 
				
			||||||
 | 
					            x0.as_slice().unwrap(),
 | 
				
			||||||
 | 
					            tmp.as_slice_mut().unwrap(),
 | 
				
			||||||
 | 
					        );
 | 
				
			||||||
 | 
					        x0.as_slice_mut()
 | 
				
			||||||
 | 
					            .unwrap()
 | 
				
			||||||
 | 
					            .copy_from_slice(tmp.as_slice().unwrap());
 | 
				
			||||||
 | 
					        match iter {
 | 
				
			||||||
 | 
					            0 => approx::assert_abs_diff_eq!(
 | 
				
			||||||
 | 
					                x0,
 | 
				
			||||||
 | 
					                ndarray::arr1(&[0.6, 2.27272, -1.1, 1.875]),
 | 
				
			||||||
 | 
					                epsilon = 1e-4
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            1 => approx::assert_abs_diff_eq!(
 | 
				
			||||||
 | 
					                x0,
 | 
				
			||||||
 | 
					                ndarray::arr1(&[1.04727, 1.7159, -0.80522, 0.88522]),
 | 
				
			||||||
 | 
					                epsilon = 1e-4
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            2 => approx::assert_abs_diff_eq!(
 | 
				
			||||||
 | 
					                x0,
 | 
				
			||||||
 | 
					                ndarray::arr1(&[0.93263, 2.05330, -1.0493, 1.13088]),
 | 
				
			||||||
 | 
					                epsilon = 1e-4
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            3 => approx::assert_abs_diff_eq!(
 | 
				
			||||||
 | 
					                x0,
 | 
				
			||||||
 | 
					                ndarray::arr1(&[1.01519, 1.95369, -0.9681, 0.97384]),
 | 
				
			||||||
 | 
					                epsilon = 1e-4
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            4 => approx::assert_abs_diff_eq!(
 | 
				
			||||||
 | 
					                x0,
 | 
				
			||||||
 | 
					                ndarray::arr1(&[0.98899, 2.0114, -1.0102, 1.02135]),
 | 
				
			||||||
 | 
					                epsilon = 1e-4
 | 
				
			||||||
 | 
					            ),
 | 
				
			||||||
 | 
					            _ => break,
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										197
									
								
								sbp/src/utils/outer_product.rs
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										197
									
								
								sbp/src/utils/outer_product.rs
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,197 @@
 | 
				
			|||||||
 | 
					/// Computes the sparse kronecker product
 | 
				
			||||||
 | 
					/// M = A \kron B
 | 
				
			||||||
 | 
					#[allow(non_snake_case)]
 | 
				
			||||||
 | 
					#[must_use]
 | 
				
			||||||
 | 
					pub fn sparse_sparse_outer_product<
 | 
				
			||||||
 | 
					    N: num_traits::Num + Copy + Default,
 | 
				
			||||||
 | 
					    I: sprs::SpIndex,
 | 
				
			||||||
 | 
					    Iptr: sprs::SpIndex,
 | 
				
			||||||
 | 
					>(
 | 
				
			||||||
 | 
					    A: sprs::CsMatViewI<N, I, Iptr>,
 | 
				
			||||||
 | 
					    B: sprs::CsMatViewI<N, I, Iptr>,
 | 
				
			||||||
 | 
					) -> sprs::CsMatI<N, I, Iptr> {
 | 
				
			||||||
 | 
					    match (A.storage(), B.storage()) {
 | 
				
			||||||
 | 
					        (sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSR) => {
 | 
				
			||||||
 | 
					            let nnz = A.nnz() * B.nnz();
 | 
				
			||||||
 | 
					            let a_shape = A.shape();
 | 
				
			||||||
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
 | 
					            let mut mat = sprs::CsMatI::zero(shape);
 | 
				
			||||||
 | 
					            mat.reserve_nnz_exact(nnz);
 | 
				
			||||||
 | 
					            for (aj, a) in A.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                for (bj, b) in B.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                    for (ai, &a) in a.iter() {
 | 
				
			||||||
 | 
					                        for (bi, &b) in b.iter() {
 | 
				
			||||||
 | 
					                            let i = ai * b_shape.1 + bi;
 | 
				
			||||||
 | 
					                            let j = aj * b_shape.0 + bj;
 | 
				
			||||||
 | 
					                            mat.insert(j, i, a * b)
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            debug_assert_eq!(mat.nnz(), nnz);
 | 
				
			||||||
 | 
					            mat
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        (sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSC) => {
 | 
				
			||||||
 | 
					            let nnz = A.nnz() * B.nnz();
 | 
				
			||||||
 | 
					            let a_shape = A.shape();
 | 
				
			||||||
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
 | 
					            let mat = sprs::CsMatI::zero(shape);
 | 
				
			||||||
 | 
					            let mut mat = mat.to_csc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for (ai, a) in A.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                for (bi, b) in B.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                    for (aj, &a) in a.iter() {
 | 
				
			||||||
 | 
					                        for (bj, &b) in b.iter() {
 | 
				
			||||||
 | 
					                            let i = ai * b_shape.1 + bi;
 | 
				
			||||||
 | 
					                            let j = aj * b_shape.0 + bj;
 | 
				
			||||||
 | 
					                            mat.insert(j, i, a * b)
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            debug_assert_eq!(mat.nnz(), nnz);
 | 
				
			||||||
 | 
					            mat
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        (sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSC) => {
 | 
				
			||||||
 | 
					            let nnz = A.nnz() * B.nnz();
 | 
				
			||||||
 | 
					            let a_shape = A.shape();
 | 
				
			||||||
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
 | 
					            let mut mat = sprs::CsMatI::zero(shape);
 | 
				
			||||||
 | 
					            mat.reserve_nnz_exact(nnz);
 | 
				
			||||||
 | 
					            for (aj, a) in A.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                for (bi, b) in B.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                    for (ai, &a) in a.iter() {
 | 
				
			||||||
 | 
					                        for (bj, &b) in b.iter() {
 | 
				
			||||||
 | 
					                            let i = ai * b_shape.1 + bi;
 | 
				
			||||||
 | 
					                            let j = aj * b_shape.0 + bj;
 | 
				
			||||||
 | 
					                            mat.insert(j, i, a * b)
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            debug_assert_eq!(mat.nnz(), nnz);
 | 
				
			||||||
 | 
					            mat
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        (sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSR) => {
 | 
				
			||||||
 | 
					            let nnz = A.nnz() * B.nnz();
 | 
				
			||||||
 | 
					            let a_shape = A.shape();
 | 
				
			||||||
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
 | 
					            let mat = sprs::CsMatI::zero(shape);
 | 
				
			||||||
 | 
					            let mut mat = mat.to_csc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for (aj, a) in A.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                for (bi, b) in B.outer_iterator().enumerate() {
 | 
				
			||||||
 | 
					                    for (ai, &a) in a.iter() {
 | 
				
			||||||
 | 
					                        for (bj, &b) in b.iter() {
 | 
				
			||||||
 | 
					                            let i = ai * b_shape.1 + bi;
 | 
				
			||||||
 | 
					                            let j = aj * b_shape.0 + bj;
 | 
				
			||||||
 | 
					                            mat.insert(j, i, a * b)
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            debug_assert_eq!(mat.nnz(), nnz);
 | 
				
			||||||
 | 
					            mat
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[test]
 | 
				
			||||||
 | 
					fn test_outer_product() {
 | 
				
			||||||
 | 
					    let mut a = sprs::TriMat::new((2, 3));
 | 
				
			||||||
 | 
					    a.add_triplet(0, 1, 2);
 | 
				
			||||||
 | 
					    a.add_triplet(0, 2, 3);
 | 
				
			||||||
 | 
					    a.add_triplet(1, 0, 6);
 | 
				
			||||||
 | 
					    a.add_triplet(1, 2, 8);
 | 
				
			||||||
 | 
					    let a = a.to_csr();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let mut b = sprs::TriMat::new((3, 2));
 | 
				
			||||||
 | 
					    b.add_triplet(0, 0, 1);
 | 
				
			||||||
 | 
					    b.add_triplet(1, 0, 2);
 | 
				
			||||||
 | 
					    b.add_triplet(2, 0, 3);
 | 
				
			||||||
 | 
					    b.add_triplet(2, 1, -3);
 | 
				
			||||||
 | 
					    let b = b.to_csr();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let c = sparse_sparse_outer_product(a.view(), b.view());
 | 
				
			||||||
 | 
					    for (&n, (j, i)) in c.iter() {
 | 
				
			||||||
 | 
					        match (j, i) {
 | 
				
			||||||
 | 
					            (0, 2) => assert_eq!(n, 2),
 | 
				
			||||||
 | 
					            (0, 4) => assert_eq!(n, 3),
 | 
				
			||||||
 | 
					            (1, 2) => assert_eq!(n, 4),
 | 
				
			||||||
 | 
					            (1, 4) => assert_eq!(n, 6),
 | 
				
			||||||
 | 
					            (2, 2) => assert_eq!(n, 6),
 | 
				
			||||||
 | 
					            (2, 3) => assert_eq!(n, -6),
 | 
				
			||||||
 | 
					            (2, 4) => assert_eq!(n, 9),
 | 
				
			||||||
 | 
					            (2, 5) => assert_eq!(n, -9),
 | 
				
			||||||
 | 
					            (3, 0) => assert_eq!(n, 6),
 | 
				
			||||||
 | 
					            (3, 4) => assert_eq!(n, 8),
 | 
				
			||||||
 | 
					            (4, 0) => assert_eq!(n, 12),
 | 
				
			||||||
 | 
					            (4, 4) => assert_eq!(n, 16),
 | 
				
			||||||
 | 
					            (5, 0) => assert_eq!(n, 18),
 | 
				
			||||||
 | 
					            (5, 1) => assert_eq!(n, -18),
 | 
				
			||||||
 | 
					            (5, 4) => assert_eq!(n, 24),
 | 
				
			||||||
 | 
					            (5, 5) => assert_eq!(n, -24),
 | 
				
			||||||
 | 
					            _ => panic!("index ({},{}) should be 0, found {}", j, i, n),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[test]
 | 
				
			||||||
 | 
					fn test_outer_product_csc() {
 | 
				
			||||||
 | 
					    let mut a = sprs::TriMat::new((2, 3));
 | 
				
			||||||
 | 
					    a.add_triplet(0, 1, 2);
 | 
				
			||||||
 | 
					    a.add_triplet(0, 2, 3);
 | 
				
			||||||
 | 
					    a.add_triplet(1, 0, 6);
 | 
				
			||||||
 | 
					    a.add_triplet(1, 2, 8);
 | 
				
			||||||
 | 
					    let a = a.to_csc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let mut b = sprs::TriMat::new((3, 2));
 | 
				
			||||||
 | 
					    b.add_triplet(0, 0, 1);
 | 
				
			||||||
 | 
					    b.add_triplet(1, 0, 2);
 | 
				
			||||||
 | 
					    b.add_triplet(2, 0, 3);
 | 
				
			||||||
 | 
					    b.add_triplet(2, 1, -3);
 | 
				
			||||||
 | 
					    let b = b.to_csc();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let c = sparse_sparse_outer_product(a.view(), b.view());
 | 
				
			||||||
 | 
					    for (&n, (j, i)) in c.iter() {
 | 
				
			||||||
 | 
					        match (j, i) {
 | 
				
			||||||
 | 
					            (0, 2) => assert_eq!(n, 2),
 | 
				
			||||||
 | 
					            (0, 4) => assert_eq!(n, 3),
 | 
				
			||||||
 | 
					            (1, 2) => assert_eq!(n, 4),
 | 
				
			||||||
 | 
					            (1, 4) => assert_eq!(n, 6),
 | 
				
			||||||
 | 
					            (2, 2) => assert_eq!(n, 6),
 | 
				
			||||||
 | 
					            (2, 3) => assert_eq!(n, -6),
 | 
				
			||||||
 | 
					            (2, 4) => assert_eq!(n, 9),
 | 
				
			||||||
 | 
					            (2, 5) => assert_eq!(n, -9),
 | 
				
			||||||
 | 
					            (3, 0) => assert_eq!(n, 6),
 | 
				
			||||||
 | 
					            (3, 4) => assert_eq!(n, 8),
 | 
				
			||||||
 | 
					            (4, 0) => assert_eq!(n, 12),
 | 
				
			||||||
 | 
					            (4, 4) => assert_eq!(n, 16),
 | 
				
			||||||
 | 
					            (5, 0) => assert_eq!(n, 18),
 | 
				
			||||||
 | 
					            (5, 1) => assert_eq!(n, -18),
 | 
				
			||||||
 | 
					            (5, 4) => assert_eq!(n, 24),
 | 
				
			||||||
 | 
					            (5, 5) => assert_eq!(n, -24),
 | 
				
			||||||
 | 
					            _ => panic!("index ({},{}) should be 0, found {}", j, i, n),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[test]
 | 
				
			||||||
 | 
					fn test_outer_product_2() {
 | 
				
			||||||
 | 
					    let mut e0 = sprs::CsMat::zero((10, 1));
 | 
				
			||||||
 | 
					    e0.insert(0, 0, 1);
 | 
				
			||||||
 | 
					    let mut en = sprs::CsMat::zero((11, 1));
 | 
				
			||||||
 | 
					    en.insert(10, 0, 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let v = sparse_sparse_outer_product(e0.view(), en.transpose_view());
 | 
				
			||||||
 | 
					    for (&val, (j, i)) in v.iter() {
 | 
				
			||||||
 | 
					        match (j, i) {
 | 
				
			||||||
 | 
					            (0, 10) => assert_eq!(val, 1),
 | 
				
			||||||
 | 
					            _ => panic!("Unexpected element: ({},{}): {}", j, i, val),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
		Reference in New Issue
	
	Block a user