readd diff_op naive

This commit is contained in:
Magnus Ulimoen 2020-05-01 17:44:33 +02:00
parent 78da9baaea
commit 29db6b73df
1 changed files with 75 additions and 0 deletions

View File

@ -186,12 +186,87 @@ enum OperatorType {
H2, H2,
} }
#[inline(always)]
#[allow(unused)]
fn diff_op_col_naive(
block: &'static [&'static [Float]],
diag: &'static [Float],
symmetry: Symmetry,
optype: OperatorType,
) -> impl Fn(ArrayView2<Float>, ArrayViewMut2<Float>) {
#[inline(always)]
move |prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>| {
assert_eq!(prev.shape(), fut.shape());
let nx = prev.shape()[1];
assert!(nx >= 2 * block.len());
assert_eq!(prev.strides()[0], 1);
assert_eq!(fut.strides()[0], 1);
let dx = if optype == OperatorType::H2 {
1.0 / (nx - 2) as Float
} else {
1.0 / (nx - 1) as Float
};
let idx = 1.0 / dx;
fut.fill(0.0);
// First block
for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) {
debug_assert_eq!(fut.len(), prev.shape()[0]);
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
debug_assert_eq!(prev.len(), fut.len());
fut.scaled_add(idx * bl, &prev);
}
}
let half_diag_width = (diag.len() - 1) / 2;
assert!(half_diag_width <= block.len());
// Diagonal entries
for (ifut, mut fut) in fut
.axis_iter_mut(ndarray::Axis(1))
.enumerate()
.skip(block.len())
.take(nx - 2 * block.len())
{
for (id, &d) in diag.iter().enumerate() {
let offset = ifut - half_diag_width + id;
fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset]))
}
}
// End block
for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) {
fut.fill(0.0);
for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) {
if symmetry == Symmetry::Symmetric {
fut.scaled_add(idx * bl, &prev);
} else {
fut.scaled_add(-idx * bl, &prev);
}
}
}
}
}
#[inline(always)] #[inline(always)]
fn diff_op_col( fn diff_op_col(
block: &'static [&'static [Float]], block: &'static [&'static [Float]],
diag: &'static [Float], diag: &'static [Float],
symmetry: Symmetry, symmetry: Symmetry,
optype: OperatorType, optype: OperatorType,
) -> impl Fn(ArrayView2<Float>, ArrayViewMut2<Float>) {
diff_op_col_simd(block, diag, symmetry, optype)
}
#[inline(always)]
fn diff_op_col_simd(
block: &'static [&'static [Float]],
diag: &'static [Float],
symmetry: Symmetry,
optype: OperatorType,
) -> impl Fn(ArrayView2<Float>, ArrayViewMut2<Float>) { ) -> impl Fn(ArrayView2<Float>, ArrayViewMut2<Float>) {
#[inline(always)] #[inline(always)]
move |prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>| { move |prev: ArrayView2<Float>, mut fut: ArrayViewMut2<Float>| {