Wait for join-group to finish using ongoing channel

The ongoing mechanism actually has a channel to wait on instead of
sleeping in a loop.  Let's use it.
This commit is contained in:
Floris Bruynooghe
2021-02-17 21:41:17 +01:00
parent 24cb6aa9a4
commit 036c9cd513

View File

@@ -4,6 +4,7 @@ use std::convert::TryFrom;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use anyhow::{bail, Context as _, Error, Result}; use anyhow::{bail, Context as _, Error, Result};
use async_std::channel::Receiver;
use async_std::sync::Mutex; use async_std::sync::Mutex;
use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC}; use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
@@ -70,6 +71,20 @@ pub(crate) struct Bob {
inner: Mutex<Option<BobState>>, inner: Mutex<Option<BobState>>,
} }
/// Return value for [`Bob::start_protocol`].
///
/// This indicates which protocol variant was started and provides the required information
/// about it.
enum StartedProtocolVariant {
/// The setup-contact protocol, to verify a contact.
SetupContact,
/// The secure-join protocol, to join a group.
SecureJoin {
ongoing_receiver: Receiver<()>,
group_id: String,
},
}
impl Bob { impl Bob {
/// Starts the securejoin protocol with the QR `invite`. /// Starts the securejoin protocol with the QR `invite`.
/// ///
@@ -78,19 +93,32 @@ impl Bob {
/// ///
/// This function takes care of starting the "ongoing" mechanism if required and /// This function takes care of starting the "ongoing" mechanism if required and
/// handling errors while starting the protocol. /// handling errors while starting the protocol.
async fn start_protocol(&self, context: &Context, invite: QrInvite) -> Result<(), JoinError> { ///
/// # Returns
///
/// If the started protocol is joining a group the returned struct contains information
/// about the group and ongoing process.
async fn start_protocol(
&self,
context: &Context,
invite: QrInvite,
) -> Result<StartedProtocolVariant, JoinError> {
let mut guard = self.inner.lock().await; let mut guard = self.inner.lock().await;
if guard.is_some() { if guard.is_some() {
return Err(JoinError::AlreadyRunning); return Err(JoinError::AlreadyRunning);
} }
let did_alloc_ongoing = match invite { let variant = match invite {
QrInvite::Group { .. } => { QrInvite::Group { ref grpid, .. } => {
if context.alloc_ongoing().await.is_err() { let receiver = context
return Err(JoinError::OngoingRunning); .alloc_ongoing()
.await
.map_err(|_| JoinError::OngoingRunning)?;
StartedProtocolVariant::SecureJoin {
ongoing_receiver: receiver,
group_id: grpid.clone(),
} }
true
} }
_ => false, _ => StartedProtocolVariant::SetupContact,
}; };
match BobState::start_protocol(context, invite).await { match BobState::start_protocol(context, invite).await {
Ok((state, stage)) => { Ok((state, stage)) => {
@@ -98,10 +126,10 @@ impl Bob {
joiner_progress!(context, state.invite().contact_id(), 400); joiner_progress!(context, state.invite().contact_id(), 400);
} }
*guard = Some(state); *guard = Some(state);
Ok(()) Ok(variant)
} }
Err(err) => { Err(err) => {
if did_alloc_ongoing { if let StartedProtocolVariant::SecureJoin { .. } = variant {
context.free_ongoing().await; context.free_ongoing().await;
} }
Err(err) Err(err)
@@ -233,6 +261,8 @@ pub enum JoinError {
// Note that this can only occur if we failed to create the chat correctly. // Note that this can only occur if we failed to create the chat correctly.
#[error("No Chat found for group (this is a bug)")] #[error("No Chat found for group (this is a bug)")]
MissingChat(#[source] sql::Error), MissingChat(#[source] sql::Error),
#[error("Ongoing sender dropped (this is a bug)")]
OngoingSenderDropped,
} }
/// Take a scanned QR-code and do the setup-contact/join-group/invite handshake. /// Take a scanned QR-code and do the setup-contact/join-group/invite handshake.
@@ -262,10 +292,8 @@ async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
let qr_scan = check_qr(context, &qr).await; let qr_scan = check_qr(context, &qr).await;
let invite = QrInvite::try_from(qr_scan)?; let invite = QrInvite::try_from(qr_scan)?;
context.bob.start_protocol(context, invite.clone()).await?; match context.bob.start_protocol(context, invite.clone()).await? {
StartedProtocolVariant::SetupContact => {
match invite {
QrInvite::Contact { .. } => {
// for a one-to-one-chat, the chat is already known, return the chat-id, // for a one-to-one-chat, the chat is already known, return the chat-id,
// the verification runs in background // the verification runs in background
let chat_id = chat::create_by_contact_id(context, invite.contact_id()) let chat_id = chat::create_by_contact_id(context, invite.contact_id())
@@ -273,11 +301,15 @@ async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
.map_err(JoinError::UnknownContact)?; .map_err(JoinError::UnknownContact)?;
Ok(chat_id) Ok(chat_id)
} }
QrInvite::Group { ref grpid, .. } => { StartedProtocolVariant::SecureJoin {
// for a group-join, wait until the secure-join is done and the group is created ongoing_receiver,
while !context.shall_stop_ongoing().await { group_id,
async_std::task::sleep(Duration::from_millis(50)).await; } => {
} // for a group-join, wait until the protocol is finished and the group is created
ongoing_receiver
.recv()
.await
.map_err(|_| JoinError::OngoingSenderDropped)?;
// handle_securejoin_handshake() calls Context::stop_ongoing before the group // handle_securejoin_handshake() calls Context::stop_ongoing before the group
// chat is created (it is created after handle_securejoin_handshake() returns by // chat is created (it is created after handle_securejoin_handshake() returns by
@@ -287,7 +319,7 @@ async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
let start = Instant::now(); let start = Instant::now();
let chatid = loop { let chatid = loop {
{ {
match chat::get_chat_id_by_grpid(context, grpid).await { match chat::get_chat_id_by_grpid(context, &group_id).await {
Ok((chatid, _is_protected, _blocked)) => break chatid, Ok((chatid, _is_protected, _blocked)) => break chatid,
Err(err) => { Err(err) => {
if start.elapsed() > Duration::from_secs(7) { if start.elapsed() > Duration::from_secs(7) {