update sprs
This commit is contained in:
parent
c241603e44
commit
3c2cfe27cd
|
@ -10,7 +10,7 @@ sparse = ["sbp/sparse", "sprs"]
|
||||||
[dependencies]
|
[dependencies]
|
||||||
ndarray = "0.13.1"
|
ndarray = "0.13.1"
|
||||||
sbp = { path = "../sbp" }
|
sbp = { path = "../sbp" }
|
||||||
sprs = { version = "0.7.1", optional = true }
|
sprs = { version = "0.9.0", optional = true, default-features = false }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.3.2"
|
criterion = "0.3.2"
|
||||||
|
|
|
@ -9,7 +9,7 @@ ndarray = { version = "0.13.1", features = ["approx"] }
|
||||||
approx = "0.3.2"
|
approx = "0.3.2"
|
||||||
packed_simd = "0.3.3"
|
packed_simd = "0.3.3"
|
||||||
rayon = { version = "1.3.0", optional = true }
|
rayon = { version = "1.3.0", optional = true }
|
||||||
sprs = { version = "0.7.1", optional = true }
|
sprs = { version = "0.9.0", optional = true, default-features = false }
|
||||||
num-traits = "0.2.11"
|
num-traits = "0.2.11"
|
||||||
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
serde = { version = "1.0.115", optional = true, default-features = false, features = ["derive"] }
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,10 @@ use crate::Float;
|
||||||
mod jacobi;
|
mod jacobi;
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
pub use jacobi::*;
|
pub use jacobi::*;
|
||||||
#[cfg(feature = "sparse")]
|
|
||||||
mod kronecker_product;
|
|
||||||
#[cfg(feature = "sparse")]
|
|
||||||
pub use kronecker_product::kronecker_product;
|
|
||||||
#[cfg(feature = "serde")]
|
#[cfg(feature = "serde")]
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
pub use sprs::kronecker_product;
|
||||||
|
|
||||||
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
|
||||||
#[derive(Copy, Clone, Debug, Default)]
|
#[derive(Copy, Clone, Debug, Default)]
|
||||||
|
|
|
@ -1,167 +0,0 @@
|
||||||
/// Computes the sparse kronecker product
|
|
||||||
/// M = A \kron B
|
|
||||||
#[allow(non_snake_case)]
|
|
||||||
#[must_use]
|
|
||||||
pub fn kronecker_product<
|
|
||||||
N: num_traits::Num + Copy + Default,
|
|
||||||
I: sprs::SpIndex,
|
|
||||||
Iptr: sprs::SpIndex,
|
|
||||||
>(
|
|
||||||
A: sprs::CsMatViewI<N, I, Iptr>,
|
|
||||||
B: sprs::CsMatViewI<N, I, Iptr>,
|
|
||||||
) -> sprs::CsMatI<N, I, Iptr> {
|
|
||||||
use sprs::{CSC, CSR};
|
|
||||||
match (A.storage(), B.storage()) {
|
|
||||||
(CSR, CSR) => {
|
|
||||||
let nnz = A.nnz() * B.nnz();
|
|
||||||
let a_shape = A.shape();
|
|
||||||
let b_shape = B.shape();
|
|
||||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
|
||||||
let mut values = Vec::with_capacity(nnz);
|
|
||||||
let mut indices = Vec::with_capacity(nnz);
|
|
||||||
let mut indptr = Vec::with_capacity(shape.1 + 1);
|
|
||||||
|
|
||||||
let mut element_count = Iptr::zero();
|
|
||||||
indptr.push(element_count);
|
|
||||||
for a in A.outer_iterator() {
|
|
||||||
for b in B.outer_iterator() {
|
|
||||||
for (ai, &a) in a.iter() {
|
|
||||||
for (bi, &b) in b.iter() {
|
|
||||||
indices.push(I::from(ai * b_shape.1 + bi).unwrap());
|
|
||||||
element_count += Iptr::one();
|
|
||||||
values.push(a * b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
indptr.push(element_count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mat = sprs::CsMatBase::new(shape, indptr, indices, values);
|
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
|
||||||
mat
|
|
||||||
}
|
|
||||||
(CSC, CSC) => {
|
|
||||||
let nnz = A.nnz() * B.nnz();
|
|
||||||
let a_shape = A.shape();
|
|
||||||
let b_shape = B.shape();
|
|
||||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
|
||||||
let mut values = Vec::with_capacity(nnz);
|
|
||||||
let mut indices = Vec::with_capacity(nnz);
|
|
||||||
let mut indptr = Vec::with_capacity(shape.0 + 1);
|
|
||||||
|
|
||||||
let mut element_count = Iptr::zero();
|
|
||||||
indptr.push(element_count);
|
|
||||||
for a in A.outer_iterator() {
|
|
||||||
for b in B.outer_iterator() {
|
|
||||||
for (aj, &a) in a.iter() {
|
|
||||||
for (bj, &b) in b.iter() {
|
|
||||||
indices.push(I::from(aj * b_shape.0 + bj).unwrap());
|
|
||||||
element_count += Iptr::one();
|
|
||||||
values.push(a * b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
indptr.push(element_count);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values);
|
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
|
||||||
mat
|
|
||||||
}
|
|
||||||
(CSR, CSC) => {
|
|
||||||
let nnz = A.nnz() * B.nnz();
|
|
||||||
let a_shape = A.shape();
|
|
||||||
let b_shape = B.shape();
|
|
||||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
|
||||||
let mut mat = sprs::CsMatI::zero(shape);
|
|
||||||
|
|
||||||
mat.reserve_nnz_exact(nnz);
|
|
||||||
for (aj, a) in A.outer_iterator().enumerate() {
|
|
||||||
for (bi, b) in B.outer_iterator().enumerate() {
|
|
||||||
for (ai, &a) in a.iter() {
|
|
||||||
for (bj, &b) in b.iter() {
|
|
||||||
let i = ai * b_shape.1 + bi;
|
|
||||||
let j = aj * b_shape.0 + bj;
|
|
||||||
mat.insert(j, i, a * b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
|
||||||
mat
|
|
||||||
}
|
|
||||||
(CSC, CSR) => {
|
|
||||||
let nnz = A.nnz() * B.nnz();
|
|
||||||
let a_shape = A.shape();
|
|
||||||
let b_shape = B.shape();
|
|
||||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
|
||||||
let mat = sprs::CsMatI::zero(shape);
|
|
||||||
let mut mat = mat.to_csc();
|
|
||||||
|
|
||||||
for (ai, a) in A.outer_iterator().enumerate() {
|
|
||||||
for (bj, b) in B.outer_iterator().enumerate() {
|
|
||||||
for (aj, &a) in a.iter() {
|
|
||||||
for (bi, &b) in b.iter() {
|
|
||||||
let i = ai * b_shape.1 + bi;
|
|
||||||
let j = aj * b_shape.0 + bj;
|
|
||||||
mat.insert(j, i, a * b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
|
||||||
mat
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_kronecker_product() {
|
|
||||||
let mut a = sprs::TriMat::new((2, 3));
|
|
||||||
a.add_triplet(0, 1, 2);
|
|
||||||
a.add_triplet(0, 2, 3);
|
|
||||||
a.add_triplet(1, 0, 6);
|
|
||||||
a.add_triplet(1, 2, 8);
|
|
||||||
let a = a.to_csr();
|
|
||||||
|
|
||||||
let mut b = sprs::TriMat::new((3, 2));
|
|
||||||
b.add_triplet(0, 0, 1);
|
|
||||||
b.add_triplet(1, 0, 2);
|
|
||||||
b.add_triplet(2, 0, 3);
|
|
||||||
b.add_triplet(2, 1, -3);
|
|
||||||
let b = b.to_csr();
|
|
||||||
|
|
||||||
let check = |c: sprs::CsMatView<i32>| {
|
|
||||||
for (&n, (j, i)) in c.iter() {
|
|
||||||
match (j, i) {
|
|
||||||
(0, 2) => assert_eq!(n, 2),
|
|
||||||
(0, 4) => assert_eq!(n, 3),
|
|
||||||
(1, 2) => assert_eq!(n, 4),
|
|
||||||
(1, 4) => assert_eq!(n, 6),
|
|
||||||
(2, 2) => assert_eq!(n, 6),
|
|
||||||
(2, 3) => assert_eq!(n, -6),
|
|
||||||
(2, 4) => assert_eq!(n, 9),
|
|
||||||
(2, 5) => assert_eq!(n, -9),
|
|
||||||
(3, 0) => assert_eq!(n, 6),
|
|
||||||
(3, 4) => assert_eq!(n, 8),
|
|
||||||
(4, 0) => assert_eq!(n, 12),
|
|
||||||
(4, 4) => assert_eq!(n, 16),
|
|
||||||
(5, 0) => assert_eq!(n, 18),
|
|
||||||
(5, 1) => assert_eq!(n, -18),
|
|
||||||
(5, 4) => assert_eq!(n, 24),
|
|
||||||
(5, 5) => assert_eq!(n, -24),
|
|
||||||
_ => panic!("index ({},{}) should be 0, found {}", j, i, n),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let c = kronecker_product(a.view(), b.view());
|
|
||||||
check(c.view());
|
|
||||||
let b = b.to_csc();
|
|
||||||
let c = kronecker_product(a.view(), b.view());
|
|
||||||
check(c.view());
|
|
||||||
let a = a.to_csc();
|
|
||||||
let c = kronecker_product(a.view(), b.view());
|
|
||||||
check(c.view());
|
|
||||||
let b = b.to_csr();
|
|
||||||
let c = kronecker_product(a.view(), b.view());
|
|
||||||
check(c.view());
|
|
||||||
}
|
|
Loading…
Reference in New Issue