diff --git a/maxwell/Cargo.toml b/maxwell/Cargo.toml index e3d96cf..7797588 100644 --- a/maxwell/Cargo.toml +++ b/maxwell/Cargo.toml @@ -10,7 +10,7 @@ sparse = ["sbp/sparse", "sprs"] [dependencies] ndarray = "0.13.1" sbp = { path = "../sbp" } -sprs = { version = "0.7.1", optional = true } +sprs = { version = "0.9.0", optional = true, default-features = false } [dev-dependencies] criterion = "0.3.2" diff --git a/sbp/Cargo.toml b/sbp/Cargo.toml index 76d5b46..908a684 100644 --- a/sbp/Cargo.toml +++ b/sbp/Cargo.toml @@ -9,7 +9,7 @@ ndarray = { version = "0.13.1", features = ["approx"] } approx = "0.3.2" packed_simd = "0.3.3" rayon = { version = "1.3.0", optional = true } -sprs = { version = "0.7.1", optional = true } +sprs = { version = "0.9.0", optional = true, default-features = false } num-traits = "0.2.11" serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] } diff --git a/sbp/src/utils.rs b/sbp/src/utils.rs index dd59187..8d1f025 100644 --- a/sbp/src/utils.rs +++ b/sbp/src/utils.rs @@ -4,12 +4,10 @@ use crate::Float; mod jacobi; #[cfg(feature = "sparse")] pub use jacobi::*; -#[cfg(feature = "sparse")] -mod kronecker_product; -#[cfg(feature = "sparse")] -pub use kronecker_product::kronecker_product; #[cfg(feature = "serde")] use serde::{Deserialize, Serialize}; +#[cfg(feature = "sparse")] +pub use sprs::kronecker_product; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Copy, Clone, Debug, Default)] diff --git a/sbp/src/utils/kronecker_product.rs b/sbp/src/utils/kronecker_product.rs deleted file mode 100644 index 74bfa88..0000000 --- a/sbp/src/utils/kronecker_product.rs +++ /dev/null @@ -1,167 +0,0 @@ -/// Computes the sparse kronecker product -/// M = A \kron B -#[allow(non_snake_case)] -#[must_use] -pub fn kronecker_product< - N: num_traits::Num + Copy + Default, - I: sprs::SpIndex, - Iptr: sprs::SpIndex, ->( - A: sprs::CsMatViewI, - B: sprs::CsMatViewI, -) -> sprs::CsMatI { - use sprs::{CSC, CSR}; - match (A.storage(), B.storage()) { - (CSR, CSR) => { - let nnz = A.nnz() * B.nnz(); - let a_shape = A.shape(); - let b_shape = B.shape(); - let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); - let mut values = Vec::with_capacity(nnz); - let mut indices = Vec::with_capacity(nnz); - let mut indptr = Vec::with_capacity(shape.1 + 1); - - let mut element_count = Iptr::zero(); - indptr.push(element_count); - for a in A.outer_iterator() { - for b in B.outer_iterator() { - for (ai, &a) in a.iter() { - for (bi, &b) in b.iter() { - indices.push(I::from(ai * b_shape.1 + bi).unwrap()); - element_count += Iptr::one(); - values.push(a * b); - } - } - indptr.push(element_count); - } - } - let mat = sprs::CsMatBase::new(shape, indptr, indices, values); - debug_assert_eq!(mat.nnz(), nnz); - mat - } - (CSC, CSC) => { - let nnz = A.nnz() * B.nnz(); - let a_shape = A.shape(); - let b_shape = B.shape(); - let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); - let mut values = Vec::with_capacity(nnz); - let mut indices = Vec::with_capacity(nnz); - let mut indptr = Vec::with_capacity(shape.0 + 1); - - let mut element_count = Iptr::zero(); - indptr.push(element_count); - for a in A.outer_iterator() { - for b in B.outer_iterator() { - for (aj, &a) in a.iter() { - for (bj, &b) in b.iter() { - indices.push(I::from(aj * b_shape.0 + bj).unwrap()); - element_count += Iptr::one(); - values.push(a * b); - } - } - indptr.push(element_count); - } - } - let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values); - debug_assert_eq!(mat.nnz(), nnz); - mat - } - (CSR, CSC) => { - let nnz = A.nnz() * B.nnz(); - let a_shape = A.shape(); - let b_shape = B.shape(); - let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); - let mut mat = sprs::CsMatI::zero(shape); - - mat.reserve_nnz_exact(nnz); - for (aj, a) in A.outer_iterator().enumerate() { - for (bi, b) in B.outer_iterator().enumerate() { - for (ai, &a) in a.iter() { - for (bj, &b) in b.iter() { - let i = ai * b_shape.1 + bi; - let j = aj * b_shape.0 + bj; - mat.insert(j, i, a * b) - } - } - } - } - debug_assert_eq!(mat.nnz(), nnz); - mat - } - (CSC, CSR) => { - let nnz = A.nnz() * B.nnz(); - let a_shape = A.shape(); - let b_shape = B.shape(); - let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); - let mat = sprs::CsMatI::zero(shape); - let mut mat = mat.to_csc(); - - for (ai, a) in A.outer_iterator().enumerate() { - for (bj, b) in B.outer_iterator().enumerate() { - for (aj, &a) in a.iter() { - for (bi, &b) in b.iter() { - let i = ai * b_shape.1 + bi; - let j = aj * b_shape.0 + bj; - mat.insert(j, i, a * b) - } - } - } - } - debug_assert_eq!(mat.nnz(), nnz); - mat - } - } -} - -#[test] -fn test_kronecker_product() { - let mut a = sprs::TriMat::new((2, 3)); - a.add_triplet(0, 1, 2); - a.add_triplet(0, 2, 3); - a.add_triplet(1, 0, 6); - a.add_triplet(1, 2, 8); - let a = a.to_csr(); - - let mut b = sprs::TriMat::new((3, 2)); - b.add_triplet(0, 0, 1); - b.add_triplet(1, 0, 2); - b.add_triplet(2, 0, 3); - b.add_triplet(2, 1, -3); - let b = b.to_csr(); - - let check = |c: sprs::CsMatView| { - for (&n, (j, i)) in c.iter() { - match (j, i) { - (0, 2) => assert_eq!(n, 2), - (0, 4) => assert_eq!(n, 3), - (1, 2) => assert_eq!(n, 4), - (1, 4) => assert_eq!(n, 6), - (2, 2) => assert_eq!(n, 6), - (2, 3) => assert_eq!(n, -6), - (2, 4) => assert_eq!(n, 9), - (2, 5) => assert_eq!(n, -9), - (3, 0) => assert_eq!(n, 6), - (3, 4) => assert_eq!(n, 8), - (4, 0) => assert_eq!(n, 12), - (4, 4) => assert_eq!(n, 16), - (5, 0) => assert_eq!(n, 18), - (5, 1) => assert_eq!(n, -18), - (5, 4) => assert_eq!(n, 24), - (5, 5) => assert_eq!(n, -24), - _ => panic!("index ({},{}) should be 0, found {}", j, i, n), - } - } - }; - - let c = kronecker_product(a.view(), b.view()); - check(c.view()); - let b = b.to_csc(); - let c = kronecker_product(a.view(), b.view()); - check(c.view()); - let a = a.to_csc(); - let c = kronecker_product(a.view(), b.view()); - check(c.view()); - let b = b.to_csr(); - let c = kronecker_product(a.view(), b.view()); - check(c.view()); -}