remove some unsafe from simd
This commit is contained in:
		@@ -390,10 +390,16 @@ pub(crate) fn diff_op_2d_sliceable_y_simd<const M: usize, const N: usize, const
 | 
				
			|||||||
) {
 | 
					) {
 | 
				
			||||||
    assert_eq!(prev.shape(), fut.shape());
 | 
					    assert_eq!(prev.shape(), fut.shape());
 | 
				
			||||||
    let nx = prev.shape()[1];
 | 
					    let nx = prev.shape()[1];
 | 
				
			||||||
 | 
					    let ny = prev.shape()[0];
 | 
				
			||||||
    assert!(nx >= 2 * M);
 | 
					    assert!(nx >= 2 * M);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    assert_eq!(prev.strides()[0], 1);
 | 
					    assert_eq!(prev.strides(), fut.strides());
 | 
				
			||||||
    assert_eq!(fut.strides()[0], 1);
 | 
					    assert_eq!(prev.strides(), &[1, ny as isize]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    let prev = prev.as_slice_memory_order().unwrap();
 | 
				
			||||||
 | 
					    let fut = fut.as_slice_memory_order_mut().unwrap();
 | 
				
			||||||
 | 
					    let prev = &prev[..nx * ny];
 | 
				
			||||||
 | 
					    let fut = &mut fut[..nx * ny];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    let dx = if optype == OperatorType::H2 {
 | 
					    let dx = if optype == OperatorType::H2 {
 | 
				
			||||||
        1.0 / (nx - 2) as Float
 | 
					        1.0 / (nx - 2) as Float
 | 
				
			||||||
@@ -407,103 +413,88 @@ pub(crate) fn diff_op_2d_sliceable_y_simd<const M: usize, const N: usize, const
 | 
				
			|||||||
    #[cfg(feature = "f32")]
 | 
					    #[cfg(feature = "f32")]
 | 
				
			||||||
    type SimdT = packed_simd::f32x16;
 | 
					    type SimdT = packed_simd::f32x16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    let ny = prev.shape()[0];
 | 
					 | 
				
			||||||
    // How many elements that can be simdified
 | 
					    // How many elements that can be simdified
 | 
				
			||||||
    let simdified = SimdT::lanes() * (ny / SimdT::lanes());
 | 
					    let simdified = SimdT::lanes() * (ny / SimdT::lanes());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    let half_diag_width = (D - 1) / 2;
 | 
					    let (fut0, futmid) = fut.split_at_mut(M * ny);
 | 
				
			||||||
    assert!(half_diag_width <= M);
 | 
					    let (futmid, futn) = futmid.split_at_mut((nx - 2 * M) * ny);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    let fut_base_ptr = fut.as_mut_ptr();
 | 
					    #[inline(always)]
 | 
				
			||||||
    let fut_stride = fut.strides()[1];
 | 
					    fn block_multiply<const M: usize, const N: usize>(
 | 
				
			||||||
    let fut_ptr = |j, i| {
 | 
					        matrix: &Matrix<Float, M, N>,
 | 
				
			||||||
        debug_assert!(j < ny && i < nx);
 | 
					        idx: Float,
 | 
				
			||||||
        unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) }
 | 
					        ny: usize,
 | 
				
			||||||
    };
 | 
					        simdified: usize,
 | 
				
			||||||
 | 
					        prev: &[Float],
 | 
				
			||||||
 | 
					        fut: &mut [Float],
 | 
				
			||||||
 | 
					    ) {
 | 
				
			||||||
 | 
					        assert_eq!(prev.len(), N * ny);
 | 
				
			||||||
 | 
					        assert_eq!(fut.len(), M * ny);
 | 
				
			||||||
 | 
					        let prev = &prev[..N * ny];
 | 
				
			||||||
 | 
					        let fut = &mut fut[..M * ny];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    let prev_base_ptr = prev.as_ptr();
 | 
					        let prevcol = |i: usize| -> &[Float] { &prev[i * ny..(i + 1) * ny] };
 | 
				
			||||||
    let prev_stride = prev.strides()[1];
 | 
					 | 
				
			||||||
    let prev_ptr = |j, i| {
 | 
					 | 
				
			||||||
        debug_assert!(j < ny && i < nx);
 | 
					 | 
				
			||||||
        unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) }
 | 
					 | 
				
			||||||
    };
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Not algo necessary, but gives performance increase
 | 
					        for (&bl, fut) in matrix.iter_rows().zip(fut.chunks_exact_mut(ny)) {
 | 
				
			||||||
    assert_eq!(fut_stride, prev_stride);
 | 
					            let mut fut = fut.chunks_exact_mut(SimdT::lanes());
 | 
				
			||||||
 | 
					            for (j, fut) in fut.by_ref().enumerate() {
 | 
				
			||||||
    // First block
 | 
					                let index_to_simd =
 | 
				
			||||||
    {
 | 
					                    |i| SimdT::from_slice_unaligned(&prevcol(i)[SimdT::lanes() * j..]);
 | 
				
			||||||
        for (ifut, &bl) in matrix.start.iter_rows().enumerate() {
 | 
					 | 
				
			||||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
					 | 
				
			||||||
                let index_to_simd = |i| unsafe {
 | 
					 | 
				
			||||||
                    // j never moves past end of slice due to step_by and
 | 
					 | 
				
			||||||
                    // rounding down
 | 
					 | 
				
			||||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
					 | 
				
			||||||
                        prev_ptr(j, i),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ))
 | 
					 | 
				
			||||||
                };
 | 
					 | 
				
			||||||
                let mut f = SimdT::splat(0.0);
 | 
					                let mut f = SimdT::splat(0.0);
 | 
				
			||||||
                for (iprev, &bl) in bl.iter().enumerate() {
 | 
					                for (iprev, &bl) in bl.iter().enumerate() {
 | 
				
			||||||
                    f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
 | 
					                    f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                f *= idx;
 | 
					                f *= idx;
 | 
				
			||||||
 | 
					                f.write_to_slice_unaligned(fut);
 | 
				
			||||||
                unsafe {
 | 
					 | 
				
			||||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
					 | 
				
			||||||
                        fut_ptr(j, ifut),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ));
 | 
					 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            }
 | 
					            for (j, fut) in (simdified..ny).zip(fut.into_remainder()) {
 | 
				
			||||||
            for j in simdified..ny {
 | 
					 | 
				
			||||||
                unsafe {
 | 
					 | 
				
			||||||
                let mut f = 0.0;
 | 
					                let mut f = 0.0;
 | 
				
			||||||
                for (iprev, bl) in bl.iter().enumerate() {
 | 
					                for (iprev, bl) in bl.iter().enumerate() {
 | 
				
			||||||
                        f += bl * *prev_ptr(j, iprev);
 | 
					                    f += bl * prevcol(iprev)[j];
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                    *fut_ptr(j, ifut) = f * idx;
 | 
					                *fut = f * idx;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // First block
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        let prev = &prev[..N * ny];
 | 
				
			||||||
 | 
					        block_multiply(&matrix.start, idx, ny, simdified, prev, fut0);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Diagonal elements
 | 
					    // Diagonal elements
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        for ifut in M..nx - M {
 | 
					        let half_diag_width = (D - 1) / 2;
 | 
				
			||||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
					        assert!(half_diag_width <= M);
 | 
				
			||||||
                let index_to_simd = |i| unsafe {
 | 
					
 | 
				
			||||||
                    // j never moves past end of slice due to step_by and
 | 
					        let prevcol = |i: usize| -> &[Float] { &prev[i * ny..(i + 1) * ny] };
 | 
				
			||||||
                    // rounding down
 | 
					        for (fut, ifut) in futmid.chunks_exact_mut(ny).zip(M..nx - M) {
 | 
				
			||||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
					            let mut fut = fut.chunks_exact_mut(SimdT::lanes());
 | 
				
			||||||
                        prev_ptr(j, i),
 | 
					            for (j, fut) in fut.by_ref().enumerate() {
 | 
				
			||||||
                        SimdT::lanes(),
 | 
					                let index_to_simd =
 | 
				
			||||||
                    ))
 | 
					                    |i| SimdT::from_slice_unaligned(&prevcol(i)[SimdT::lanes() * j..]);
 | 
				
			||||||
                };
 | 
					 | 
				
			||||||
                let mut f = SimdT::splat(0.0);
 | 
					                let mut f = SimdT::splat(0.0);
 | 
				
			||||||
                for (id, &d) in matrix.diag.iter().enumerate() {
 | 
					                for (id, &d) in matrix.diag.iter().enumerate() {
 | 
				
			||||||
                    let offset = ifut - half_diag_width + id;
 | 
					                    let offset = ifut - half_diag_width + id;
 | 
				
			||||||
                    f = index_to_simd(offset).mul_adde(SimdT::splat(d), f);
 | 
					                    f = index_to_simd(offset).mul_adde(SimdT::splat(d), f);
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                f *= idx;
 | 
					                f *= idx;
 | 
				
			||||||
                unsafe {
 | 
					                {
 | 
				
			||||||
                    // puts simd along stride 1, j never goes past end of slice
 | 
					                    // puts simd along stride 1, j never goes past end of slice
 | 
				
			||||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
					                    f.write_to_slice_unaligned(fut);
 | 
				
			||||||
                        fut_ptr(j, ifut),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ));
 | 
					 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
            for j in simdified..ny {
 | 
					            for (j, fut) in (simdified..ny).zip(fut.into_remainder()) {
 | 
				
			||||||
                let mut f = 0.0;
 | 
					                let mut f = 0.0;
 | 
				
			||||||
                for (id, &d) in matrix.diag.iter().enumerate() {
 | 
					                for (id, &d) in matrix.diag.iter().enumerate() {
 | 
				
			||||||
                    let offset = ifut - half_diag_width + id;
 | 
					                    let offset = ifut - half_diag_width + id;
 | 
				
			||||||
                    unsafe {
 | 
					                    {
 | 
				
			||||||
                        f += d * *prev_ptr(j, offset);
 | 
					                        f += d * prevcol(offset)[j];
 | 
				
			||||||
                    }
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
                unsafe {
 | 
					                {
 | 
				
			||||||
                    *fut_ptr(j, ifut) = idx * f;
 | 
					                    *fut = idx * f;
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
@@ -511,41 +502,8 @@ pub(crate) fn diff_op_2d_sliceable_y_simd<const M: usize, const N: usize, const
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    // End block
 | 
					    // End block
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
        // Get blocks and corresponding offsets
 | 
					        let prev = &prev[((nx - N) * ny)..];
 | 
				
			||||||
        // (rev to iterate in ifut increasing order)
 | 
					        block_multiply(&matrix.end, idx, ny, simdified, prev, futn);
 | 
				
			||||||
        for (bl, ifut) in matrix.end.iter_rows().zip(nx - M..) {
 | 
					 | 
				
			||||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
					 | 
				
			||||||
                let index_to_simd = |i| unsafe {
 | 
					 | 
				
			||||||
                    // j never moves past end of slice due to step_by and
 | 
					 | 
				
			||||||
                    // rounding down
 | 
					 | 
				
			||||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
					 | 
				
			||||||
                        prev_ptr(j, i),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ))
 | 
					 | 
				
			||||||
                };
 | 
					 | 
				
			||||||
                let mut f = SimdT::splat(0.0);
 | 
					 | 
				
			||||||
                for (&bl, iprev) in bl.iter().zip(nx - N..) {
 | 
					 | 
				
			||||||
                    f = index_to_simd(iprev).mul_adde(SimdT::splat(bl), f);
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                f = f * idx;
 | 
					 | 
				
			||||||
                unsafe {
 | 
					 | 
				
			||||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
					 | 
				
			||||||
                        fut_ptr(j, ifut),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ));
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for j in simdified..ny {
 | 
					 | 
				
			||||||
                unsafe {
 | 
					 | 
				
			||||||
                    let mut f = 0.0;
 | 
					 | 
				
			||||||
                    for (&bl, iprev) in bl.iter().zip(nx - N..) {
 | 
					 | 
				
			||||||
                        f += bl * *prev_ptr(j, iprev);
 | 
					 | 
				
			||||||
                    }
 | 
					 | 
				
			||||||
                    *fut_ptr(j, ifut) = f * idx;
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -576,7 +534,11 @@ pub(crate) fn diff_op_2d<const M: usize, const N: usize, const D: usize>(
 | 
				
			|||||||
    assert_eq!(prev.shape(), fut.shape());
 | 
					    assert_eq!(prev.shape(), fut.shape());
 | 
				
			||||||
    match (prev.strides(), fut.strides()) {
 | 
					    match (prev.strides(), fut.strides()) {
 | 
				
			||||||
        ([_, 1], [_, 1]) => diff_op_2d_sliceable(matrix, optype, prev, fut),
 | 
					        ([_, 1], [_, 1]) => diff_op_2d_sliceable(matrix, optype, prev, fut),
 | 
				
			||||||
        ([1, _], [1, _]) => diff_op_2d_sliceable_y_simd(matrix, optype, prev, fut),
 | 
					        ([1, _], [1, _])
 | 
				
			||||||
 | 
					            if prev.as_slice_memory_order().is_some() && fut.as_slice_memory_order().is_some() =>
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            diff_op_2d_sliceable_y_simd(matrix, optype, prev, fut)
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
        _ => diff_op_2d_fallback(matrix, optype, prev, fut),
 | 
					        _ => diff_op_2d_fallback(matrix, optype, prev, fut),
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user