Add progressbar inside multi-sys
This commit is contained in:
		@@ -333,6 +333,8 @@ impl BaseSystem {
 | 
			
		||||
                            wb,
 | 
			
		||||
                            wb_ns: Array2::zeros((4, nx)),
 | 
			
		||||
                            wb_ew: Array2::zeros((4, ny)),
 | 
			
		||||
 | 
			
		||||
                            progressbar: None,
 | 
			
		||||
                        };
 | 
			
		||||
 | 
			
		||||
                        sys.run();
 | 
			
		||||
@@ -410,10 +412,10 @@ impl System {
 | 
			
		||||
        match self {
 | 
			
		||||
            Self::SingleThreaded(sys) => sys.progressbar.take().unwrap().finish_and_clear(),
 | 
			
		||||
            Self::MultiThreaded(sys) => {
 | 
			
		||||
                let (target, pbs) = sys.progressbar.take().unwrap();
 | 
			
		||||
                for pb in pbs.into_iter() {
 | 
			
		||||
                    pb.finish_and_clear()
 | 
			
		||||
                for tid in &sys.send {
 | 
			
		||||
                    tid.send(MsgFromHost::ProgressbarDrop).unwrap();
 | 
			
		||||
                }
 | 
			
		||||
                let target = sys.progressbar.take().unwrap();
 | 
			
		||||
                target.clear().unwrap();
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
@@ -627,7 +629,7 @@ pub struct DistributedSystem {
 | 
			
		||||
    /// All threads should be joined to mark the end of the computation
 | 
			
		||||
    sys: Vec<std::thread::JoinHandle<()>>,
 | 
			
		||||
    output: hdf5::File,
 | 
			
		||||
    progressbar: Option<(indicatif::MultiProgress, Vec<indicatif::ProgressBar>)>,
 | 
			
		||||
    progressbar: Option<indicatif::MultiProgress>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl DistributedSystem {
 | 
			
		||||
@@ -635,17 +637,6 @@ impl DistributedSystem {
 | 
			
		||||
        for tid in &self.send {
 | 
			
		||||
            tid.send(MsgFromHost::Advance(ntime)).unwrap();
 | 
			
		||||
        }
 | 
			
		||||
        if let Some(pbar) = &self.progressbar {
 | 
			
		||||
            let expected_messages = ntime * self.sys.len() as u64;
 | 
			
		||||
            for _i in 0..expected_messages {
 | 
			
		||||
                match self.recv.recv().unwrap() {
 | 
			
		||||
                    (i, MsgToHost::CurrentTimestep(_)) => {
 | 
			
		||||
                        pbar.1[i].inc(1);
 | 
			
		||||
                    }
 | 
			
		||||
                    _ => unreachable!(),
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    pub fn output(&self, ntime: u64) {
 | 
			
		||||
        for tid in &self.send {
 | 
			
		||||
@@ -659,13 +650,13 @@ impl DistributedSystem {
 | 
			
		||||
    }
 | 
			
		||||
    pub fn attach_progressbar(&mut self, ntime: u64) {
 | 
			
		||||
        let target = indicatif::MultiProgress::new();
 | 
			
		||||
        let mut progressbars = Vec::with_capacity(self.sys.len());
 | 
			
		||||
        for _ in 0..self.sys.len() {
 | 
			
		||||
        for tid in &self.send {
 | 
			
		||||
            let pb = super::progressbar(ntime);
 | 
			
		||||
            progressbars.push(target.add(pb));
 | 
			
		||||
            let pb = target.add(pb);
 | 
			
		||||
            tid.send(MsgFromHost::Progressbar(pb)).unwrap();
 | 
			
		||||
        }
 | 
			
		||||
        target.set_move_cursor(true);
 | 
			
		||||
        self.progressbar = Some((target, progressbars));
 | 
			
		||||
        self.progressbar = Some(target);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -696,6 +687,10 @@ enum MsgFromHost {
 | 
			
		||||
    Stop,
 | 
			
		||||
    /// Request the current error
 | 
			
		||||
    Error,
 | 
			
		||||
    /// Progressbar to report progress
 | 
			
		||||
    Progressbar(indicatif::ProgressBar),
 | 
			
		||||
    /// Clear and remove the progressbar
 | 
			
		||||
    ProgressbarDrop,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Messages sent back to the host
 | 
			
		||||
@@ -703,8 +698,6 @@ enum MsgFromHost {
 | 
			
		||||
enum MsgToHost {
 | 
			
		||||
    /// Maximum dt allowed by the current grid
 | 
			
		||||
    MaxDt(Float),
 | 
			
		||||
    /// Timestep which we have currently computed
 | 
			
		||||
    CurrentTimestep(u64),
 | 
			
		||||
    /// Error from the current grid
 | 
			
		||||
    Error(Float),
 | 
			
		||||
}
 | 
			
		||||
@@ -750,6 +743,8 @@ struct DistributedSystemPart {
 | 
			
		||||
    wb_ns: Array2<Float>,
 | 
			
		||||
    /// Work buffer for east/west boundary
 | 
			
		||||
    wb_ew: Array2<Float>,
 | 
			
		||||
 | 
			
		||||
    progressbar: Option<indicatif::ProgressBar>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl DistributedSystemPart {
 | 
			
		||||
@@ -759,15 +754,17 @@ impl DistributedSystemPart {
 | 
			
		||||
                MsgFromHost::DtSet(dt) => self.dt = dt,
 | 
			
		||||
                MsgFromHost::DtRequest => {
 | 
			
		||||
                    let dt = self.max_dt();
 | 
			
		||||
                    self.send.send((self.id, MsgToHost::MaxDt(dt))).unwrap();
 | 
			
		||||
                    self.send(MsgToHost::MaxDt(dt)).unwrap();
 | 
			
		||||
                }
 | 
			
		||||
                MsgFromHost::Advance(ntime) => self.advance(ntime),
 | 
			
		||||
                MsgFromHost::Output(ntime) => self.output(ntime),
 | 
			
		||||
                MsgFromHost::Stop => return,
 | 
			
		||||
                MsgFromHost::Error => self
 | 
			
		||||
                    .send
 | 
			
		||||
                    .send((self.id, MsgToHost::Error(self.error())))
 | 
			
		||||
                    .unwrap(),
 | 
			
		||||
                MsgFromHost::Error => self.send(MsgToHost::Error(self.error())).unwrap(),
 | 
			
		||||
                MsgFromHost::Progressbar(pbar) => self.progressbar = Some(pbar),
 | 
			
		||||
                MsgFromHost::ProgressbarDrop => {
 | 
			
		||||
                    let pb = self.progressbar.take().unwrap();
 | 
			
		||||
                    pb.finish_and_clear()
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
@@ -837,10 +834,10 @@ impl DistributedSystemPart {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn advance(&mut self, ntime: u64) {
 | 
			
		||||
        for ntime in 0..ntime {
 | 
			
		||||
            self.send
 | 
			
		||||
                .send((self.id, MsgToHost::CurrentTimestep(ntime)))
 | 
			
		||||
                .unwrap();
 | 
			
		||||
        for _itime in 0..ntime {
 | 
			
		||||
            if let Some(pbar) = &self.progressbar {
 | 
			
		||||
                pbar.inc(1)
 | 
			
		||||
            }
 | 
			
		||||
            let metrics = &self.grid.1;
 | 
			
		||||
            let wb = &mut self.wb.0;
 | 
			
		||||
            let sbp = &self.sbp;
 | 
			
		||||
@@ -1128,6 +1125,10 @@ impl DistributedSystemPart {
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn send(&self, msg: MsgToHost) -> Result<(), crossbeam_channel::SendError<(usize, MsgToHost)>> {
 | 
			
		||||
        self.send.send((self.id, msg))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn error(&self) -> Float {
 | 
			
		||||
        let mut fvort = self.current.clone();
 | 
			
		||||
        match &self.initial_conditions {
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user