simdify diff_op_col
This commit is contained in:
		@@ -192,10 +192,16 @@ pub(crate) fn diff_op_col(
 | 
			
		||||
        };
 | 
			
		||||
        let idx = 1.0 / dx;
 | 
			
		||||
 | 
			
		||||
        fut.fill(0.0);
 | 
			
		||||
        #[cfg(not(feature = "f32"))]
 | 
			
		||||
        type SimdT = packed_simd::f64x8;
 | 
			
		||||
        #[cfg(feature = "f32")]
 | 
			
		||||
        type SimdT = packed_simd::f32x16;
 | 
			
		||||
 | 
			
		||||
        let ny = prev.shape()[0];
 | 
			
		||||
 | 
			
		||||
        // First block
 | 
			
		||||
        for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) {
 | 
			
		||||
            fut.fill(0.0);
 | 
			
		||||
            debug_assert_eq!(fut.len(), prev.shape()[0]);
 | 
			
		||||
            for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
 | 
			
		||||
                debug_assert_eq!(prev.len(), fut.len());
 | 
			
		||||
@@ -204,22 +210,44 @@ pub(crate) fn diff_op_col(
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let half_diag_width = (diag.len() - 1) / 2;
 | 
			
		||||
        assert!(half_diag_width <= block.len());
 | 
			
		||||
 | 
			
		||||
        // Diagonal entries
 | 
			
		||||
        for (ifut, mut fut) in fut
 | 
			
		||||
            .axis_iter_mut(ndarray::Axis(1))
 | 
			
		||||
            .enumerate()
 | 
			
		||||
            .skip(block.len())
 | 
			
		||||
            .take(nx - 2 * block.len())
 | 
			
		||||
        {
 | 
			
		||||
            for (id, d) in diag.iter().enumerate() {
 | 
			
		||||
                let offset = ifut - half_diag_width + id;
 | 
			
		||||
                fut.scaled_add(idx * d, &prev.slice(ndarray::s![.., offset]))
 | 
			
		||||
        for ifut in block.len()..nx - block.len() {
 | 
			
		||||
            let simdified = SimdT::lanes() * (ny / SimdT::lanes());
 | 
			
		||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
			
		||||
                let index_to_simd = |(j, i)| unsafe {
 | 
			
		||||
                    // gets simd along stride 1, j never goes past end of slice
 | 
			
		||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
			
		||||
                        prev.uget((j, i)),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ))
 | 
			
		||||
                };
 | 
			
		||||
                let mut f = SimdT::splat(0.0);
 | 
			
		||||
                for (id, &d) in diag.iter().enumerate() {
 | 
			
		||||
                    let offset = ifut - half_diag_width + id;
 | 
			
		||||
                    f = f + d * index_to_simd((j, offset));
 | 
			
		||||
                }
 | 
			
		||||
                f = f * idx;
 | 
			
		||||
                unsafe {
 | 
			
		||||
                    // puts simd along stride 1, j never goes past end of slice
 | 
			
		||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
			
		||||
                        fut.uget_mut((j, ifut)),
 | 
			
		||||
                        SimdT::lanes(),
 | 
			
		||||
                    ));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            for j in simdified..ny {
 | 
			
		||||
                let mut f = 0.0;
 | 
			
		||||
                for (id, &d) in diag.iter().enumerate() {
 | 
			
		||||
                    let offset = ifut - half_diag_width + id;
 | 
			
		||||
                    f += d * prev[(j, offset)];
 | 
			
		||||
                }
 | 
			
		||||
                fut[(j, ifut)] = idx * f;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // End block
 | 
			
		||||
        for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) {
 | 
			
		||||
            fut.fill(0.0);
 | 
			
		||||
            for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) {
 | 
			
		||||
                if symmetric {
 | 
			
		||||
                    fut.scaled_add(idx * bl, &prev);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user