add sparse matrix creating to all diff ops
This commit is contained in:
parent
4f772b8dc5
commit
e2a3bed1ff
|
@ -10,6 +10,7 @@ approx = "0.3.2"
|
||||||
packed_simd = "0.3.3"
|
packed_simd = "0.3.3"
|
||||||
rayon = { version = "1.3.0", optional = true }
|
rayon = { version = "1.3.0", optional = true }
|
||||||
sprs = { version = "0.7.1", optional = true }
|
sprs = { version = "0.7.1", optional = true }
|
||||||
|
num-traits = "0.2.11"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
# Use f32 as precision, default is f64
|
# Use f32 as precision, default is f64
|
||||||
|
@ -17,7 +18,7 @@ f32 = []
|
||||||
sparse = ["sprs"]
|
sparse = ["sprs"]
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
criterion = "0.3.1"
|
criterion = "0.3.2"
|
||||||
|
|
||||||
[[bench]]
|
[[bench]]
|
||||||
name = "sbpoperators"
|
name = "sbpoperators"
|
||||||
|
|
|
@ -12,13 +12,9 @@ pub trait SbpOperator1d: Send + Sync {
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float>;
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float>;
|
||||||
unimplemented!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait SbpOperator2d: Send + Sync {
|
pub trait SbpOperator2d: Send + Sync {
|
||||||
|
@ -30,6 +26,9 @@ pub trait SbpOperator2d: Send + Sync {
|
||||||
|
|
||||||
fn is_h2xi(&self) -> bool;
|
fn is_h2xi(&self) -> bool;
|
||||||
fn is_h2eta(&self) -> bool;
|
fn is_h2eta(&self) -> bool;
|
||||||
|
|
||||||
|
fn op_xi(&self) -> &dyn SbpOperator1d;
|
||||||
|
fn op_eta(&self) -> &dyn SbpOperator1d;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (&SBPeta, &SBPxi) {
|
impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (&SBPeta, &SBPxi) {
|
||||||
|
@ -55,6 +54,13 @@ impl<SBPeta: SbpOperator1d, SBPxi: SbpOperator1d> SbpOperator2d for (&SBPeta, &S
|
||||||
fn is_h2eta(&self) -> bool {
|
fn is_h2eta(&self) -> bool {
|
||||||
self.0.is_h2()
|
self.0.is_h2()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn op_xi(&self) -> &dyn SbpOperator1d {
|
||||||
|
self.1
|
||||||
|
}
|
||||||
|
fn op_eta(&self) -> &dyn SbpOperator1d {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBP: SbpOperator1d + Copy> SbpOperator2d for SBP {
|
impl<SBP: SbpOperator1d + Copy> SbpOperator2d for SBP {
|
||||||
|
@ -76,6 +82,13 @@ impl<SBP: SbpOperator1d + Copy> SbpOperator2d for SBP {
|
||||||
fn is_h2eta(&self) -> bool {
|
fn is_h2eta(&self) -> bool {
|
||||||
<(&SBP, &SBP) as SbpOperator2d>::is_h2eta(&(self, self))
|
<(&SBP, &SBP) as SbpOperator2d>::is_h2eta(&(self, self))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn op_xi(&self) -> &dyn SbpOperator1d {
|
||||||
|
self
|
||||||
|
}
|
||||||
|
fn op_eta(&self) -> &dyn SbpOperator1d {
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait UpwindOperator1d: SbpOperator1d + Send + Sync {
|
pub trait UpwindOperator1d: SbpOperator1d + Send + Sync {
|
||||||
|
|
|
@ -38,6 +38,21 @@ impl SbpOperator1d for SBP4 {
|
||||||
fn h(&self) -> &'static [Float] {
|
fn h(&self) -> &'static [Float] {
|
||||||
Self::HBLOCK
|
Self::HBLOCK
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::sparse_from_block(
|
||||||
|
Self::BLOCK,
|
||||||
|
Self::DIAG,
|
||||||
|
super::Symmetry::AntiSymmetric,
|
||||||
|
super::OperatorType::Normal,
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::h_matrix(Self::DIAG, n, self.is_h2())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP4) {
|
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP4) {
|
||||||
|
|
|
@ -42,6 +42,21 @@ impl SbpOperator1d for SBP8 {
|
||||||
fn h(&self) -> &'static [Float] {
|
fn h(&self) -> &'static [Float] {
|
||||||
Self::HBLOCK
|
Self::HBLOCK
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::sparse_from_block(
|
||||||
|
Self::BLOCK,
|
||||||
|
Self::DIAG,
|
||||||
|
super::Symmetry::AntiSymmetric,
|
||||||
|
super::OperatorType::Normal,
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::h_matrix(Self::DIAG, n, self.is_h2())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP8) {
|
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &SBP8) {
|
||||||
|
|
|
@ -56,6 +56,21 @@ impl SbpOperator1d for Upwind4h2 {
|
||||||
fn is_h2(&self) -> bool {
|
fn is_h2(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::sparse_from_block(
|
||||||
|
Self::BLOCK,
|
||||||
|
Self::DIAG,
|
||||||
|
super::Symmetry::AntiSymmetric,
|
||||||
|
super::OperatorType::H2,
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::h_matrix(Self::DIAG, n, self.is_h2())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4h2) {
|
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind4h2) {
|
||||||
|
|
|
@ -61,6 +61,21 @@ impl SbpOperator1d for Upwind9 {
|
||||||
fn h(&self) -> &'static [Float] {
|
fn h(&self) -> &'static [Float] {
|
||||||
Self::HBLOCK
|
Self::HBLOCK
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::sparse_from_block(
|
||||||
|
Self::BLOCK,
|
||||||
|
Self::DIAG,
|
||||||
|
super::Symmetry::AntiSymmetric,
|
||||||
|
super::OperatorType::Normal,
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::h_matrix(Self::DIAG, n, self.is_h2())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9) {
|
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9) {
|
||||||
|
|
|
@ -64,6 +64,21 @@ impl SbpOperator1d for Upwind9h2 {
|
||||||
fn is_h2(&self) -> bool {
|
fn is_h2(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn diff_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::sparse_from_block(
|
||||||
|
Self::BLOCK,
|
||||||
|
Self::DIAG,
|
||||||
|
super::Symmetry::AntiSymmetric,
|
||||||
|
super::OperatorType::H2,
|
||||||
|
n,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
|
super::h_matrix(Self::DIAG, n, self.is_h2())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9h2) {
|
impl<SBP: SbpOperator1d> SbpOperator2d for (&SBP, &Upwind9h2) {
|
||||||
|
|
|
@ -1,5 +1,14 @@
|
||||||
use crate::Float;
|
use crate::Float;
|
||||||
|
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
mod jacobi;
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
pub use jacobi::*;
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
mod outer_product;
|
||||||
|
#[cfg(feature = "sparse")]
|
||||||
|
pub use outer_product::sparse_sparse_outer_product;
|
||||||
|
|
||||||
pub struct Direction<T> {
|
pub struct Direction<T> {
|
||||||
pub north: T,
|
pub north: T,
|
||||||
pub south: T,
|
pub south: T,
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
use crate::Float;
|
||||||
|
|
||||||
|
/// A x = b
|
||||||
|
/// with A and b known
|
||||||
|
/// x should contain a first guess of
|
||||||
|
pub fn jacobi_method(
|
||||||
|
a: sprs::CsMatView<Float>,
|
||||||
|
b: &[Float],
|
||||||
|
x: &mut [Float],
|
||||||
|
tmp: &mut [Float],
|
||||||
|
iter_count: usize,
|
||||||
|
) {
|
||||||
|
for _ in 0..iter_count {
|
||||||
|
jacobi_step(a, b, x, tmp);
|
||||||
|
x.copy_from_slice(tmp);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn jacobi_step(a: sprs::CsMatView<Float>, b: &[Float], x0: &[Float], x: &mut [Float]) {
|
||||||
|
let n = a.shape().0;
|
||||||
|
assert_eq!(n, a.shape().1);
|
||||||
|
let b = &b[..n];
|
||||||
|
let x0 = &x0[..n];
|
||||||
|
let x = &mut x[..n];
|
||||||
|
for (((i, ai), xi), &bi) in a
|
||||||
|
.outer_iterator()
|
||||||
|
.enumerate()
|
||||||
|
.zip(x.iter_mut())
|
||||||
|
.zip(b.iter())
|
||||||
|
{
|
||||||
|
let mut summa = 0.0;
|
||||||
|
let mut aii = None;
|
||||||
|
for (j, aij) in ai.iter() {
|
||||||
|
if i == j {
|
||||||
|
aii = Some(aij);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
summa += aij * x0[j];
|
||||||
|
}
|
||||||
|
*xi = 1.0 / aii.unwrap() * (bi - summa);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jacobi_2x2() {
|
||||||
|
let mut a = sprs::CsMat::zero((2, 2));
|
||||||
|
a.insert(0, 0, 2.0);
|
||||||
|
a.insert(0, 1, 1.0);
|
||||||
|
a.insert(1, 0, 5.0);
|
||||||
|
a.insert(1, 1, 7.0);
|
||||||
|
|
||||||
|
let b = ndarray::arr1(&[11.0, 13.0]);
|
||||||
|
|
||||||
|
let mut x0 = ndarray::arr1(&[1.0; 2]);
|
||||||
|
let mut tmp = x0.clone();
|
||||||
|
|
||||||
|
jacobi_method(
|
||||||
|
a.view(),
|
||||||
|
b.as_slice().unwrap(),
|
||||||
|
x0.as_slice_mut().unwrap(),
|
||||||
|
tmp.as_slice_mut().unwrap(),
|
||||||
|
25,
|
||||||
|
);
|
||||||
|
|
||||||
|
approx::assert_abs_diff_eq!(x0, ndarray::arr1(&[7.111, -3.222]), epsilon = 1e-2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_jacobi_4x4() {
|
||||||
|
let mut a = sprs::CsMat::zero((4, 4));
|
||||||
|
a.insert(0, 0, 10.0);
|
||||||
|
a.insert(0, 1, -1.0);
|
||||||
|
a.insert(0, 2, 2.0);
|
||||||
|
a.insert(1, 0, -1.0);
|
||||||
|
a.insert(1, 1, 11.0);
|
||||||
|
a.insert(1, 2, -1.0);
|
||||||
|
a.insert(1, 3, 3.0);
|
||||||
|
a.insert(2, 0, 2.0);
|
||||||
|
a.insert(2, 1, -1.0);
|
||||||
|
a.insert(2, 2, 10.0);
|
||||||
|
a.insert(2, 3, -1.0);
|
||||||
|
a.insert(3, 1, 3.0);
|
||||||
|
a.insert(3, 2, -1.0);
|
||||||
|
a.insert(3, 3, 8.0);
|
||||||
|
|
||||||
|
let b = ndarray::arr1(&[6.0, 25.0, -11.0, 15.0]);
|
||||||
|
|
||||||
|
let mut x0 = ndarray::Array::zeros(b.len());
|
||||||
|
let mut tmp = x0.clone();
|
||||||
|
|
||||||
|
for iter in 0.. {
|
||||||
|
jacobi_step(
|
||||||
|
a.view(),
|
||||||
|
b.as_slice().unwrap(),
|
||||||
|
x0.as_slice().unwrap(),
|
||||||
|
tmp.as_slice_mut().unwrap(),
|
||||||
|
);
|
||||||
|
x0.as_slice_mut()
|
||||||
|
.unwrap()
|
||||||
|
.copy_from_slice(tmp.as_slice().unwrap());
|
||||||
|
match iter {
|
||||||
|
0 => approx::assert_abs_diff_eq!(
|
||||||
|
x0,
|
||||||
|
ndarray::arr1(&[0.6, 2.27272, -1.1, 1.875]),
|
||||||
|
epsilon = 1e-4
|
||||||
|
),
|
||||||
|
1 => approx::assert_abs_diff_eq!(
|
||||||
|
x0,
|
||||||
|
ndarray::arr1(&[1.04727, 1.7159, -0.80522, 0.88522]),
|
||||||
|
epsilon = 1e-4
|
||||||
|
),
|
||||||
|
2 => approx::assert_abs_diff_eq!(
|
||||||
|
x0,
|
||||||
|
ndarray::arr1(&[0.93263, 2.05330, -1.0493, 1.13088]),
|
||||||
|
epsilon = 1e-4
|
||||||
|
),
|
||||||
|
3 => approx::assert_abs_diff_eq!(
|
||||||
|
x0,
|
||||||
|
ndarray::arr1(&[1.01519, 1.95369, -0.9681, 0.97384]),
|
||||||
|
epsilon = 1e-4
|
||||||
|
),
|
||||||
|
4 => approx::assert_abs_diff_eq!(
|
||||||
|
x0,
|
||||||
|
ndarray::arr1(&[0.98899, 2.0114, -1.0102, 1.02135]),
|
||||||
|
epsilon = 1e-4
|
||||||
|
),
|
||||||
|
_ => break,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,197 @@
|
||||||
|
/// Computes the sparse kronecker product
|
||||||
|
/// M = A \kron B
|
||||||
|
#[allow(non_snake_case)]
|
||||||
|
#[must_use]
|
||||||
|
pub fn sparse_sparse_outer_product<
|
||||||
|
N: num_traits::Num + Copy + Default,
|
||||||
|
I: sprs::SpIndex,
|
||||||
|
Iptr: sprs::SpIndex,
|
||||||
|
>(
|
||||||
|
A: sprs::CsMatViewI<N, I, Iptr>,
|
||||||
|
B: sprs::CsMatViewI<N, I, Iptr>,
|
||||||
|
) -> sprs::CsMatI<N, I, Iptr> {
|
||||||
|
match (A.storage(), B.storage()) {
|
||||||
|
(sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSR) => {
|
||||||
|
let nnz = A.nnz() * B.nnz();
|
||||||
|
let a_shape = A.shape();
|
||||||
|
let b_shape = B.shape();
|
||||||
|
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||||
|
let mut mat = sprs::CsMatI::zero(shape);
|
||||||
|
mat.reserve_nnz_exact(nnz);
|
||||||
|
for (aj, a) in A.outer_iterator().enumerate() {
|
||||||
|
for (bj, b) in B.outer_iterator().enumerate() {
|
||||||
|
for (ai, &a) in a.iter() {
|
||||||
|
for (bi, &b) in b.iter() {
|
||||||
|
let i = ai * b_shape.1 + bi;
|
||||||
|
let j = aj * b_shape.0 + bj;
|
||||||
|
mat.insert(j, i, a * b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
|
mat
|
||||||
|
}
|
||||||
|
(sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSC) => {
|
||||||
|
let nnz = A.nnz() * B.nnz();
|
||||||
|
let a_shape = A.shape();
|
||||||
|
let b_shape = B.shape();
|
||||||
|
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||||
|
let mat = sprs::CsMatI::zero(shape);
|
||||||
|
let mut mat = mat.to_csc();
|
||||||
|
|
||||||
|
for (ai, a) in A.outer_iterator().enumerate() {
|
||||||
|
for (bi, b) in B.outer_iterator().enumerate() {
|
||||||
|
for (aj, &a) in a.iter() {
|
||||||
|
for (bj, &b) in b.iter() {
|
||||||
|
let i = ai * b_shape.1 + bi;
|
||||||
|
let j = aj * b_shape.0 + bj;
|
||||||
|
mat.insert(j, i, a * b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
|
mat
|
||||||
|
}
|
||||||
|
(sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSC) => {
|
||||||
|
let nnz = A.nnz() * B.nnz();
|
||||||
|
let a_shape = A.shape();
|
||||||
|
let b_shape = B.shape();
|
||||||
|
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||||
|
let mut mat = sprs::CsMatI::zero(shape);
|
||||||
|
mat.reserve_nnz_exact(nnz);
|
||||||
|
for (aj, a) in A.outer_iterator().enumerate() {
|
||||||
|
for (bi, b) in B.outer_iterator().enumerate() {
|
||||||
|
for (ai, &a) in a.iter() {
|
||||||
|
for (bj, &b) in b.iter() {
|
||||||
|
let i = ai * b_shape.1 + bi;
|
||||||
|
let j = aj * b_shape.0 + bj;
|
||||||
|
mat.insert(j, i, a * b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
|
mat
|
||||||
|
}
|
||||||
|
(sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSR) => {
|
||||||
|
let nnz = A.nnz() * B.nnz();
|
||||||
|
let a_shape = A.shape();
|
||||||
|
let b_shape = B.shape();
|
||||||
|
let shape = (a_shape.0 * b_shape.0, a_shape.1 * b_shape.1);
|
||||||
|
let mat = sprs::CsMatI::zero(shape);
|
||||||
|
let mut mat = mat.to_csc();
|
||||||
|
|
||||||
|
for (aj, a) in A.outer_iterator().enumerate() {
|
||||||
|
for (bi, b) in B.outer_iterator().enumerate() {
|
||||||
|
for (ai, &a) in a.iter() {
|
||||||
|
for (bj, &b) in b.iter() {
|
||||||
|
let i = ai * b_shape.1 + bi;
|
||||||
|
let j = aj * b_shape.0 + bj;
|
||||||
|
mat.insert(j, i, a * b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
|
mat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_outer_product() {
|
||||||
|
let mut a = sprs::TriMat::new((2, 3));
|
||||||
|
a.add_triplet(0, 1, 2);
|
||||||
|
a.add_triplet(0, 2, 3);
|
||||||
|
a.add_triplet(1, 0, 6);
|
||||||
|
a.add_triplet(1, 2, 8);
|
||||||
|
let a = a.to_csr();
|
||||||
|
|
||||||
|
let mut b = sprs::TriMat::new((3, 2));
|
||||||
|
b.add_triplet(0, 0, 1);
|
||||||
|
b.add_triplet(1, 0, 2);
|
||||||
|
b.add_triplet(2, 0, 3);
|
||||||
|
b.add_triplet(2, 1, -3);
|
||||||
|
let b = b.to_csr();
|
||||||
|
|
||||||
|
let c = sparse_sparse_outer_product(a.view(), b.view());
|
||||||
|
for (&n, (j, i)) in c.iter() {
|
||||||
|
match (j, i) {
|
||||||
|
(0, 2) => assert_eq!(n, 2),
|
||||||
|
(0, 4) => assert_eq!(n, 3),
|
||||||
|
(1, 2) => assert_eq!(n, 4),
|
||||||
|
(1, 4) => assert_eq!(n, 6),
|
||||||
|
(2, 2) => assert_eq!(n, 6),
|
||||||
|
(2, 3) => assert_eq!(n, -6),
|
||||||
|
(2, 4) => assert_eq!(n, 9),
|
||||||
|
(2, 5) => assert_eq!(n, -9),
|
||||||
|
(3, 0) => assert_eq!(n, 6),
|
||||||
|
(3, 4) => assert_eq!(n, 8),
|
||||||
|
(4, 0) => assert_eq!(n, 12),
|
||||||
|
(4, 4) => assert_eq!(n, 16),
|
||||||
|
(5, 0) => assert_eq!(n, 18),
|
||||||
|
(5, 1) => assert_eq!(n, -18),
|
||||||
|
(5, 4) => assert_eq!(n, 24),
|
||||||
|
(5, 5) => assert_eq!(n, -24),
|
||||||
|
_ => panic!("index ({},{}) should be 0, found {}", j, i, n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_outer_product_csc() {
|
||||||
|
let mut a = sprs::TriMat::new((2, 3));
|
||||||
|
a.add_triplet(0, 1, 2);
|
||||||
|
a.add_triplet(0, 2, 3);
|
||||||
|
a.add_triplet(1, 0, 6);
|
||||||
|
a.add_triplet(1, 2, 8);
|
||||||
|
let a = a.to_csc();
|
||||||
|
|
||||||
|
let mut b = sprs::TriMat::new((3, 2));
|
||||||
|
b.add_triplet(0, 0, 1);
|
||||||
|
b.add_triplet(1, 0, 2);
|
||||||
|
b.add_triplet(2, 0, 3);
|
||||||
|
b.add_triplet(2, 1, -3);
|
||||||
|
let b = b.to_csc();
|
||||||
|
|
||||||
|
let c = sparse_sparse_outer_product(a.view(), b.view());
|
||||||
|
for (&n, (j, i)) in c.iter() {
|
||||||
|
match (j, i) {
|
||||||
|
(0, 2) => assert_eq!(n, 2),
|
||||||
|
(0, 4) => assert_eq!(n, 3),
|
||||||
|
(1, 2) => assert_eq!(n, 4),
|
||||||
|
(1, 4) => assert_eq!(n, 6),
|
||||||
|
(2, 2) => assert_eq!(n, 6),
|
||||||
|
(2, 3) => assert_eq!(n, -6),
|
||||||
|
(2, 4) => assert_eq!(n, 9),
|
||||||
|
(2, 5) => assert_eq!(n, -9),
|
||||||
|
(3, 0) => assert_eq!(n, 6),
|
||||||
|
(3, 4) => assert_eq!(n, 8),
|
||||||
|
(4, 0) => assert_eq!(n, 12),
|
||||||
|
(4, 4) => assert_eq!(n, 16),
|
||||||
|
(5, 0) => assert_eq!(n, 18),
|
||||||
|
(5, 1) => assert_eq!(n, -18),
|
||||||
|
(5, 4) => assert_eq!(n, 24),
|
||||||
|
(5, 5) => assert_eq!(n, -24),
|
||||||
|
_ => panic!("index ({},{}) should be 0, found {}", j, i, n),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_outer_product_2() {
|
||||||
|
let mut e0 = sprs::CsMat::zero((10, 1));
|
||||||
|
e0.insert(0, 0, 1);
|
||||||
|
let mut en = sprs::CsMat::zero((11, 1));
|
||||||
|
en.insert(10, 0, 1);
|
||||||
|
|
||||||
|
let v = sparse_sparse_outer_product(e0.view(), en.transpose_view());
|
||||||
|
for (&val, (j, i)) in v.iter() {
|
||||||
|
match (j, i) {
|
||||||
|
(0, 10) => assert_eq!(val, 1),
|
||||||
|
_ => panic!("Unexpected element: ({},{}): {}", j, i, val),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue