diff --git a/multigrid/src/system.rs b/multigrid/src/system.rs index 1ba6bd1..add187f 100644 --- a/multigrid/src/system.rs +++ b/multigrid/src/system.rs @@ -333,6 +333,8 @@ impl BaseSystem { wb, wb_ns: Array2::zeros((4, nx)), wb_ew: Array2::zeros((4, ny)), + + progressbar: None, }; sys.run(); @@ -410,10 +412,10 @@ impl System { match self { Self::SingleThreaded(sys) => sys.progressbar.take().unwrap().finish_and_clear(), Self::MultiThreaded(sys) => { - let (target, pbs) = sys.progressbar.take().unwrap(); - for pb in pbs.into_iter() { - pb.finish_and_clear() + for tid in &sys.send { + tid.send(MsgFromHost::ProgressbarDrop).unwrap(); } + let target = sys.progressbar.take().unwrap(); target.clear().unwrap(); } } @@ -627,7 +629,7 @@ pub struct DistributedSystem { /// All threads should be joined to mark the end of the computation sys: Vec>, output: hdf5::File, - progressbar: Option<(indicatif::MultiProgress, Vec)>, + progressbar: Option, } impl DistributedSystem { @@ -635,17 +637,6 @@ impl DistributedSystem { for tid in &self.send { tid.send(MsgFromHost::Advance(ntime)).unwrap(); } - if let Some(pbar) = &self.progressbar { - let expected_messages = ntime * self.sys.len() as u64; - for _i in 0..expected_messages { - match self.recv.recv().unwrap() { - (i, MsgToHost::CurrentTimestep(_)) => { - pbar.1[i].inc(1); - } - _ => unreachable!(), - } - } - } } pub fn output(&self, ntime: u64) { for tid in &self.send { @@ -659,13 +650,13 @@ impl DistributedSystem { } pub fn attach_progressbar(&mut self, ntime: u64) { let target = indicatif::MultiProgress::new(); - let mut progressbars = Vec::with_capacity(self.sys.len()); - for _ in 0..self.sys.len() { + for tid in &self.send { let pb = super::progressbar(ntime); - progressbars.push(target.add(pb)); + let pb = target.add(pb); + tid.send(MsgFromHost::Progressbar(pb)).unwrap(); } target.set_move_cursor(true); - self.progressbar = Some((target, progressbars)); + self.progressbar = Some(target); } } @@ -696,6 +687,10 @@ enum MsgFromHost { Stop, /// Request the current error Error, + /// Progressbar to report progress + Progressbar(indicatif::ProgressBar), + /// Clear and remove the progressbar + ProgressbarDrop, } /// Messages sent back to the host @@ -703,8 +698,6 @@ enum MsgFromHost { enum MsgToHost { /// Maximum dt allowed by the current grid MaxDt(Float), - /// Timestep which we have currently computed - CurrentTimestep(u64), /// Error from the current grid Error(Float), } @@ -750,6 +743,8 @@ struct DistributedSystemPart { wb_ns: Array2, /// Work buffer for east/west boundary wb_ew: Array2, + + progressbar: Option, } impl DistributedSystemPart { @@ -759,15 +754,17 @@ impl DistributedSystemPart { MsgFromHost::DtSet(dt) => self.dt = dt, MsgFromHost::DtRequest => { let dt = self.max_dt(); - self.send.send((self.id, MsgToHost::MaxDt(dt))).unwrap(); + self.send(MsgToHost::MaxDt(dt)).unwrap(); } MsgFromHost::Advance(ntime) => self.advance(ntime), MsgFromHost::Output(ntime) => self.output(ntime), MsgFromHost::Stop => return, - MsgFromHost::Error => self - .send - .send((self.id, MsgToHost::Error(self.error()))) - .unwrap(), + MsgFromHost::Error => self.send(MsgToHost::Error(self.error())).unwrap(), + MsgFromHost::Progressbar(pbar) => self.progressbar = Some(pbar), + MsgFromHost::ProgressbarDrop => { + let pb = self.progressbar.take().unwrap(); + pb.finish_and_clear() + } } } } @@ -837,10 +834,10 @@ impl DistributedSystemPart { } fn advance(&mut self, ntime: u64) { - for ntime in 0..ntime { - self.send - .send((self.id, MsgToHost::CurrentTimestep(ntime))) - .unwrap(); + for _itime in 0..ntime { + if let Some(pbar) = &self.progressbar { + pbar.inc(1) + } let metrics = &self.grid.1; let wb = &mut self.wb.0; let sbp = &self.sbp; @@ -1128,6 +1125,10 @@ impl DistributedSystemPart { } } + fn send(&self, msg: MsgToHost) -> Result<(), crossbeam_channel::SendError<(usize, MsgToHost)>> { + self.send.send((self.id, msg)) + } + fn error(&self) -> Float { let mut fvort = self.current.clone(); match &self.initial_conditions {