Improve the performance of the outer product
This commit is contained in:
		@@ -17,19 +17,25 @@ pub fn sparse_sparse_outer_product<
 | 
				
			|||||||
            let a_shape = A.shape();
 | 
					            let a_shape = A.shape();
 | 
				
			||||||
            let b_shape = B.shape();
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
            let mut mat = sprs::CsMatI::zero(shape);
 | 
					            let mut values = Vec::with_capacity(nnz);
 | 
				
			||||||
            mat.reserve_nnz_exact(nnz);
 | 
					            let mut indices = Vec::with_capacity(nnz);
 | 
				
			||||||
            for (aj, a) in A.outer_iterator().enumerate() {
 | 
					            let mut indptr = Vec::with_capacity(shape.1 + 1);
 | 
				
			||||||
                for (bj, b) in B.outer_iterator().enumerate() {
 | 
					
 | 
				
			||||||
 | 
					            let mut element_count = Iptr::zero();
 | 
				
			||||||
 | 
					            indptr.push(element_count);
 | 
				
			||||||
 | 
					            for a in A.outer_iterator() {
 | 
				
			||||||
 | 
					                for b in B.outer_iterator() {
 | 
				
			||||||
                    for (ai, &a) in a.iter() {
 | 
					                    for (ai, &a) in a.iter() {
 | 
				
			||||||
                        for (bi, &b) in b.iter() {
 | 
					                        for (bi, &b) in b.iter() {
 | 
				
			||||||
                            let i = ai * b_shape.1 + bi;
 | 
					                            indices.push(I::from(ai * b_shape.1 + bi).unwrap());
 | 
				
			||||||
                            let j = aj * b_shape.0 + bj;
 | 
					                            element_count += Iptr::one();
 | 
				
			||||||
                            mat.insert(j, i, a * b)
 | 
					                            values.push(a * b);
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					                    indptr.push(element_count);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            let mat = sprs::CsMatBase::new(shape, indptr, indices, values);
 | 
				
			||||||
            debug_assert_eq!(mat.nnz(), nnz);
 | 
					            debug_assert_eq!(mat.nnz(), nnz);
 | 
				
			||||||
            mat
 | 
					            mat
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -38,20 +44,25 @@ pub fn sparse_sparse_outer_product<
 | 
				
			|||||||
            let a_shape = A.shape();
 | 
					            let a_shape = A.shape();
 | 
				
			||||||
            let b_shape = B.shape();
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
            let mat = sprs::CsMatI::zero(shape);
 | 
					            let mut values = Vec::with_capacity(nnz);
 | 
				
			||||||
            let mut mat = mat.to_csc();
 | 
					            let mut indices = Vec::with_capacity(nnz);
 | 
				
			||||||
 | 
					            let mut indptr = Vec::with_capacity(shape.0 + 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for (ai, a) in A.outer_iterator().enumerate() {
 | 
					            let mut element_count = Iptr::zero();
 | 
				
			||||||
                for (bi, b) in B.outer_iterator().enumerate() {
 | 
					            indptr.push(element_count);
 | 
				
			||||||
 | 
					            for a in A.outer_iterator() {
 | 
				
			||||||
 | 
					                for b in B.outer_iterator() {
 | 
				
			||||||
                    for (aj, &a) in a.iter() {
 | 
					                    for (aj, &a) in a.iter() {
 | 
				
			||||||
                        for (bj, &b) in b.iter() {
 | 
					                        for (bj, &b) in b.iter() {
 | 
				
			||||||
                            let i = ai * b_shape.1 + bi;
 | 
					                            indices.push(I::from(aj * b_shape.0 + bj).unwrap());
 | 
				
			||||||
                            let j = aj * b_shape.0 + bj;
 | 
					                            element_count += Iptr::one();
 | 
				
			||||||
                            mat.insert(j, i, a * b)
 | 
					                            values.push(a * b);
 | 
				
			||||||
                        }
 | 
					                        }
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
 | 
					                    indptr.push(element_count);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					            let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values);
 | 
				
			||||||
            debug_assert_eq!(mat.nnz(), nnz);
 | 
					            debug_assert_eq!(mat.nnz(), nnz);
 | 
				
			||||||
            mat
 | 
					            mat
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -61,6 +72,7 @@ pub fn sparse_sparse_outer_product<
 | 
				
			|||||||
            let b_shape = B.shape();
 | 
					            let b_shape = B.shape();
 | 
				
			||||||
            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
					            let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
 | 
				
			||||||
            let mut mat = sprs::CsMatI::zero(shape);
 | 
					            let mut mat = sprs::CsMatI::zero(shape);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            mat.reserve_nnz_exact(nnz);
 | 
					            mat.reserve_nnz_exact(nnz);
 | 
				
			||||||
            for (aj, a) in A.outer_iterator().enumerate() {
 | 
					            for (aj, a) in A.outer_iterator().enumerate() {
 | 
				
			||||||
                for (bi, b) in B.outer_iterator().enumerate() {
 | 
					                for (bi, b) in B.outer_iterator().enumerate() {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user