Wait for join-group to finish using ongoing channel

The ongoing mechanism actually has a channel to wait on instead of
sleeping in a loop.  Let's use it.
This commit is contained in:
Floris Bruynooghe
2021-02-17 21:41:17 +01:00
parent 24cb6aa9a4
commit 036c9cd513

View File

@@ -4,6 +4,7 @@ use std::convert::TryFrom;
use std::time::{Duration, Instant};
use anyhow::{bail, Context as _, Error, Result};
use async_std::channel::Receiver;
use async_std::sync::Mutex;
use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
@@ -70,6 +71,20 @@ pub(crate) struct Bob {
inner: Mutex<Option<BobState>>,
}
/// Return value for [`Bob::start_protocol`].
///
/// This indicates which protocol variant was started and provides the required information
/// about it.
enum StartedProtocolVariant {
/// The setup-contact protocol, to verify a contact.
SetupContact,
/// The secure-join protocol, to join a group.
SecureJoin {
ongoing_receiver: Receiver<()>,
group_id: String,
},
}
impl Bob {
/// Starts the securejoin protocol with the QR `invite`.
///
@@ -78,19 +93,32 @@ impl Bob {
///
/// This function takes care of starting the "ongoing" mechanism if required and
/// handling errors while starting the protocol.
async fn start_protocol(&self, context: &Context, invite: QrInvite) -> Result<(), JoinError> {
///
/// # Returns
///
/// If the started protocol is joining a group the returned struct contains information
/// about the group and ongoing process.
async fn start_protocol(
&self,
context: &Context,
invite: QrInvite,
) -> Result<StartedProtocolVariant, JoinError> {
let mut guard = self.inner.lock().await;
if guard.is_some() {
return Err(JoinError::AlreadyRunning);
}
let did_alloc_ongoing = match invite {
QrInvite::Group { .. } => {
if context.alloc_ongoing().await.is_err() {
return Err(JoinError::OngoingRunning);
let variant = match invite {
QrInvite::Group { ref grpid, .. } => {
let receiver = context
.alloc_ongoing()
.await
.map_err(|_| JoinError::OngoingRunning)?;
StartedProtocolVariant::SecureJoin {
ongoing_receiver: receiver,
group_id: grpid.clone(),
}
true
}
_ => false,
_ => StartedProtocolVariant::SetupContact,
};
match BobState::start_protocol(context, invite).await {
Ok((state, stage)) => {
@@ -98,10 +126,10 @@ impl Bob {
joiner_progress!(context, state.invite().contact_id(), 400);
}
*guard = Some(state);
Ok(())
Ok(variant)
}
Err(err) => {
if did_alloc_ongoing {
if let StartedProtocolVariant::SecureJoin { .. } = variant {
context.free_ongoing().await;
}
Err(err)
@@ -233,6 +261,8 @@ pub enum JoinError {
// Note that this can only occur if we failed to create the chat correctly.
#[error("No Chat found for group (this is a bug)")]
MissingChat(#[source] sql::Error),
#[error("Ongoing sender dropped (this is a bug)")]
OngoingSenderDropped,
}
/// Take a scanned QR-code and do the setup-contact/join-group/invite handshake.
@@ -262,10 +292,8 @@ async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
let qr_scan = check_qr(context, &qr).await;
let invite = QrInvite::try_from(qr_scan)?;
context.bob.start_protocol(context, invite.clone()).await?;
match invite {
QrInvite::Contact { .. } => {
match context.bob.start_protocol(context, invite.clone()).await? {
StartedProtocolVariant::SetupContact => {
// for a one-to-one-chat, the chat is already known, return the chat-id,
// the verification runs in background
let chat_id = chat::create_by_contact_id(context, invite.contact_id())
@@ -273,11 +301,15 @@ async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
.map_err(JoinError::UnknownContact)?;
Ok(chat_id)
}
QrInvite::Group { ref grpid, .. } => {
// for a group-join, wait until the secure-join is done and the group is created
while !context.shall_stop_ongoing().await {
async_std::task::sleep(Duration::from_millis(50)).await;
}
StartedProtocolVariant::SecureJoin {
ongoing_receiver,
group_id,
} => {
// for a group-join, wait until the protocol is finished and the group is created
ongoing_receiver
.recv()
.await
.map_err(|_| JoinError::OngoingSenderDropped)?;
// handle_securejoin_handshake() calls Context::stop_ongoing before the group
// chat is created (it is created after handle_securejoin_handshake() returns by
@@ -287,7 +319,7 @@ async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
let start = Instant::now();
let chatid = loop {
{
match chat::get_chat_id_by_grpid(context, grpid).await {
match chat::get_chat_id_by_grpid(context, &group_id).await {
Ok((chatid, _is_protected, _blocked)) => break chatid,
Err(err) => {
if start.elapsed() > Duration::from_secs(7) {