specialise on contigous ny
This commit is contained in:
		@@ -281,6 +281,105 @@ pub(crate) fn diff_op_2d_fallback<const M: usize, const N: usize, const D: usize
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#[inline(always)]
 | 
				
			||||||
 | 
					/// 2D diff when first axis is contiguous
 | 
				
			||||||
 | 
					pub(crate) fn diff_op_2d_sliceable_y<const M: usize, const N: usize, const D: usize>(
 | 
				
			||||||
 | 
					    matrix: &BlockMatrix<Float, M, N, D>,
 | 
				
			||||||
 | 
					    optype: OperatorType,
 | 
				
			||||||
 | 
					    prev: ArrayView2<Float>,
 | 
				
			||||||
 | 
					    mut fut: ArrayViewMut2<Float>,
 | 
				
			||||||
 | 
					) {
 | 
				
			||||||
 | 
					    /* Does not increase the perf...
 | 
				
			||||||
 | 
					    #[cfg(feature = "fast-float")]
 | 
				
			||||||
 | 
					    let (matrix, prev, mut fut) = unsafe {
 | 
				
			||||||
 | 
					        (
 | 
				
			||||||
 | 
					            std::mem::transmute::<_, &BlockMatrix<FastFloat, M, N, D>>(matrix),
 | 
				
			||||||
 | 
					            std::mem::transmute::<_, ArrayView2<FastFloat>>(prev),
 | 
				
			||||||
 | 
					            std::mem::transmute::<_, ArrayViewMut2<FastFloat>>(fut),
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    #[cfg(not(feature = "fast-float"))]
 | 
				
			||||||
 | 
					    let mut fut = fut;
 | 
				
			||||||
 | 
					    */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    assert_eq!(prev.shape(), fut.shape());
 | 
				
			||||||
 | 
					    let nx = prev.shape()[1];
 | 
				
			||||||
 | 
					    let ny = prev.shape()[0];
 | 
				
			||||||
 | 
					    assert!(nx >= 2 * M);
 | 
				
			||||||
 | 
					    assert_eq!(prev.strides()[0], 1);
 | 
				
			||||||
 | 
					    assert_eq!(fut.strides()[0], 1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let dx = if optype == OperatorType::H2 {
 | 
				
			||||||
 | 
					        1.0 / (nx - 2) as Float
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					        1.0 / (nx - 1) as Float
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    let idx = 1.0 / dx;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    fut.fill(0.0.into());
 | 
				
			||||||
 | 
					    let (mut fut0, mut futmid, mut futn) = fut.multi_slice_mut((
 | 
				
			||||||
 | 
					        ndarray::s![.., ..M],
 | 
				
			||||||
 | 
					        ndarray::s![.., M..nx - M],
 | 
				
			||||||
 | 
					        ndarray::s![.., nx - M..],
 | 
				
			||||||
 | 
					    ));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // First block
 | 
				
			||||||
 | 
					    for (bl, mut fut) in matrix
 | 
				
			||||||
 | 
					        .start
 | 
				
			||||||
 | 
					        .iter_rows()
 | 
				
			||||||
 | 
					        .zip(fut0.axis_iter_mut(ndarray::Axis(1)))
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        let fut = &mut fut.as_slice_mut().unwrap()[..ny];
 | 
				
			||||||
 | 
					        for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
 | 
				
			||||||
 | 
					            if bl.is_zero() {
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            let prev = &prev.as_slice().unwrap()[..ny];
 | 
				
			||||||
 | 
					            for (fut, prev) in fut.iter_mut().zip(prev) {
 | 
				
			||||||
 | 
					                *fut = *fut + idx * prev * bl;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let window_elems_to_skip = M - ((D - 1) / 2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // Diagonal entries
 | 
				
			||||||
 | 
					    for (mut fut, id) in futmid
 | 
				
			||||||
 | 
					        .axis_iter_mut(ndarray::Axis(1))
 | 
				
			||||||
 | 
					        .zip(prev.windows((ny, D)).into_iter().skip(window_elems_to_skip))
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        let fut = &mut fut.as_slice_mut().unwrap()[..ny];
 | 
				
			||||||
 | 
					        for (&d, id) in matrix.diag.iter().zip(id.axis_iter(ndarray::Axis(1))) {
 | 
				
			||||||
 | 
					            if d.is_zero() {
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            let id = id.as_slice().unwrap();
 | 
				
			||||||
 | 
					            for (fut, id) in fut.iter_mut().zip(id) {
 | 
				
			||||||
 | 
					                *fut = *fut + idx * id * d;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // End block
 | 
				
			||||||
 | 
					    let prev = prev.slice(ndarray::s!(.., nx - N..));
 | 
				
			||||||
 | 
					    for (bl, mut fut) in matrix
 | 
				
			||||||
 | 
					        .end
 | 
				
			||||||
 | 
					        .iter_rows()
 | 
				
			||||||
 | 
					        .zip(futn.axis_iter_mut(ndarray::Axis(1)))
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        let fut = &mut fut.as_slice_mut().unwrap()[..ny];
 | 
				
			||||||
 | 
					        for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
 | 
				
			||||||
 | 
					            if bl.is_zero() {
 | 
				
			||||||
 | 
					                continue;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            let prev = &prev.as_slice().unwrap()[..ny];
 | 
				
			||||||
 | 
					            for (fut, prev) in fut.iter_mut().zip(prev) {
 | 
				
			||||||
 | 
					                *fut = *fut + idx * prev * bl;
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[inline(always)]
 | 
					#[inline(always)]
 | 
				
			||||||
pub(crate) fn diff_op_2d_sliceable<const M: usize, const N: usize, const D: usize>(
 | 
					pub(crate) fn diff_op_2d_sliceable<const M: usize, const N: usize, const D: usize>(
 | 
				
			||||||
    matrix: &BlockMatrix<Float, M, N, D>,
 | 
					    matrix: &BlockMatrix<Float, M, N, D>,
 | 
				
			||||||
@@ -308,6 +407,7 @@ pub(crate) fn diff_op_2d<const M: usize, const N: usize, const D: usize>(
 | 
				
			|||||||
    assert_eq!(prev.shape(), fut.shape());
 | 
					    assert_eq!(prev.shape(), fut.shape());
 | 
				
			||||||
    match (prev.strides(), fut.strides()) {
 | 
					    match (prev.strides(), fut.strides()) {
 | 
				
			||||||
        ([_, 1], [_, 1]) => diff_op_2d_sliceable(matrix, optype, prev, fut),
 | 
					        ([_, 1], [_, 1]) => diff_op_2d_sliceable(matrix, optype, prev, fut),
 | 
				
			||||||
 | 
					        ([1, _], [1, _]) => diff_op_2d_sliceable_y(matrix, optype, prev, fut),
 | 
				
			||||||
        _ => diff_op_2d_fallback(matrix, optype, prev, fut),
 | 
					        _ => diff_op_2d_fallback(matrix, optype, prev, fut),
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user