diff --git a/src/imex.rs b/src/imex.rs index f19b02c4c..8310e9a13 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -8,13 +8,12 @@ use std::iter::FusedIterator; use std::path::{Path, PathBuf}; use ::pgp::types::KeyTrait; -use anyhow::{bail, ensure, format_err, Context as _, Result}; +use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result}; use async_channel::Receiver; use futures::StreamExt; use futures_lite::FutureExt; use rand::{thread_rng, Rng}; use tokio::fs::{self, File}; -use tokio::task::JoinError; use tokio_stream::wrappers::ReadDirStream; use tokio_tar::Archive; @@ -771,7 +770,7 @@ impl BackupSender { // - [x] check i/o is not running // - [x] check we have secret key // - [x] alloc ongoing - // - [ ] correctly cancel Provider when cancelled + // - [x] correctly cancel Provider when cancelled // - [x] create auth token // - [x] export backup with generated token as password // - needs a path to store the database @@ -787,7 +786,9 @@ impl BackupSender { /// or database are happening, this is done by calling the `dc_accounts_stop_io` or /// `dc_stop_io` APIs first. TODO: Add the rust equivalents. /// - /// This will acquire the global "ongoing" mutex. + /// This will acquire the global "ongoing process" mutex. You must call + /// [`BackupSender::join`] after creating this struct, otherwise this will not respect + /// the possible cancellation of the "ongoing process". pub async fn perpare(context: &Context, dir: &Path) -> Result { ensure!( // TODO: Should we worry about path normalisation? @@ -818,13 +819,24 @@ impl BackupSender { }, _ = cancel_token.recv() => Err(format_err!("cancelled")), }; - - // TODO: This is all wrong, too early to release - context.free_ongoing().await; - res + let (provider, ticket) = match res { + Ok((provider, ticket)) => (provider, ticket), + Err(err) => { + context.free_ongoing().await; + return Err(err); + } + }; + Ok(Self { + provider, + ticket, + cancel_token, + }) } - async fn prepare_inner(context: &Context, dir: &Path) -> Result { + async fn prepare_inner( + context: &Context, + dir: &Path, + ) -> Result<(sendme::provider::Provider, sendme::provider::Ticket)> { // Generate the token up front: we also use it to encrypt the database. let token = sendme::protocol::AuthToken::generate(); let dbfile = dir.join(DBFILE_BACKUP_NAME); @@ -845,7 +857,7 @@ impl BackupSender { .auth_token(token) .spawn()?; let ticket = provider.ticket(hash); - Ok(Self { provider, ticket }) + Ok((provider, ticket)) } pub fn qr(&self) -> Qr { @@ -854,9 +866,20 @@ impl BackupSender { } } - pub async fn join(self) -> Result<(), JoinError> { + /// Wait for the backup sender to complete. + /// + /// The sender completes when an authenticated client disconnects, whether the transfer + /// was successful or not. When the ongoing task is cancelled the sender also completes + /// with an error. + /// + /// Note that this must be called and awaited for the ongoing cancellation to work. + pub async fn join(self) -> Result<()> { // TODO: should wait for 1 transfer to complete or abort - self.provider.join().await + tokio::select! { + biased; + res = self.provider.join() => res.context("BackupSender failed"), + _ = self.cancel_token.recv() => Err(anyhow!("BackupSender cancelled")), + } } pub fn abort(&self) {