use raw pointer in simd code
This commit is contained in:
		@@ -290,61 +290,91 @@ fn diff_op_col_simd(
 | 
				
			|||||||
        type SimdT = packed_simd::f32x16;
 | 
					        type SimdT = packed_simd::f32x16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let ny = prev.shape()[0];
 | 
					        let ny = prev.shape()[0];
 | 
				
			||||||
 | 
					        // How many elements that can be simdified
 | 
				
			||||||
 | 
					        let simdified = SimdT::lanes() * (ny / SimdT::lanes());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // First block
 | 
					        // First block
 | 
				
			||||||
        for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) {
 | 
					        {
 | 
				
			||||||
            fut.fill(0.0);
 | 
					            for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1))) {
 | 
				
			||||||
            debug_assert_eq!(fut.len(), prev.shape()[0]);
 | 
					                fut.fill(0.0);
 | 
				
			||||||
            for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
 | 
					                debug_assert_eq!(fut.len(), prev.shape()[0]);
 | 
				
			||||||
                debug_assert_eq!(prev.len(), fut.len());
 | 
					                for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1))) {
 | 
				
			||||||
                fut.scaled_add(idx * bl, &prev);
 | 
					                    debug_assert_eq!(prev.len(), fut.len());
 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        let half_diag_width = (diag.len() - 1) / 2;
 | 
					 | 
				
			||||||
        assert!(half_diag_width <= block.len());
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for ifut in block.len()..nx - block.len() {
 | 
					 | 
				
			||||||
            let simdified = SimdT::lanes() * (ny / SimdT::lanes());
 | 
					 | 
				
			||||||
            for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
					 | 
				
			||||||
                let index_to_simd = |(j, i)| unsafe {
 | 
					 | 
				
			||||||
                    // gets simd along stride 1, j never goes past end of slice
 | 
					 | 
				
			||||||
                    SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
					 | 
				
			||||||
                        prev.uget((j, i)),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ))
 | 
					 | 
				
			||||||
                };
 | 
					 | 
				
			||||||
                let mut f = SimdT::splat(0.0);
 | 
					 | 
				
			||||||
                for (id, &d) in diag.iter().enumerate() {
 | 
					 | 
				
			||||||
                    let offset = ifut - half_diag_width + id;
 | 
					 | 
				
			||||||
                    f = f + d * index_to_simd((j, offset));
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                f = f * idx;
 | 
					 | 
				
			||||||
                unsafe {
 | 
					 | 
				
			||||||
                    // puts simd along stride 1, j never goes past end of slice
 | 
					 | 
				
			||||||
                    f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
					 | 
				
			||||||
                        fut.uget_mut((j, ifut)),
 | 
					 | 
				
			||||||
                        SimdT::lanes(),
 | 
					 | 
				
			||||||
                    ));
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
            for j in simdified..ny {
 | 
					 | 
				
			||||||
                let mut f = 0.0;
 | 
					 | 
				
			||||||
                for (id, &d) in diag.iter().enumerate() {
 | 
					 | 
				
			||||||
                    let offset = ifut - half_diag_width + id;
 | 
					 | 
				
			||||||
                    f += d * prev[(j, offset)];
 | 
					 | 
				
			||||||
                }
 | 
					 | 
				
			||||||
                fut[(j, ifut)] = idx * f;
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        // End block
 | 
					 | 
				
			||||||
        for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) {
 | 
					 | 
				
			||||||
            fut.fill(0.0);
 | 
					 | 
				
			||||||
            for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) {
 | 
					 | 
				
			||||||
                if symmetry == Symmetry::Symmetric {
 | 
					 | 
				
			||||||
                    fut.scaled_add(idx * bl, &prev);
 | 
					                    fut.scaled_add(idx * bl, &prev);
 | 
				
			||||||
                } else {
 | 
					                }
 | 
				
			||||||
                    fut.scaled_add(-idx * bl, &prev);
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // Diagonal elements
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            let half_diag_width = (diag.len() - 1) / 2;
 | 
				
			||||||
 | 
					            assert!(half_diag_width <= block.len());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            let fut_base_ptr = fut.as_mut_ptr();
 | 
				
			||||||
 | 
					            let fut_stride = fut.strides()[1];
 | 
				
			||||||
 | 
					            let fut_ptr = |j, i| {
 | 
				
			||||||
 | 
					                debug_assert!(j < ny && i < nx);
 | 
				
			||||||
 | 
					                unsafe { fut_base_ptr.offset(fut_stride * i as isize + j as isize) }
 | 
				
			||||||
 | 
					            };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            let prev_base_ptr = prev.as_ptr();
 | 
				
			||||||
 | 
					            let prev_stride = prev.strides()[1];
 | 
				
			||||||
 | 
					            let prev_ptr = |j, i| {
 | 
				
			||||||
 | 
					                debug_assert!(j < ny && i < nx);
 | 
				
			||||||
 | 
					                unsafe { prev_base_ptr.offset(prev_stride * i as isize + j as isize) }
 | 
				
			||||||
 | 
					            };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            assert_eq!(fut_stride, prev_stride);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for ifut in block.len()..nx - block.len() {
 | 
				
			||||||
 | 
					                for j in (0..simdified).step_by(SimdT::lanes()) {
 | 
				
			||||||
 | 
					                    let index_to_simd = |(j, i)| unsafe {
 | 
				
			||||||
 | 
					                        // j never moves past end of slice due to step_by and
 | 
				
			||||||
 | 
					                        // rounding down
 | 
				
			||||||
 | 
					                        SimdT::from_slice_unaligned(std::slice::from_raw_parts(
 | 
				
			||||||
 | 
					                            prev_ptr(j, i),
 | 
				
			||||||
 | 
					                            SimdT::lanes(),
 | 
				
			||||||
 | 
					                        ))
 | 
				
			||||||
 | 
					                    };
 | 
				
			||||||
 | 
					                    let mut f = SimdT::splat(0.0);
 | 
				
			||||||
 | 
					                    for (id, &d) in diag.iter().enumerate() {
 | 
				
			||||||
 | 
					                        let offset = ifut - half_diag_width + id;
 | 
				
			||||||
 | 
					                        f = index_to_simd((j, offset)).mul_adde(SimdT::splat(d), f);
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    f = f * idx;
 | 
				
			||||||
 | 
					                    unsafe {
 | 
				
			||||||
 | 
					                        // puts simd along stride 1, j never goes past end of slice
 | 
				
			||||||
 | 
					                        f.write_to_slice_unaligned(std::slice::from_raw_parts_mut(
 | 
				
			||||||
 | 
					                            fut_ptr(j, ifut),
 | 
				
			||||||
 | 
					                            SimdT::lanes(),
 | 
				
			||||||
 | 
					                        ));
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                for j in simdified..ny {
 | 
				
			||||||
 | 
					                    let mut f = 0.0;
 | 
				
			||||||
 | 
					                    for (id, &d) in diag.iter().enumerate() {
 | 
				
			||||||
 | 
					                        let offset = ifut - half_diag_width + id;
 | 
				
			||||||
 | 
					                        unsafe {
 | 
				
			||||||
 | 
					                            f += d * *prev_ptr(j, offset);
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                    unsafe {
 | 
				
			||||||
 | 
					                        *fut_ptr(j, ifut) = idx * f;
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        // End block
 | 
				
			||||||
 | 
					        {
 | 
				
			||||||
 | 
					            for (bl, mut fut) in block.iter().zip(fut.axis_iter_mut(ndarray::Axis(1)).rev()) {
 | 
				
			||||||
 | 
					                fut.fill(0.0);
 | 
				
			||||||
 | 
					                for (&bl, prev) in bl.iter().zip(prev.axis_iter(ndarray::Axis(1)).rev()) {
 | 
				
			||||||
 | 
					                    if symmetry == Symmetry::Symmetric {
 | 
				
			||||||
 | 
					                        fut.scaled_add(idx * bl, &prev);
 | 
				
			||||||
 | 
					                    } else {
 | 
				
			||||||
 | 
					                        fut.scaled_add(-idx * bl, &prev);
 | 
				
			||||||
 | 
					                    }
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user