diff --git a/sbp/src/utils/outer_product.rs b/sbp/src/utils/outer_product.rs index 39c1ec6..c5cf0dc 100644 --- a/sbp/src/utils/outer_product.rs +++ b/sbp/src/utils/outer_product.rs @@ -17,19 +17,25 @@ pub fn sparse_sparse_outer_product< let a_shape = A.shape(); let b_shape = B.shape(); let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); - let mut mat = sprs::CsMatI::zero(shape); - mat.reserve_nnz_exact(nnz); - for (aj, a) in A.outer_iterator().enumerate() { - for (bj, b) in B.outer_iterator().enumerate() { + let mut values = Vec::with_capacity(nnz); + let mut indices = Vec::with_capacity(nnz); + let mut indptr = Vec::with_capacity(shape.1 + 1); + + let mut element_count = Iptr::zero(); + indptr.push(element_count); + for a in A.outer_iterator() { + for b in B.outer_iterator() { for (ai, &a) in a.iter() { for (bi, &b) in b.iter() { - let i = ai * b_shape.1 + bi; - let j = aj * b_shape.0 + bj; - mat.insert(j, i, a * b) + indices.push(I::from(ai * b_shape.1 + bi).unwrap()); + element_count += Iptr::one(); + values.push(a * b); } } + indptr.push(element_count); } } + let mat = sprs::CsMatBase::new(shape, indptr, indices, values); debug_assert_eq!(mat.nnz(), nnz); mat } @@ -38,20 +44,25 @@ pub fn sparse_sparse_outer_product< let a_shape = A.shape(); let b_shape = B.shape(); let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); - let mat = sprs::CsMatI::zero(shape); - let mut mat = mat.to_csc(); + let mut values = Vec::with_capacity(nnz); + let mut indices = Vec::with_capacity(nnz); + let mut indptr = Vec::with_capacity(shape.0 + 1); - for (ai, a) in A.outer_iterator().enumerate() { - for (bi, b) in B.outer_iterator().enumerate() { + let mut element_count = Iptr::zero(); + indptr.push(element_count); + for a in A.outer_iterator() { + for b in B.outer_iterator() { for (aj, &a) in a.iter() { for (bj, &b) in b.iter() { - let i = ai * b_shape.1 + bi; - let j = aj * b_shape.0 + bj; - mat.insert(j, i, a * b) + indices.push(I::from(aj * b_shape.0 + bj).unwrap()); + element_count += Iptr::one(); + values.push(a * b); } } + indptr.push(element_count); } } + let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values); debug_assert_eq!(mat.nnz(), nnz); mat } @@ -61,6 +72,7 @@ pub fn sparse_sparse_outer_product< let b_shape = B.shape(); let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); let mut mat = sprs::CsMatI::zero(shape); + mat.reserve_nnz_exact(nnz); for (aj, a) in A.outer_iterator().enumerate() { for (bi, b) in B.outer_iterator().enumerate() {