Improve the performance of the outer product

This commit is contained in:
Magnus Ulimoen 2020-08-21 23:52:26 +02:00
parent 3671ba5e1f
commit feeb254468
1 changed files with 26 additions and 14 deletions

View File

@ -17,19 +17,25 @@ pub fn sparse_sparse_outer_product<
let a_shape = A.shape(); let a_shape = A.shape();
let b_shape = B.shape(); let b_shape = B.shape();
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
let mut mat = sprs::CsMatI::zero(shape); let mut values = Vec::with_capacity(nnz);
mat.reserve_nnz_exact(nnz); let mut indices = Vec::with_capacity(nnz);
for (aj, a) in A.outer_iterator().enumerate() { let mut indptr = Vec::with_capacity(shape.1 + 1);
for (bj, b) in B.outer_iterator().enumerate() {
let mut element_count = Iptr::zero();
indptr.push(element_count);
for a in A.outer_iterator() {
for b in B.outer_iterator() {
for (ai, &a) in a.iter() { for (ai, &a) in a.iter() {
for (bi, &b) in b.iter() { for (bi, &b) in b.iter() {
let i = ai * b_shape.1 + bi; indices.push(I::from(ai * b_shape.1 + bi).unwrap());
let j = aj * b_shape.0 + bj; element_count += Iptr::one();
mat.insert(j, i, a * b) values.push(a * b);
} }
} }
indptr.push(element_count);
} }
} }
let mat = sprs::CsMatBase::new(shape, indptr, indices, values);
debug_assert_eq!(mat.nnz(), nnz); debug_assert_eq!(mat.nnz(), nnz);
mat mat
} }
@ -38,20 +44,25 @@ pub fn sparse_sparse_outer_product<
let a_shape = A.shape(); let a_shape = A.shape();
let b_shape = B.shape(); let b_shape = B.shape();
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
let mat = sprs::CsMatI::zero(shape); let mut values = Vec::with_capacity(nnz);
let mut mat = mat.to_csc(); let mut indices = Vec::with_capacity(nnz);
let mut indptr = Vec::with_capacity(shape.0 + 1);
for (ai, a) in A.outer_iterator().enumerate() { let mut element_count = Iptr::zero();
for (bi, b) in B.outer_iterator().enumerate() { indptr.push(element_count);
for a in A.outer_iterator() {
for b in B.outer_iterator() {
for (aj, &a) in a.iter() { for (aj, &a) in a.iter() {
for (bj, &b) in b.iter() { for (bj, &b) in b.iter() {
let i = ai * b_shape.1 + bi; indices.push(I::from(aj * b_shape.0 + bj).unwrap());
let j = aj * b_shape.0 + bj; element_count += Iptr::one();
mat.insert(j, i, a * b) values.push(a * b);
} }
} }
indptr.push(element_count);
} }
} }
let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values);
debug_assert_eq!(mat.nnz(), nnz); debug_assert_eq!(mat.nnz(), nnz);
mat mat
} }
@ -61,6 +72,7 @@ pub fn sparse_sparse_outer_product<
let b_shape = B.shape(); let b_shape = B.shape();
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
let mut mat = sprs::CsMatI::zero(shape); let mut mat = sprs::CsMatI::zero(shape);
mat.reserve_nnz_exact(nnz); mat.reserve_nnz_exact(nnz);
for (aj, a) in A.outer_iterator().enumerate() { for (aj, a) in A.outer_iterator().enumerate() {
for (bi, b) in B.outer_iterator().enumerate() { for (bi, b) in B.outer_iterator().enumerate() {