diff --git a/sbp/src/operators/traditional4.rs b/sbp/src/operators/traditional4.rs index 9d0c559..f17d5cd 100644 --- a/sbp/src/operators/traditional4.rs +++ b/sbp/src/operators/traditional4.rs @@ -51,7 +51,7 @@ impl SbpOperator1d for SBP4 { } #[cfg(feature = "sparse")] fn h_matrix(&self, n: usize) -> sprs::CsMat { - super::h_matrix(Self::DIAG, n, self.is_h2()) + super::h_matrix(Self::HBLOCK, n, self.is_h2()) } } diff --git a/sbp/src/operators/traditional8.rs b/sbp/src/operators/traditional8.rs index e387796..2db2e58 100644 --- a/sbp/src/operators/traditional8.rs +++ b/sbp/src/operators/traditional8.rs @@ -55,7 +55,7 @@ impl SbpOperator1d for SBP8 { } #[cfg(feature = "sparse")] fn h_matrix(&self, n: usize) -> sprs::CsMat { - super::h_matrix(Self::DIAG, n, self.is_h2()) + super::h_matrix(Self::HBLOCK, n, self.is_h2()) } } diff --git a/sbp/src/operators/upwind4.rs b/sbp/src/operators/upwind4.rs index c11a122..d943dfc 100644 --- a/sbp/src/operators/upwind4.rs +++ b/sbp/src/operators/upwind4.rs @@ -164,7 +164,7 @@ diff_simd_col_7_47!(diss_simd_col, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true impl Upwind4 { #[rustfmt::skip] - const HBLOCK: &'static [Float] = &[ + pub const HBLOCK: &'static [Float] = &[ 49.0 / 144.0, 61.0 / 48.0, 41.0 / 48.0, 149.0 / 144.0 ]; #[rustfmt::skip] @@ -219,7 +219,7 @@ impl SbpOperator1d for Upwind4 { } #[cfg(feature = "sparse")] fn h_matrix(&self, n: usize) -> sprs::CsMat { - super::h_matrix(Self::DIAG, n, self.is_h2()) + super::h_matrix(Self::HBLOCK, n, self.is_h2()) } } diff --git a/sbp/src/operators/upwind4h2.rs b/sbp/src/operators/upwind4h2.rs index 8c65602..5c6ee61 100644 --- a/sbp/src/operators/upwind4h2.rs +++ b/sbp/src/operators/upwind4h2.rs @@ -69,7 +69,7 @@ impl SbpOperator1d for Upwind4h2 { } #[cfg(feature = "sparse")] fn h_matrix(&self, n: usize) -> sprs::CsMat { - super::h_matrix(Self::DIAG, n, self.is_h2()) + super::h_matrix(Self::HBLOCK, n, self.is_h2()) } } diff --git a/sbp/src/operators/upwind9.rs b/sbp/src/operators/upwind9.rs index 4869037..62722b4 100644 --- a/sbp/src/operators/upwind9.rs +++ b/sbp/src/operators/upwind9.rs @@ -74,7 +74,7 @@ impl SbpOperator1d for Upwind9 { } #[cfg(feature = "sparse")] fn h_matrix(&self, n: usize) -> sprs::CsMat { - super::h_matrix(Self::DIAG, n, self.is_h2()) + super::h_matrix(Self::HBLOCK, n, self.is_h2()) } } diff --git a/sbp/src/operators/upwind9h2.rs b/sbp/src/operators/upwind9h2.rs index 26b2468..15c3290 100644 --- a/sbp/src/operators/upwind9h2.rs +++ b/sbp/src/operators/upwind9h2.rs @@ -77,7 +77,7 @@ impl SbpOperator1d for Upwind9h2 { } #[cfg(feature = "sparse")] fn h_matrix(&self, n: usize) -> sprs::CsMat { - super::h_matrix(Self::DIAG, n, self.is_h2()) + super::h_matrix(Self::HBLOCK, n, self.is_h2()) } } diff --git a/sbp/src/utils/outer_product.rs b/sbp/src/utils/outer_product.rs index 47012d1..39c1ec6 100644 --- a/sbp/src/utils/outer_product.rs +++ b/sbp/src/utils/outer_product.rs @@ -10,8 +10,9 @@ pub fn sparse_sparse_outer_product< A: sprs::CsMatViewI, B: sprs::CsMatViewI, ) -> sprs::CsMatI { + use sprs::{CSC, CSR}; match (A.storage(), B.storage()) { - (sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSR) => { + (CSR, CSR) => { let nnz = A.nnz() * B.nnz(); let a_shape = A.shape(); let b_shape = B.shape(); @@ -32,7 +33,7 @@ pub fn sparse_sparse_outer_product< debug_assert_eq!(mat.nnz(), nnz); mat } - (sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSC) => { + (CSC, CSC) => { let nnz = A.nnz() * B.nnz(); let a_shape = A.shape(); let b_shape = B.shape(); @@ -54,7 +55,7 @@ pub fn sparse_sparse_outer_product< debug_assert_eq!(mat.nnz(), nnz); mat } - (sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSC) => { + (CSR, CSC) => { let nnz = A.nnz() * B.nnz(); let a_shape = A.shape(); let b_shape = B.shape(); @@ -75,7 +76,7 @@ pub fn sparse_sparse_outer_product< debug_assert_eq!(mat.nnz(), nnz); mat } - (sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSR) => { + (CSC, CSR) => { let nnz = A.nnz() * B.nnz(); let a_shape = A.shape(); let b_shape = B.shape(); @@ -83,10 +84,10 @@ pub fn sparse_sparse_outer_product< let mat = sprs::CsMatI::zero(shape); let mut mat = mat.to_csc(); - for (aj, a) in A.outer_iterator().enumerate() { - for (bi, b) in B.outer_iterator().enumerate() { - for (ai, &a) in a.iter() { - for (bj, &b) in b.iter() { + for (ai, a) in A.outer_iterator().enumerate() { + for (bj, b) in B.outer_iterator().enumerate() { + for (aj, &a) in a.iter() { + for (bi, &b) in b.iter() { let i = ai * b_shape.1 + bi; let j = aj * b_shape.0 + bj; mat.insert(j, i, a * b) @@ -116,82 +117,39 @@ fn test_outer_product() { b.add_triplet(2, 1, -3); let b = b.to_csr(); - let c = sparse_sparse_outer_product(a.view(), b.view()); - for (&n, (j, i)) in c.iter() { - match (j, i) { - (0, 2) => assert_eq!(n, 2), - (0, 4) => assert_eq!(n, 3), - (1, 2) => assert_eq!(n, 4), - (1, 4) => assert_eq!(n, 6), - (2, 2) => assert_eq!(n, 6), - (2, 3) => assert_eq!(n, -6), - (2, 4) => assert_eq!(n, 9), - (2, 5) => assert_eq!(n, -9), - (3, 0) => assert_eq!(n, 6), - (3, 4) => assert_eq!(n, 8), - (4, 0) => assert_eq!(n, 12), - (4, 4) => assert_eq!(n, 16), - (5, 0) => assert_eq!(n, 18), - (5, 1) => assert_eq!(n, -18), - (5, 4) => assert_eq!(n, 24), - (5, 5) => assert_eq!(n, -24), - _ => panic!("index ({},{}) should be 0, found {}", j, i, n), + let check = |c: sprs::CsMatView| { + for (&n, (j, i)) in c.iter() { + match (j, i) { + (0, 2) => assert_eq!(n, 2), + (0, 4) => assert_eq!(n, 3), + (1, 2) => assert_eq!(n, 4), + (1, 4) => assert_eq!(n, 6), + (2, 2) => assert_eq!(n, 6), + (2, 3) => assert_eq!(n, -6), + (2, 4) => assert_eq!(n, 9), + (2, 5) => assert_eq!(n, -9), + (3, 0) => assert_eq!(n, 6), + (3, 4) => assert_eq!(n, 8), + (4, 0) => assert_eq!(n, 12), + (4, 4) => assert_eq!(n, 16), + (5, 0) => assert_eq!(n, 18), + (5, 1) => assert_eq!(n, -18), + (5, 4) => assert_eq!(n, 24), + (5, 5) => assert_eq!(n, -24), + _ => panic!("index ({},{}) should be 0, found {}", j, i, n), + } } - } -} + }; -#[test] -fn test_outer_product_csc() { - let mut a = sprs::TriMat::new((2, 3)); - a.add_triplet(0, 1, 2); - a.add_triplet(0, 2, 3); - a.add_triplet(1, 0, 6); - a.add_triplet(1, 2, 8); - let a = a.to_csc(); - - let mut b = sprs::TriMat::new((3, 2)); - b.add_triplet(0, 0, 1); - b.add_triplet(1, 0, 2); - b.add_triplet(2, 0, 3); - b.add_triplet(2, 1, -3); + let c = sparse_sparse_outer_product(a.view(), b.view()); + check(c.view()); let b = b.to_csc(); - let c = sparse_sparse_outer_product(a.view(), b.view()); - for (&n, (j, i)) in c.iter() { - match (j, i) { - (0, 2) => assert_eq!(n, 2), - (0, 4) => assert_eq!(n, 3), - (1, 2) => assert_eq!(n, 4), - (1, 4) => assert_eq!(n, 6), - (2, 2) => assert_eq!(n, 6), - (2, 3) => assert_eq!(n, -6), - (2, 4) => assert_eq!(n, 9), - (2, 5) => assert_eq!(n, -9), - (3, 0) => assert_eq!(n, 6), - (3, 4) => assert_eq!(n, 8), - (4, 0) => assert_eq!(n, 12), - (4, 4) => assert_eq!(n, 16), - (5, 0) => assert_eq!(n, 18), - (5, 1) => assert_eq!(n, -18), - (5, 4) => assert_eq!(n, 24), - (5, 5) => assert_eq!(n, -24), - _ => panic!("index ({},{}) should be 0, found {}", j, i, n), - } - } -} - -#[test] -fn test_outer_product_2() { - let mut e0 = sprs::CsMat::zero((10, 1)); - e0.insert(0, 0, 1); - let mut en = sprs::CsMat::zero((11, 1)); - en.insert(10, 0, 1); - - let v = sparse_sparse_outer_product(e0.view(), en.transpose_view()); - for (&val, (j, i)) in v.iter() { - match (j, i) { - (0, 10) => assert_eq!(val, 1), - _ => panic!("Unexpected element: ({},{}): {}", j, i, val), - } - } + check(c.view()); + let a = a.to_csc(); + let c = sparse_sparse_outer_product(a.view(), b.view()); + check(c.view()); + let b = b.to_csr(); + let c = sparse_sparse_outer_product(a.view(), b.view()); + check(c.view()); }