brainderp
This commit is contained in:
parent
1e84e2ddf0
commit
459581a3c9
|
@ -51,7 +51,7 @@ impl SbpOperator1d for SBP4 {
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
super::h_matrix(Self::DIAG, n, self.is_h2())
|
super::h_matrix(Self::HBLOCK, n, self.is_h2())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -55,7 +55,7 @@ impl SbpOperator1d for SBP8 {
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
super::h_matrix(Self::DIAG, n, self.is_h2())
|
super::h_matrix(Self::HBLOCK, n, self.is_h2())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -164,7 +164,7 @@ diff_simd_col_7_47!(diss_simd_col, Upwind4::DISS_BLOCK, Upwind4::DISS_DIAG, true
|
||||||
|
|
||||||
impl Upwind4 {
|
impl Upwind4 {
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
const HBLOCK: &'static [Float] = &[
|
pub const HBLOCK: &'static [Float] = &[
|
||||||
49.0 / 144.0, 61.0 / 48.0, 41.0 / 48.0, 149.0 / 144.0
|
49.0 / 144.0, 61.0 / 48.0, 41.0 / 48.0, 149.0 / 144.0
|
||||||
];
|
];
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
|
@ -219,7 +219,7 @@ impl SbpOperator1d for Upwind4 {
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
super::h_matrix(Self::DIAG, n, self.is_h2())
|
super::h_matrix(Self::HBLOCK, n, self.is_h2())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -69,7 +69,7 @@ impl SbpOperator1d for Upwind4h2 {
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
super::h_matrix(Self::DIAG, n, self.is_h2())
|
super::h_matrix(Self::HBLOCK, n, self.is_h2())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -74,7 +74,7 @@ impl SbpOperator1d for Upwind9 {
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
super::h_matrix(Self::DIAG, n, self.is_h2())
|
super::h_matrix(Self::HBLOCK, n, self.is_h2())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ impl SbpOperator1d for Upwind9h2 {
|
||||||
}
|
}
|
||||||
#[cfg(feature = "sparse")]
|
#[cfg(feature = "sparse")]
|
||||||
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
fn h_matrix(&self, n: usize) -> sprs::CsMat<Float> {
|
||||||
super::h_matrix(Self::DIAG, n, self.is_h2())
|
super::h_matrix(Self::HBLOCK, n, self.is_h2())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,8 +10,9 @@ pub fn sparse_sparse_outer_product<
|
||||||
A: sprs::CsMatViewI<N, I, Iptr>,
|
A: sprs::CsMatViewI<N, I, Iptr>,
|
||||||
B: sprs::CsMatViewI<N, I, Iptr>,
|
B: sprs::CsMatViewI<N, I, Iptr>,
|
||||||
) -> sprs::CsMatI<N, I, Iptr> {
|
) -> sprs::CsMatI<N, I, Iptr> {
|
||||||
|
use sprs::{CSC, CSR};
|
||||||
match (A.storage(), B.storage()) {
|
match (A.storage(), B.storage()) {
|
||||||
(sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSR) => {
|
(CSR, CSR) => {
|
||||||
let nnz = A.nnz() * B.nnz();
|
let nnz = A.nnz() * B.nnz();
|
||||||
let a_shape = A.shape();
|
let a_shape = A.shape();
|
||||||
let b_shape = B.shape();
|
let b_shape = B.shape();
|
||||||
|
@ -32,7 +33,7 @@ pub fn sparse_sparse_outer_product<
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
mat
|
mat
|
||||||
}
|
}
|
||||||
(sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSC) => {
|
(CSC, CSC) => {
|
||||||
let nnz = A.nnz() * B.nnz();
|
let nnz = A.nnz() * B.nnz();
|
||||||
let a_shape = A.shape();
|
let a_shape = A.shape();
|
||||||
let b_shape = B.shape();
|
let b_shape = B.shape();
|
||||||
|
@ -54,7 +55,7 @@ pub fn sparse_sparse_outer_product<
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
mat
|
mat
|
||||||
}
|
}
|
||||||
(sprs::CompressedStorage::CSR, sprs::CompressedStorage::CSC) => {
|
(CSR, CSC) => {
|
||||||
let nnz = A.nnz() * B.nnz();
|
let nnz = A.nnz() * B.nnz();
|
||||||
let a_shape = A.shape();
|
let a_shape = A.shape();
|
||||||
let b_shape = B.shape();
|
let b_shape = B.shape();
|
||||||
|
@ -75,7 +76,7 @@ pub fn sparse_sparse_outer_product<
|
||||||
debug_assert_eq!(mat.nnz(), nnz);
|
debug_assert_eq!(mat.nnz(), nnz);
|
||||||
mat
|
mat
|
||||||
}
|
}
|
||||||
(sprs::CompressedStorage::CSC, sprs::CompressedStorage::CSR) => {
|
(CSC, CSR) => {
|
||||||
let nnz = A.nnz() * B.nnz();
|
let nnz = A.nnz() * B.nnz();
|
||||||
let a_shape = A.shape();
|
let a_shape = A.shape();
|
||||||
let b_shape = B.shape();
|
let b_shape = B.shape();
|
||||||
|
@ -83,10 +84,10 @@ pub fn sparse_sparse_outer_product<
|
||||||
let mat = sprs::CsMatI::zero(shape);
|
let mat = sprs::CsMatI::zero(shape);
|
||||||
let mut mat = mat.to_csc();
|
let mut mat = mat.to_csc();
|
||||||
|
|
||||||
for (aj, a) in A.outer_iterator().enumerate() {
|
for (ai, a) in A.outer_iterator().enumerate() {
|
||||||
for (bi, b) in B.outer_iterator().enumerate() {
|
for (bj, b) in B.outer_iterator().enumerate() {
|
||||||
for (ai, &a) in a.iter() {
|
for (aj, &a) in a.iter() {
|
||||||
for (bj, &b) in b.iter() {
|
for (bi, &b) in b.iter() {
|
||||||
let i = ai * b_shape.1 + bi;
|
let i = ai * b_shape.1 + bi;
|
||||||
let j = aj * b_shape.0 + bj;
|
let j = aj * b_shape.0 + bj;
|
||||||
mat.insert(j, i, a * b)
|
mat.insert(j, i, a * b)
|
||||||
|
@ -116,82 +117,39 @@ fn test_outer_product() {
|
||||||
b.add_triplet(2, 1, -3);
|
b.add_triplet(2, 1, -3);
|
||||||
let b = b.to_csr();
|
let b = b.to_csr();
|
||||||
|
|
||||||
let c = sparse_sparse_outer_product(a.view(), b.view());
|
let check = |c: sprs::CsMatView<i32>| {
|
||||||
for (&n, (j, i)) in c.iter() {
|
for (&n, (j, i)) in c.iter() {
|
||||||
match (j, i) {
|
match (j, i) {
|
||||||
(0, 2) => assert_eq!(n, 2),
|
(0, 2) => assert_eq!(n, 2),
|
||||||
(0, 4) => assert_eq!(n, 3),
|
(0, 4) => assert_eq!(n, 3),
|
||||||
(1, 2) => assert_eq!(n, 4),
|
(1, 2) => assert_eq!(n, 4),
|
||||||
(1, 4) => assert_eq!(n, 6),
|
(1, 4) => assert_eq!(n, 6),
|
||||||
(2, 2) => assert_eq!(n, 6),
|
(2, 2) => assert_eq!(n, 6),
|
||||||
(2, 3) => assert_eq!(n, -6),
|
(2, 3) => assert_eq!(n, -6),
|
||||||
(2, 4) => assert_eq!(n, 9),
|
(2, 4) => assert_eq!(n, 9),
|
||||||
(2, 5) => assert_eq!(n, -9),
|
(2, 5) => assert_eq!(n, -9),
|
||||||
(3, 0) => assert_eq!(n, 6),
|
(3, 0) => assert_eq!(n, 6),
|
||||||
(3, 4) => assert_eq!(n, 8),
|
(3, 4) => assert_eq!(n, 8),
|
||||||
(4, 0) => assert_eq!(n, 12),
|
(4, 0) => assert_eq!(n, 12),
|
||||||
(4, 4) => assert_eq!(n, 16),
|
(4, 4) => assert_eq!(n, 16),
|
||||||
(5, 0) => assert_eq!(n, 18),
|
(5, 0) => assert_eq!(n, 18),
|
||||||
(5, 1) => assert_eq!(n, -18),
|
(5, 1) => assert_eq!(n, -18),
|
||||||
(5, 4) => assert_eq!(n, 24),
|
(5, 4) => assert_eq!(n, 24),
|
||||||
(5, 5) => assert_eq!(n, -24),
|
(5, 5) => assert_eq!(n, -24),
|
||||||
_ => panic!("index ({},{}) should be 0, found {}", j, i, n),
|
_ => panic!("index ({},{}) should be 0, found {}", j, i, n),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
let c = sparse_sparse_outer_product(a.view(), b.view());
|
||||||
fn test_outer_product_csc() {
|
check(c.view());
|
||||||
let mut a = sprs::TriMat::new((2, 3));
|
|
||||||
a.add_triplet(0, 1, 2);
|
|
||||||
a.add_triplet(0, 2, 3);
|
|
||||||
a.add_triplet(1, 0, 6);
|
|
||||||
a.add_triplet(1, 2, 8);
|
|
||||||
let a = a.to_csc();
|
|
||||||
|
|
||||||
let mut b = sprs::TriMat::new((3, 2));
|
|
||||||
b.add_triplet(0, 0, 1);
|
|
||||||
b.add_triplet(1, 0, 2);
|
|
||||||
b.add_triplet(2, 0, 3);
|
|
||||||
b.add_triplet(2, 1, -3);
|
|
||||||
let b = b.to_csc();
|
let b = b.to_csc();
|
||||||
|
|
||||||
let c = sparse_sparse_outer_product(a.view(), b.view());
|
let c = sparse_sparse_outer_product(a.view(), b.view());
|
||||||
for (&n, (j, i)) in c.iter() {
|
check(c.view());
|
||||||
match (j, i) {
|
let a = a.to_csc();
|
||||||
(0, 2) => assert_eq!(n, 2),
|
let c = sparse_sparse_outer_product(a.view(), b.view());
|
||||||
(0, 4) => assert_eq!(n, 3),
|
check(c.view());
|
||||||
(1, 2) => assert_eq!(n, 4),
|
let b = b.to_csr();
|
||||||
(1, 4) => assert_eq!(n, 6),
|
let c = sparse_sparse_outer_product(a.view(), b.view());
|
||||||
(2, 2) => assert_eq!(n, 6),
|
check(c.view());
|
||||||
(2, 3) => assert_eq!(n, -6),
|
|
||||||
(2, 4) => assert_eq!(n, 9),
|
|
||||||
(2, 5) => assert_eq!(n, -9),
|
|
||||||
(3, 0) => assert_eq!(n, 6),
|
|
||||||
(3, 4) => assert_eq!(n, 8),
|
|
||||||
(4, 0) => assert_eq!(n, 12),
|
|
||||||
(4, 4) => assert_eq!(n, 16),
|
|
||||||
(5, 0) => assert_eq!(n, 18),
|
|
||||||
(5, 1) => assert_eq!(n, -18),
|
|
||||||
(5, 4) => assert_eq!(n, 24),
|
|
||||||
(5, 5) => assert_eq!(n, -24),
|
|
||||||
_ => panic!("index ({},{}) should be 0, found {}", j, i, n),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn test_outer_product_2() {
|
|
||||||
let mut e0 = sprs::CsMat::zero((10, 1));
|
|
||||||
e0.insert(0, 0, 1);
|
|
||||||
let mut en = sprs::CsMat::zero((11, 1));
|
|
||||||
en.insert(10, 0, 1);
|
|
||||||
|
|
||||||
let v = sparse_sparse_outer_product(e0.view(), en.transpose_view());
|
|
||||||
for (&val, (j, i)) in v.iter() {
|
|
||||||
match (j, i) {
|
|
||||||
(0, 10) => assert_eq!(val, 1),
|
|
||||||
_ => panic!("Unexpected element: ({},{}): {}", j, i, val),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue