diff --git a/src/configure.rs b/src/configure.rs index eb69ad25c..d6589975a 100644 --- a/src/configure.rs +++ b/src/configure.rs @@ -70,7 +70,7 @@ impl Context { let res = self .inner_configure() - .race(cancel_channel.recv().map(|_| { + .race(cancel_channel.map(|_| { progress!(self, 0); Ok(()) })) diff --git a/src/context.rs b/src/context.rs index a9bbd01ad..8835aeb51 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,16 +2,19 @@ use std::collections::{BTreeMap, HashMap}; use std::ffi::OsString; +use std::future::Future; use std::ops::Deref; use std::path::{Path, PathBuf}; +use std::pin::Pin; use std::sync::atomic::AtomicBool; use std::sync::Arc; +use std::task::Poll; use std::time::{Duration, Instant, SystemTime}; use anyhow::{bail, ensure, Context as _, Result}; -use async_channel::{self as channel, Receiver, Sender}; +use async_channel::Sender; use ratelimit::Ratelimit; -use tokio::sync::{Mutex, RwLock}; +use tokio::sync::{oneshot, Mutex, RwLock}; use tokio::task; use crate::chat::{get_chat_cnt, ChatId}; @@ -257,7 +260,7 @@ pub(crate) struct DebugLogging { #[derive(Debug)] enum RunningState { /// Ongoing process is allocated. - Running { cancel_sender: Sender<()> }, + Running { cancel_sender: oneshot::Sender<()> }, /// Cancel signal has been sent, waiting for ongoing process to be freed. ShallStop, @@ -511,21 +514,35 @@ impl Context { /// This is for modal operations during which no other user actions are allowed. Only /// one such operation is allowed at any given time. /// - /// The return value is a cancel token, which will release the ongoing mutex when - /// dropped. - pub(crate) async fn alloc_ongoing(&self) -> Result> { + /// The return value is a guard which does two things: + /// + /// - It is a Future which will complete when the ongoing process is cancelled using + /// [`Context::stop_ongoing`] and must stop. + /// - It will free the ongoing process, aka release the mutex, when dropped. + pub(crate) async fn alloc_ongoing(&self) -> Result { let mut s = self.running_state.write().await; ensure!( matches!(*s, RunningState::Stopped), "There is already another ongoing process running." ); - let (sender, receiver) = channel::bounded(1); + let (cancel_tx, cancel_rx) = oneshot::channel(); *s = RunningState::Running { - cancel_sender: sender, + cancel_sender: cancel_tx, }; + let (drop_tx, drop_rx) = oneshot::channel(); + let context = self.clone(); - Ok(receiver) + tokio::spawn(async move { + drop_rx.await.ok(); + let mut s = context.running_state.write().await; + *s = RunningState::Stopped; + }); + + Ok(OngoingGuard { + cancel_rx, + drop_tx: Some(drop_tx), + }) } pub(crate) async fn free_ongoing(&self) { @@ -536,21 +553,24 @@ impl Context { /// Signal an ongoing process to stop. pub async fn stop_ongoing(&self) { let mut s = self.running_state.write().await; - match &*s { - RunningState::Running { cancel_sender } => { - if let Err(err) = cancel_sender.send(()).await { - warn!(self, "could not cancel ongoing: {:#}", err); - } - info!(self, "Signaling the ongoing process to stop ASAP.",); - *s = RunningState::ShallStop; - } + + // Take out the state so we can call the oneshot sender (which takes ownership). + let current_state = std::mem::replace(&mut *s, RunningState::ShallStop); + + match current_state { + RunningState::Running { cancel_sender } => match cancel_sender.send(()) { + Ok(()) => info!(self, "Signaling the ongoing process to stop ASAP."), + Err(()) => warn!(self, "could not cancel ongoing"), + }, RunningState::ShallStop | RunningState::Stopped => { + // Put back the current state + *s = current_state; info!(self, "No ongoing process to stop.",); } } } - #[allow(unused)] + #[cfg(test)] pub(crate) async fn shall_stop_ongoing(&self) -> bool { match &*self.running_state.read().await { RunningState::Running { .. } => false, @@ -945,6 +965,54 @@ impl Context { } } +/// Guard received when calling [`Context::alloc_ongoing`]. +/// +/// While holding this guard the ongoing mutex is held, dropping this guard frees the +/// ongoing process. +/// +/// The ongoing process can also be cancelled by unrelated code calling +/// [`Context::stop_ongoing`]. This guard implements [`Future`] and the future will +/// complete when the ongoing process is cancelled and must be aborted. Freeing the ongoing +/// process works as usual in this case: when this guard is dropped. So if you need to do +/// some more work before freeing make sure to keep ownership of the guard, e.g.: +/// +/// ```no_compile +/// let mut guard = context.alloc_ongoing().await?; +/// tokio::select!{ +/// biased; +/// _ = &mut guard => (), // guard is not moved, so we keep ownership. +/// _ = do_work() => (), +/// }; +/// do_cleaup().await; +/// drop(guard); +/// ``` +pub(crate) struct OngoingGuard { + /// Receives a message when the ongoing process should be cancelled. + cancel_rx: oneshot::Receiver<()>, + /// Used by `Drop` to send a message which will free the ongoing process. + drop_tx: Option>, +} + +impl Future for OngoingGuard { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match Pin::new(&mut self.cancel_rx).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } +} + +impl Drop for OngoingGuard { + fn drop(&mut self) { + if let Some(sender) = self.drop_tx.take() { + // TODO: Maybe this should log? But we'd need to have a context. + sender.send(()).ok(); + } + } +} + /// Returns core version as a string. pub fn get_version_str() -> &'static str { &DC_VERSION_STR @@ -1409,38 +1477,52 @@ mod tests { async fn test_ongoing() -> Result<()> { let context = TestContext::new().await; - // No ongoing process allocated. + println!("No ongoing process allocated."); assert!(context.shall_stop_ongoing().await); - let receiver = context.alloc_ongoing().await?; + let mut guard = context.alloc_ongoing().await?; - // Cannot allocate another ongoing process while the first one is running. + println!("Cannot allocate another ongoing process while the first one is running."); assert!(context.alloc_ongoing().await.is_err()); - // Stop signal is not sent yet. - assert!(receiver.try_recv().is_err()); + println!("Stop signal is not sent yet."); + assert!(matches!(futures::poll!(&mut guard), Poll::Pending)); assert!(!context.shall_stop_ongoing().await); - // Send the stop signal. + println!("Send the stop signal."); context.stop_ongoing().await; - // Receive stop signal. - receiver.recv().await?; + println!("Receive stop signal."); + (&mut guard).await; assert!(context.shall_stop_ongoing().await); - // Ongoing process is still running even though stop signal was received, - // so another one cannot be allocated. + println!("Ongoing process still running even though stop signal was received"); assert!(context.alloc_ongoing().await.is_err()); - context.free_ongoing().await; + println!("free the ongoing process"); + // context.free_ongoing().await; + drop(guard); - // No ongoing process allocated, should have been stopped already. - assert!(context.shall_stop_ongoing().await); - - // Another ongoing process can be allocated now. - let _receiver = context.alloc_ongoing().await?; + println!("re-acquire the ongoing process"); + // Since the drop guard needs to send a message and the receiving task must run and + // acquire a lock this needs some time so won't succeed immediately. + #[allow(clippy::async_yields_async)] + let _guard = tokio::time::timeout(Duration::from_secs(10), async { + loop { + match context.alloc_ongoing().await { + Ok(guard) => break guard, + Err(_) => { + // tokio::task::yield_now() results in a lot hotter loop, it takes a + // lot of yields. + tokio::time::sleep(Duration::from_millis(1)).await; + } + } + } + }) + .await + .expect("timeout"); Ok(()) } diff --git a/src/imex.rs b/src/imex.rs index 885187cae..0dc519d52 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -94,7 +94,7 @@ pub async fn imex( let _guard = context.scheduler.pause(context.clone()).await; imex_inner(context, what, path, passphrase) .race(async { - cancel.recv().await.ok(); + cancel.await; Err(format_err!("canceled")) }) .await diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index d93cf4f6d..3ef7a16fe 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -31,7 +31,6 @@ use std::pin::Pin; use std::task::Poll; use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result}; -use async_channel::Receiver; use futures_lite::StreamExt; use iroh::get::{DataStream, Options}; use iroh::progress::ProgressEmitter; @@ -47,7 +46,7 @@ use tokio_stream::wrappers::ReadDirStream; use crate::blob::BlobDirContents; use crate::chat::delete_and_reset_all_device_msgs; -use crate::context::Context; +use crate::context::{Context, OngoingGuard}; use crate::qr::Qr; use crate::{e2ee, EventType}; @@ -91,7 +90,7 @@ impl BackupProvider { .context("Private key not available, aborting backup export")?; // Acquire global "ongoing" mutex. - let cancel_token = context.alloc_ongoing().await?; + let mut cancel_token = context.alloc_ongoing().await?; let paused_guard = context.scheduler.pause(context.clone()).await; let context_dir = context .get_blobdir() @@ -114,7 +113,7 @@ impl BackupProvider { }, } }, - _ = cancel_token.recv() => Err(format_err!("cancelled")), + _ = &mut cancel_token => Err(format_err!("cancelled")), }; let (provider, ticket) = match res { Ok((provider, ticket)) => (provider, ticket), @@ -188,7 +187,7 @@ impl BackupProvider { async fn watch_provider( context: &Context, mut provider: Provider, - cancel_token: Receiver<()>, + mut cancel_token: OngoingGuard, ) -> Result<()> { // _dbfile exists so we can clean up the file once it is no longer needed let mut events = provider.subscribe(); @@ -248,7 +247,7 @@ impl BackupProvider { } } }, - _ = cancel_token.recv() => { + _ = &mut cancel_token => { provider.shutdown(); break Err(anyhow!("BackupSender cancelled")); }, @@ -381,7 +380,7 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { context.free_ongoing().await; res } - _ = cancel_token.recv() => Err(format_err!("cancelled")), + _ = cancel_token => Err(format_err!("cancelled")), }; res }