Improve the performance of the outer product
This commit is contained in:
		| @@ -17,19 +17,25 @@ pub fn sparse_sparse_outer_product< | |||||||
|             let a_shape = A.shape(); |             let a_shape = A.shape(); | ||||||
|             let b_shape = B.shape(); |             let b_shape = B.shape(); | ||||||
|             let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); |             let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); | ||||||
|             let mut mat = sprs::CsMatI::zero(shape); |             let mut values = Vec::with_capacity(nnz); | ||||||
|             mat.reserve_nnz_exact(nnz); |             let mut indices = Vec::with_capacity(nnz); | ||||||
|             for (aj, a) in A.outer_iterator().enumerate() { |             let mut indptr = Vec::with_capacity(shape.1 + 1); | ||||||
|                 for (bj, b) in B.outer_iterator().enumerate() { |  | ||||||
|  |             let mut element_count = Iptr::zero(); | ||||||
|  |             indptr.push(element_count); | ||||||
|  |             for a in A.outer_iterator() { | ||||||
|  |                 for b in B.outer_iterator() { | ||||||
|                     for (ai, &a) in a.iter() { |                     for (ai, &a) in a.iter() { | ||||||
|                         for (bi, &b) in b.iter() { |                         for (bi, &b) in b.iter() { | ||||||
|                             let i = ai * b_shape.1 + bi; |                             indices.push(I::from(ai * b_shape.1 + bi).unwrap()); | ||||||
|                             let j = aj * b_shape.0 + bj; |                             element_count += Iptr::one(); | ||||||
|                             mat.insert(j, i, a * b) |                             values.push(a * b); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |                     indptr.push(element_count); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |             let mat = sprs::CsMatBase::new(shape, indptr, indices, values); | ||||||
|             debug_assert_eq!(mat.nnz(), nnz); |             debug_assert_eq!(mat.nnz(), nnz); | ||||||
|             mat |             mat | ||||||
|         } |         } | ||||||
| @@ -38,20 +44,25 @@ pub fn sparse_sparse_outer_product< | |||||||
|             let a_shape = A.shape(); |             let a_shape = A.shape(); | ||||||
|             let b_shape = B.shape(); |             let b_shape = B.shape(); | ||||||
|             let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); |             let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); | ||||||
|             let mat = sprs::CsMatI::zero(shape); |             let mut values = Vec::with_capacity(nnz); | ||||||
|             let mut mat = mat.to_csc(); |             let mut indices = Vec::with_capacity(nnz); | ||||||
|  |             let mut indptr = Vec::with_capacity(shape.0 + 1); | ||||||
|  |  | ||||||
|             for (ai, a) in A.outer_iterator().enumerate() { |             let mut element_count = Iptr::zero(); | ||||||
|                 for (bi, b) in B.outer_iterator().enumerate() { |             indptr.push(element_count); | ||||||
|  |             for a in A.outer_iterator() { | ||||||
|  |                 for b in B.outer_iterator() { | ||||||
|                     for (aj, &a) in a.iter() { |                     for (aj, &a) in a.iter() { | ||||||
|                         for (bj, &b) in b.iter() { |                         for (bj, &b) in b.iter() { | ||||||
|                             let i = ai * b_shape.1 + bi; |                             indices.push(I::from(aj * b_shape.0 + bj).unwrap()); | ||||||
|                             let j = aj * b_shape.0 + bj; |                             element_count += Iptr::one(); | ||||||
|                             mat.insert(j, i, a * b) |                             values.push(a * b); | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
|  |                     indptr.push(element_count); | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|  |             let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values); | ||||||
|             debug_assert_eq!(mat.nnz(), nnz); |             debug_assert_eq!(mat.nnz(), nnz); | ||||||
|             mat |             mat | ||||||
|         } |         } | ||||||
| @@ -61,6 +72,7 @@ pub fn sparse_sparse_outer_product< | |||||||
|             let b_shape = B.shape(); |             let b_shape = B.shape(); | ||||||
|             let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); |             let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1); | ||||||
|             let mut mat = sprs::CsMatI::zero(shape); |             let mut mat = sprs::CsMatI::zero(shape); | ||||||
|  |  | ||||||
|             mat.reserve_nnz_exact(nnz); |             mat.reserve_nnz_exact(nnz); | ||||||
|             for (aj, a) in A.outer_iterator().enumerate() { |             for (aj, a) in A.outer_iterator().enumerate() { | ||||||
|                 for (bi, b) in B.outer_iterator().enumerate() { |                 for (bi, b) in B.outer_iterator().enumerate() { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user