Resultify join_securejoin

This gets rid of ChatId::new(0) usage and is generally a nice first
refactoing step.  The complexity of cleanup() unravels nicely.
This commit is contained in:
Floris Bruynooghe
2020-09-06 19:11:46 +02:00
parent b0bb0214c0
commit 428dbfb537
4 changed files with 89 additions and 64 deletions

View File

@@ -1877,8 +1877,13 @@ pub unsafe extern "C" fn dc_join_securejoin(
}
let ctx = &*context;
block_on(async move { securejoin::dc_join_securejoin(&ctx, &to_string_lossy(qr)).await })
.to_u32()
block_on(async move {
securejoin::dc_join_securejoin(&ctx, &to_string_lossy(qr))
.await
.map(|chatid| chatid.to_u32())
.log_err(ctx, "failed dc_join_securejoin() call")
.unwrap_or_default()
})
}
#[no_mangle]

View File

@@ -410,7 +410,7 @@ async fn handle_cmd(
"joinqr" => {
ctx.start_io().await;
if !arg0.is_empty() {
dc_join_securejoin(&ctx, arg1).await;
dc_join_securejoin(&ctx, arg1).await?;
}
}
"exit" | "quit" => return Ok(ExitResult::Exit),

View File

@@ -2701,6 +2701,7 @@ pub(crate) async fn get_chat_cnt(context: &Context) -> usize {
}
}
/// Returns a tuple of `(chatid, is_verified, blocked)`.
pub(crate) async fn get_chat_id_by_grpid(
context: &Context,
grpid: impl AsRef<str>,

View File

@@ -2,6 +2,7 @@
use std::time::{Duration, Instant};
use anyhow::{bail, Error};
use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
use crate::aheader::EncryptPreference;
@@ -11,7 +12,6 @@ use crate::constants::*;
use crate::contact::*;
use crate::context::Context;
use crate::e2ee::*;
use crate::error::{bail, Error};
use crate::events::EventType;
use crate::headerdef::HeaderDef;
use crate::key::{DcKey, Fingerprint, SignedPublicKey};
@@ -21,6 +21,7 @@ use crate::mimeparser::*;
use crate::param::*;
use crate::peerstate::*;
use crate::qr::check_qr;
use crate::sql;
use crate::stock::StockMessage;
use crate::token;
@@ -204,75 +205,77 @@ async fn get_self_fingerprint(context: &Context) -> Option<Fingerprint> {
}
}
async fn cleanup(
context: &Context,
contact_chat_id: ChatId,
ongoing_allocated: bool,
join_vg: bool,
) -> ChatId {
async fn cleanup(context: &Context, ongoing_allocated: bool) {
let mut bob = context.bob.write().await;
bob.expects = SecureJoinStep::NotActive;
let ret_chat_id: ChatId = if bob.status == BobStatus::Success {
if join_vg {
chat::get_chat_id_by_grpid(
context,
bob.qr_scan.as_ref().unwrap().text2.as_ref().unwrap(),
)
.await
.unwrap_or((ChatId::new(0), false, Blocked::Not))
.0
} else {
contact_chat_id
}
} else {
ChatId::new(0)
};
bob.qr_scan = None;
if ongoing_allocated {
context.free_ongoing().await;
}
ret_chat_id
}
/// Take a scanned QR-code and do the setup-contact/join-group handshake.
/// See the ffi-documentation for more details.
pub async fn dc_join_securejoin(context: &Context, qr: &str) -> ChatId {
#[derive(Debug, thiserror::Error)]
pub enum JoinError {
#[error("Unknown QR-code")]
QrCode,
#[error("Aborted by user")]
Aborted,
#[error("Failed to send handshake message")]
SendMessage(#[from] SendMsgError),
// Note that this can currently only occur if there is a bug in the QR/Lot code as this
// is supposed to create a contact for us.
#[error("Unknown contact (this is a bug)")]
UnknownContact,
// Note that this can only occur if we failed to create the chat correctly.
#[error("No Chat found for group (this is a bug)")]
MissingChat(#[source] sql::Error),
}
/// Take a scanned QR-code and do the setup-contact/join-group/invite handshake.
///
/// This is the start of the process for the joiner. See the module and ffi documentation
/// for more details.
///
/// When joining a group this will start an "ongoing" process and will block until the
/// process is completed, the [ChatId] for the new group is not known any sooner. When
/// verifying a contact this returns immediately.
pub async fn dc_join_securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
if context.alloc_ongoing().await.is_err() {
return cleanup(&context, ChatId::new(0), false, false).await;
cleanup(&context, false).await;
return Err(JoinError::Aborted);
}
securejoin(context, qr).await
}
async fn securejoin(context: &Context, qr: &str) -> ChatId {
async fn securejoin(context: &Context, qr: &str) -> Result<ChatId, JoinError> {
/*========================================================
==== Bob - the joiner's side =====
==== Step 2 in "Setup verified contact" protocol =====
========================================================*/
let mut contact_chat_id = ChatId::new(0);
let mut join_vg: bool = false;
info!(context, "Requesting secure-join ...",);
ensure_secret_key_exists(context).await.ok();
let qr_scan = check_qr(context, &qr).await;
if qr_scan.state != LotState::QrAskVerifyContact && qr_scan.state != LotState::QrAskVerifyGroup
{
error!(context, "Unknown QR code.",);
return cleanup(&context, contact_chat_id, true, join_vg).await;
cleanup(&context, true).await;
return Err(JoinError::QrCode);
}
contact_chat_id = match chat::create_by_contact_id(context, qr_scan.id).await {
let contact_chat_id = match chat::create_by_contact_id(context, qr_scan.id).await {
Ok(chat_id) => chat_id,
Err(_) => {
error!(context, "Unknown contact.");
return cleanup(&context, contact_chat_id, true, join_vg).await;
cleanup(&context, true).await;
return Err(JoinError::UnknownContact);
}
};
if context.shall_stop_ongoing().await {
return cleanup(&context, contact_chat_id, true, join_vg).await;
cleanup(&context, true).await;
return Err(JoinError::Aborted);
}
join_vg = qr_scan.get_state() == LotState::QrAskVerifyGroup;
let join_vg = qr_scan.get_state() == LotState::QrAskVerifyGroup;
{
let mut bob = context.bob.write().await;
bob.status = BobStatus::Error;
@@ -325,7 +328,8 @@ async fn securejoin(context: &Context, qr: &str) -> ChatId {
.await
{
error!(context, "failed to send handshake message: {}", err);
return cleanup(&context, contact_chat_id, true, join_vg).await;
cleanup(&context, true).await;
return Err(JoinError::SendMessage(err));
}
} else {
context.bob.write().await.expects = SecureJoinStep::AuthRequired;
@@ -342,7 +346,8 @@ async fn securejoin(context: &Context, qr: &str) -> ChatId {
.await
{
error!(context, "failed to send handshake message: {}", err);
return cleanup(&context, contact_chat_id, true, join_vg).await;
cleanup(&context, true).await;
return Err(JoinError::SendMessage(err));
}
}
@@ -356,31 +361,40 @@ async fn securejoin(context: &Context, qr: &str) -> ChatId {
// is created (it is created after handle_securejoin_handshake() returns by
// dc_receive_imf()). As a hack we just wait a bit for it to appear.
let start = Instant::now();
while start.elapsed() < Duration::from_secs(7) {
let chatid = loop {
{
let bob = context.bob.read().await;
if chat::get_chat_id_by_grpid(
context,
bob.qr_scan.as_ref().unwrap().text2.as_ref().unwrap(),
)
.await
.is_ok()
{
break;
let grpid = bob.qr_scan.as_ref().unwrap().text2.as_ref().unwrap();
match chat::get_chat_id_by_grpid(context, grpid).await {
Ok((chatid, _is_verified, _blocked)) => break chatid,
Err(err) => {
if start.elapsed() > Duration::from_secs(7) {
return Err(JoinError::MissingChat(err));
}
}
}
}
async_std::task::sleep(Duration::from_millis(50)).await
}
async_std::task::sleep(Duration::from_millis(50)).await;
};
cleanup(&context, contact_chat_id, true, join_vg).await
cleanup(&context, true).await;
Ok(chatid)
} else {
// for a one-to-one-chat, the chat is already known, return the chat-id,
// the verification runs in background
context.free_ongoing().await;
contact_chat_id
Ok(contact_chat_id)
}
}
/// Error for [send_handshake_msg].
///
/// Wrapping the [anyhow::Error] means we can "impl From" more easily on errors from this
/// function.
#[derive(Debug, thiserror::Error)]
#[error("Failed sending handshake message")]
pub struct SendMsgError(#[from] anyhow::Error);
async fn send_handshake_msg(
context: &Context,
contact_chat_id: ChatId,
@@ -388,7 +402,7 @@ async fn send_handshake_msg(
param2: impl AsRef<str>,
fingerprint: Option<Fingerprint>,
grpid: impl AsRef<str>,
) -> Result<(), HandshakeError> {
) -> Result<(), SendMsgError> {
let mut msg = Message::default();
msg.viewtype = Viewtype::Text;
msg.text = Some(format!("Secure-Join: {}", step));
@@ -414,10 +428,7 @@ async fn send_handshake_msg(
msg.param.set_int(Param::GuaranteeE2ee, 1);
}
chat::send_msg(context, contact_chat_id, &mut msg)
.await
.map_err(HandshakeError::MsgSendFailed)?;
chat::send_msg(context, contact_chat_id, &mut msg).await?;
Ok(())
}
@@ -477,7 +488,7 @@ pub(crate) enum HandshakeError {
#[error("No configured self address found")]
NoSelfAddr,
#[error("Failed to send message")]
MsgSendFailed(#[source] Error),
MsgSendFailed(#[from] SendMsgError),
#[error("Failed to parse fingerprint")]
BadFingerprint(#[from] crate::key::FingerprintError),
}
@@ -895,6 +906,7 @@ pub(crate) async fn handle_securejoin_handshake(
return Ok(HandshakeMessage::Ignore);
}
if join_vg {
// Responsible for showing "$Bob securely joined $group" message
inviter_progress!(context, contact_id, 800);
inviter_progress!(context, contact_id, 1000);
let field_grpid = mime_message
@@ -1126,7 +1138,7 @@ mod tests {
.unwrap();
// Bob scans QR-code, sends vc-request
let bob_chatid = dc_join_securejoin(&bob.ctx, &qr).await;
let bob_chatid = dc_join_securejoin(&bob.ctx, &qr).await.unwrap();
let sent = bob.pop_sent_msg().await;
assert_eq!(sent.id(), bob_chatid);
@@ -1217,6 +1229,13 @@ mod tests {
);
}
#[async_std::test]
async fn test_setup_contact_bad_qr() {
let bob = TestContext::new_bob().await;
let ret = dc_join_securejoin(&bob.ctx, "not a qr code").await;
assert!(matches!(ret, Err(JoinError::QrCode)));
}
#[async_std::test]
async fn test_setup_contact_bob_knows_alice() {
let alice = TestContext::new_alice().await;
@@ -1248,7 +1267,7 @@ mod tests {
.unwrap();
// Bob scans QR-code, sends vc-request-with-auth, skipping vc-request
dc_join_securejoin(&bob.ctx, &qr).await;
dc_join_securejoin(&bob.ctx, &qr).await.unwrap();
let sent = bob.pop_sent_msg().await;
let msg = alice.parse_msg(&sent).await;
@@ -1342,7 +1361,7 @@ mod tests {
let joiner = {
let qr = qr.clone();
let ctx = bob.ctx.clone();
async_std::task::spawn(async move { dc_join_securejoin(&ctx, &qr).await })
async_std::task::spawn(async move { dc_join_securejoin(&ctx, &qr).await.unwrap() })
};
let sent = bob.pop_sent_msg().await;