diff --git a/src/securejoin/mod.rs b/src/securejoin/mod.rs index 4f5b7e7f4..0a8a1aed7 100644 --- a/src/securejoin/mod.rs +++ b/src/securejoin/mod.rs @@ -4,6 +4,7 @@ use std::convert::TryFrom; use std::time::{Duration, Instant}; use anyhow::{bail, Context as _, Error, Result}; +use async_std::channel::Receiver; use async_std::sync::Mutex; use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC}; @@ -70,6 +71,20 @@ pub(crate) struct Bob { inner: Mutex>, } +/// Return value for [`Bob::start_protocol`]. +/// +/// This indicates which protocol variant was started and provides the required information +/// about it. +enum StartedProtocolVariant { + /// The setup-contact protocol, to verify a contact. + SetupContact, + /// The secure-join protocol, to join a group. + SecureJoin { + ongoing_receiver: Receiver<()>, + group_id: String, + }, +} + impl Bob { /// Starts the securejoin protocol with the QR `invite`. /// @@ -78,19 +93,32 @@ impl Bob { /// /// This function takes care of starting the "ongoing" mechanism if required and /// handling errors while starting the protocol. - async fn start_protocol(&self, context: &Context, invite: QrInvite) -> Result<(), JoinError> { + /// + /// # Returns + /// + /// If the started protocol is joining a group the returned struct contains information + /// about the group and ongoing process. + async fn start_protocol( + &self, + context: &Context, + invite: QrInvite, + ) -> Result { let mut guard = self.inner.lock().await; if guard.is_some() { return Err(JoinError::AlreadyRunning); } - let did_alloc_ongoing = match invite { - QrInvite::Group { .. } => { - if context.alloc_ongoing().await.is_err() { - return Err(JoinError::OngoingRunning); + let variant = match invite { + QrInvite::Group { ref grpid, .. } => { + let receiver = context + .alloc_ongoing() + .await + .map_err(|_| JoinError::OngoingRunning)?; + StartedProtocolVariant::SecureJoin { + ongoing_receiver: receiver, + group_id: grpid.clone(), } - true } - _ => false, + _ => StartedProtocolVariant::SetupContact, }; match BobState::start_protocol(context, invite).await { Ok((state, stage)) => { @@ -98,10 +126,10 @@ impl Bob { joiner_progress!(context, state.invite().contact_id(), 400); } *guard = Some(state); - Ok(()) + Ok(variant) } Err(err) => { - if did_alloc_ongoing { + if let StartedProtocolVariant::SecureJoin { .. } = variant { context.free_ongoing().await; } Err(err) @@ -233,6 +261,8 @@ pub enum JoinError { // Note that this can only occur if we failed to create the chat correctly. #[error("No Chat found for group (this is a bug)")] MissingChat(#[source] sql::Error), + #[error("Ongoing sender dropped (this is a bug)")] + OngoingSenderDropped, } /// Take a scanned QR-code and do the setup-contact/join-group/invite handshake. @@ -262,10 +292,8 @@ async fn securejoin(context: &Context, qr: &str) -> Result { let qr_scan = check_qr(context, &qr).await; let invite = QrInvite::try_from(qr_scan)?; - context.bob.start_protocol(context, invite.clone()).await?; - - match invite { - QrInvite::Contact { .. } => { + match context.bob.start_protocol(context, invite.clone()).await? { + StartedProtocolVariant::SetupContact => { // for a one-to-one-chat, the chat is already known, return the chat-id, // the verification runs in background let chat_id = chat::create_by_contact_id(context, invite.contact_id()) @@ -273,11 +301,15 @@ async fn securejoin(context: &Context, qr: &str) -> Result { .map_err(JoinError::UnknownContact)?; Ok(chat_id) } - QrInvite::Group { ref grpid, .. } => { - // for a group-join, wait until the secure-join is done and the group is created - while !context.shall_stop_ongoing().await { - async_std::task::sleep(Duration::from_millis(50)).await; - } + StartedProtocolVariant::SecureJoin { + ongoing_receiver, + group_id, + } => { + // for a group-join, wait until the protocol is finished and the group is created + ongoing_receiver + .recv() + .await + .map_err(|_| JoinError::OngoingSenderDropped)?; // handle_securejoin_handshake() calls Context::stop_ongoing before the group // chat is created (it is created after handle_securejoin_handshake() returns by @@ -287,7 +319,7 @@ async fn securejoin(context: &Context, qr: &str) -> Result { let start = Instant::now(); let chatid = loop { { - match chat::get_chat_id_by_grpid(context, grpid).await { + match chat::get_chat_id_by_grpid(context, &group_id).await { Ok((chatid, _is_protected, _blocked)) => break chatid, Err(err) => { if start.elapsed() > Duration::from_secs(7) {