diff --git a/sbp/examples/multigrid/bin.rs b/sbp/examples/multigrid/bin.rs index f8efe92..541c400 100644 --- a/sbp/examples/multigrid/bin.rs +++ b/sbp/examples/multigrid/bin.rs @@ -338,41 +338,9 @@ fn main() { }; let output = File::create(&opt.output, sys.grids.as_slice()).unwrap(); + let mut output = OutputThread::new(output); - // Pingpong back and forth a number of Vec to be used for the - // output. The sync_channel applies some backpressure - let (tx_thread, rx) = std::sync::mpsc::channel::>(); - let (tx, rx_thread) = std::sync::mpsc::sync_channel::<(u64, Vec)>(3); - let outputthread = std::thread::Builder::new() - .name("multigrid_output".to_owned()) - .spawn(move || { - let mut times = Vec::::new(); - - for (ntime, fields) in rx_thread.iter() { - if !times.contains(&ntime) { - output.add_timestep(ntime, fields.as_slice()).unwrap(); - times.push(ntime); - } - tx_thread.send(fields).unwrap(); - } - }) - .unwrap(); - - let output = |ntime: u64, nowfield: &[euler::Field]| match rx.try_recv() { - Ok(mut fields) => { - for (from, to) in nowfield.iter().zip(fields.iter_mut()) { - to.assign(&from); - } - tx.send((ntime, fields)).unwrap(); - } - Err(std::sync::mpsc::TryRecvError::Empty) => { - let fields = nowfield.to_vec(); - tx.send((ntime, fields)).unwrap(); - } - Err(e) => panic!("{:?}", e), - }; - - output(0, &sys.fnow); + output.add_timestep(0, &sys.fnow); let bar = progressbar(opt.no_progressbar, ntime); for _ in 0..ntime { @@ -381,10 +349,7 @@ fn main() { } bar.finish(); - output(ntime, &sys.fnow); - - std::mem::drop(tx); - outputthread.join().unwrap(); + output.add_timestep(ntime, &sys.fnow); } fn progressbar(dummy: bool, ntime: u64) -> indicatif::ProgressBar { @@ -399,6 +364,70 @@ fn progressbar(dummy: bool, ntime: u64) -> indicatif::ProgressBar { } } +struct OutputThread { + rx: Option>>, + tx: Option)>>, + thread: Option>, +} + +impl OutputThread { + fn new(file: File) -> Self { + // Pingpong back and forth a number of Vec to be used for the + // output. The sync_channel applies some backpressure + let (tx_thread, rx) = std::sync::mpsc::channel::>(); + let (tx, rx_thread) = std::sync::mpsc::sync_channel::<(u64, Vec)>(3); + let thread = std::thread::Builder::new() + .name("multigrid_output".to_owned()) + .spawn(move || { + let mut times = Vec::::new(); + + for (ntime, fields) in rx_thread.iter() { + if !times.contains(&ntime) { + file.add_timestep(ntime, fields.as_slice()).unwrap(); + times.push(ntime); + } + tx_thread.send(fields).unwrap(); + } + }) + .unwrap(); + + Self { + tx: Some(tx), + rx: Some(rx), + thread: Some(thread), + } + } + + fn add_timestep(&mut self, ntime: u64, fields: &[euler::Field]) { + match self.rx.as_ref().unwrap().try_recv() { + Ok(mut copy_fields) => { + for (from, to) in fields.iter().zip(copy_fields.iter_mut()) { + to.assign(&from); + } + self.tx + .as_ref() + .unwrap() + .send((ntime, copy_fields)) + .unwrap(); + } + Err(std::sync::mpsc::TryRecvError::Empty) => { + let fields = fields.to_vec(); + self.tx.as_ref().unwrap().send((ntime, fields)).unwrap(); + } + Err(e) => panic!("{:?}", e), + }; + } +} + +impl Drop for OutputThread { + fn drop(&mut self) { + let tx = self.tx.take(); + std::mem::drop(tx); + let thread = self.thread.take().unwrap(); + thread.join().unwrap(); + } +} + #[derive(Debug, Clone)] struct File(hdf5::File);