Improve the performance of the outer product
This commit is contained in:
parent
3671ba5e1f
commit
feeb254468
|
@ -17,19 +17,25 @@ pub fn sparse_sparse_outer_product<
|
|||
let a_shape = A.shape();
|
||||
let b_shape = B.shape();
|
||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||
let mut mat = sprs::CsMatI::zero(shape);
|
||||
mat.reserve_nnz_exact(nnz);
|
||||
for (aj, a) in A.outer_iterator().enumerate() {
|
||||
for (bj, b) in B.outer_iterator().enumerate() {
|
||||
let mut values = Vec::with_capacity(nnz);
|
||||
let mut indices = Vec::with_capacity(nnz);
|
||||
let mut indptr = Vec::with_capacity(shape.1 + 1);
|
||||
|
||||
let mut element_count = Iptr::zero();
|
||||
indptr.push(element_count);
|
||||
for a in A.outer_iterator() {
|
||||
for b in B.outer_iterator() {
|
||||
for (ai, &a) in a.iter() {
|
||||
for (bi, &b) in b.iter() {
|
||||
let i = ai * b_shape.1 + bi;
|
||||
let j = aj * b_shape.0 + bj;
|
||||
mat.insert(j, i, a * b)
|
||||
indices.push(I::from(ai * b_shape.1 + bi).unwrap());
|
||||
element_count += Iptr::one();
|
||||
values.push(a * b);
|
||||
}
|
||||
}
|
||||
indptr.push(element_count);
|
||||
}
|
||||
}
|
||||
let mat = sprs::CsMatBase::new(shape, indptr, indices, values);
|
||||
debug_assert_eq!(mat.nnz(), nnz);
|
||||
mat
|
||||
}
|
||||
|
@ -38,20 +44,25 @@ pub fn sparse_sparse_outer_product<
|
|||
let a_shape = A.shape();
|
||||
let b_shape = B.shape();
|
||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||
let mat = sprs::CsMatI::zero(shape);
|
||||
let mut mat = mat.to_csc();
|
||||
let mut values = Vec::with_capacity(nnz);
|
||||
let mut indices = Vec::with_capacity(nnz);
|
||||
let mut indptr = Vec::with_capacity(shape.0 + 1);
|
||||
|
||||
for (ai, a) in A.outer_iterator().enumerate() {
|
||||
for (bi, b) in B.outer_iterator().enumerate() {
|
||||
let mut element_count = Iptr::zero();
|
||||
indptr.push(element_count);
|
||||
for a in A.outer_iterator() {
|
||||
for b in B.outer_iterator() {
|
||||
for (aj, &a) in a.iter() {
|
||||
for (bj, &b) in b.iter() {
|
||||
let i = ai * b_shape.1 + bi;
|
||||
let j = aj * b_shape.0 + bj;
|
||||
mat.insert(j, i, a * b)
|
||||
indices.push(I::from(aj * b_shape.0 + bj).unwrap());
|
||||
element_count += Iptr::one();
|
||||
values.push(a * b);
|
||||
}
|
||||
}
|
||||
indptr.push(element_count);
|
||||
}
|
||||
}
|
||||
let mat = sprs::CsMatBase::new_csc(shape, indptr, indices, values);
|
||||
debug_assert_eq!(mat.nnz(), nnz);
|
||||
mat
|
||||
}
|
||||
|
@ -61,6 +72,7 @@ pub fn sparse_sparse_outer_product<
|
|||
let b_shape = B.shape();
|
||||
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||
let mut mat = sprs::CsMatI::zero(shape);
|
||||
|
||||
mat.reserve_nnz_exact(nnz);
|
||||
for (aj, a) in A.outer_iterator().enumerate() {
|
||||
for (bi, b) in B.outer_iterator().enumerate() {
|
||||
|
|
Loading…
Reference in New Issue