diff --git a/euler/src/lib.rs b/euler/src/lib.rs index 03d2c54..e9c64d4 100644 --- a/euler/src/lib.rs +++ b/euler/src/lib.rs @@ -325,6 +325,49 @@ impl Field { &vortex_param, ) } + fn iter(&self) -> impl ExactSizeIterator + '_ { + let n = self.nx() * self.ny(); + let slice = self.0.as_slice().unwrap(); + let rho = &slice[0 * n..1 * n]; + let rhou = &slice[1 * n..2 * n]; + let rhov = &slice[2 * n..3 * n]; + let e = &slice[3 * n..4 * n]; + + rho.iter() + .zip(rhou) + .zip(rhov) + .zip(e) + .map(|(((&rho, &rhou), &rhov), &e)| FieldValue { rho, rhou, rhov, e }) + } + fn iter_mut(&mut self) -> impl ExactSizeIterator> + '_ { + let n = self.nx() * self.ny(); + let slice = self.0.as_slice_mut().unwrap(); + let (rho, slice) = slice.split_at_mut(n); + let (rhou, slice) = slice.split_at_mut(n); + let (rhov, slice) = slice.split_at_mut(n); + let (e, slice) = slice.split_at_mut(n); + assert_eq!(slice.len(), 0); + + rho.iter_mut() + .zip(rhou.iter_mut()) + .zip(rhov.iter_mut()) + .zip(e.iter_mut()) + .map(|(((rho, rhou), rhov), e)| FieldValueMut { rho, rhou, rhov, e }) + } +} + +struct FieldValue { + rho: Float, + rhou: Float, + rhov: Float, + e: Float, +} + +struct FieldValueMut<'a> { + rho: &'a mut Float, + rhou: &'a mut Float, + rhov: &'a mut Float, + e: &'a mut Float, } impl Field { @@ -613,22 +656,13 @@ fn upwind_dissipation( metrics: &Metrics, tmp: (&mut Field, &mut Field), ) { - let n = y.nx() * y.ny(); - let yview = y.0.view().into_shape((4, n)).unwrap(); - let mut tmp0 = tmp.0 .0.view_mut().into_shape((4, n)).unwrap(); - let mut tmp1 = tmp.1 .0.view_mut().into_shape((4, n)).unwrap(); - - for (((y, mut tmp0), mut tmp1), metric) in yview - .axis_iter(ndarray::Axis(1)) - .zip(tmp0.axis_iter_mut(ndarray::Axis(1))) - .zip(tmp1.axis_iter_mut(ndarray::Axis(1))) + for (((FieldValue { rho, rhou, rhov, e }, tmp0), tmp1), metric) in y + .iter() + .zip(tmp.0.iter_mut()) + .zip(tmp.1.iter_mut()) .zip(metrics.iter()) { - let rho = y[0]; assert!(rho > 0.0); - let rhou = y[1]; - let rhov = y[2]; - let e = y[3]; let u = rhou / rho; let v = rhov / rho; @@ -647,17 +681,17 @@ fn upwind_dissipation( let alpha_u = uhat.abs() + c * hypot(metric.detj_dxi_dx, metric.detj_dxi_dy); let alpha_v = vhat.abs() + c * hypot(metric.detj_deta_dx, metric.detj_deta_dy); - tmp0[0] = alpha_u * rho; - tmp1[0] = alpha_v * rho; + *tmp0.rho = alpha_u * rho; + *tmp1.rho = alpha_v * rho; - tmp0[1] = alpha_u * rhou; - tmp1[1] = alpha_v * rhou; + *tmp0.rhou = alpha_u * rhou; + *tmp1.rhou = alpha_v * rhou; - tmp0[2] = alpha_u * rhov; - tmp1[2] = alpha_v * rhov; + *tmp0.rhov = alpha_u * rhov; + *tmp1.rhov = alpha_v * rhov; - tmp0[3] = alpha_u * e; - tmp1[3] = alpha_v * e; + *tmp0.e = alpha_u * e; + *tmp1.e = alpha_v * e; } op.dissxi(tmp.0.rho(), k.0.rho_mut());