fix more tests

This commit is contained in:
dignifiedquire
2020-03-22 21:57:26 +01:00
parent 20ef115eb2
commit b616a2b3e7
7 changed files with 126 additions and 73 deletions

View File

@@ -223,7 +223,6 @@ def acfactory(pytestconfig, tmpdir, request, session_liveconfig, datadir):
pre_generated_key=pre_generated_key) pre_generated_key=pre_generated_key)
configdict.update(config) configdict.update(config)
ac.configure(**configdict) ac.configure(**configdict)
ac.start_threads()
return ac return ac
def get_one_online_account(self, pre_generated_key=True): def get_one_online_account(self, pre_generated_key=True):
@@ -231,6 +230,7 @@ def acfactory(pytestconfig, tmpdir, request, session_liveconfig, datadir):
pre_generated_key=pre_generated_key) pre_generated_key=pre_generated_key)
wait_successful_IMAP_SMTP_connection(ac1) wait_successful_IMAP_SMTP_connection(ac1)
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
return ac1 return ac1
def get_two_online_accounts(self): def get_two_online_accounts(self):
@@ -238,8 +238,11 @@ def acfactory(pytestconfig, tmpdir, request, session_liveconfig, datadir):
ac2 = self.get_online_configuring_account() ac2 = self.get_online_configuring_account()
wait_successful_IMAP_SMTP_connection(ac1) wait_successful_IMAP_SMTP_connection(ac1)
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
wait_successful_IMAP_SMTP_connection(ac2) wait_successful_IMAP_SMTP_connection(ac2)
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
return ac1, ac2 return ac1, ac2
def clone_online_account(self, account, pre_generated_key=True): def clone_online_account(self, account, pre_generated_key=True):
@@ -251,7 +254,7 @@ def acfactory(pytestconfig, tmpdir, request, session_liveconfig, datadir):
ac._evlogger.init_time = self.init_time ac._evlogger.init_time = self.init_time
ac._evlogger.set_timeout(30) ac._evlogger.set_timeout(30)
ac.configure(addr=account.get_config("addr"), mail_pw=account.get_config("mail_pw")) ac.configure(addr=account.get_config("addr"), mail_pw=account.get_config("mail_pw"))
ac.start_threads()
return ac return ac
am = AccountMaker() am = AccountMaker()

View File

@@ -349,7 +349,6 @@ class TestOfflineChat:
ac1.configure(addr="123@example.org") ac1.configure(addr="123@example.org")
def test_import_export_one_contact(self, acfactory, tmpdir): def test_import_export_one_contact(self, acfactory, tmpdir):
print("START")
backupdir = tmpdir.mkdir("backup") backupdir = tmpdir.mkdir("backup")
ac1 = acfactory.get_configured_offline_account() ac1 = acfactory.get_configured_offline_account()
contact1 = ac1.create_contact("some1@hello.com", name="some1") contact1 = ac1.create_contact("some1@hello.com", name="some1")
@@ -361,27 +360,22 @@ class TestOfflineChat:
with bin.open("w") as f: with bin.open("w") as f:
f.write("\00123" * 10000) f.write("\00123" * 10000)
msg = chat.send_file(bin.strpath) msg = chat.send_file(bin.strpath)
print("L1")
contact = msg.get_sender_contact() contact = msg.get_sender_contact()
assert contact == ac1.get_self_contact() assert contact == ac1.get_self_contact()
assert not backupdir.listdir() assert not backupdir.listdir()
print("L2")
path = ac1.export_all(backupdir.strpath) path = ac1.export_all(backupdir.strpath)
assert os.path.exists(path) assert os.path.exists(path)
ac2 = acfactory.get_unconfigured_account() ac2 = acfactory.get_unconfigured_account()
ac2.import_all(path) ac2.import_all(path)
contacts = ac2.get_contacts(query="some1") contacts = ac2.get_contacts(query="some1")
assert len(contacts) == 1 assert len(contacts) == 1
print("L3")
contact2 = contacts[0] contact2 = contacts[0]
assert contact2.addr == "some1@hello.com" assert contact2.addr == "some1@hello.com"
chat2 = ac2.create_chat_by_contact(contact2) chat2 = ac2.create_chat_by_contact(contact2)
messages = chat2.get_messages() messages = chat2.get_messages()
assert len(messages) == 2 assert len(messages) == 2
print("L4")
assert messages[0].text == "msg1" assert messages[0].text == "msg1"
assert os.path.exists(messages[1].filename) assert os.path.exists(messages[1].filename)
print("STOP")
def test_ac_setup_message_fails(self, ac1): def test_ac_setup_message_fails(self, ac1):
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
@@ -447,7 +441,9 @@ class TestOnlineAccount:
config={"key_gen_type": str(const.DC_KEY_GEN_ED25519)} config={"key_gen_type": str(const.DC_KEY_GEN_ED25519)}
) )
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
chat = self.get_chat(ac1, ac2, both_created=True) chat = self.get_chat(ac1, ac2, both_created=True)
lp.sec("ac1: send unencrypted message to ac2") lp.sec("ac1: send unencrypted message to ac2")
@@ -482,6 +478,9 @@ class TestOnlineAccount:
def test_export_import_self_keys(self, acfactory, tmpdir): def test_export_import_self_keys(self, acfactory, tmpdir):
ac1, ac2 = acfactory.get_two_online_accounts() ac1, ac2 = acfactory.get_two_online_accounts()
ac1.stop_threads()
ac2.stop_threads()
dir = tmpdir.mkdir("exportdir") dir = tmpdir.mkdir("exportdir")
export_files = ac1.export_self_keys(dir.strpath) export_files = ac1.export_self_keys(dir.strpath)
assert len(export_files) == 2 assert len(export_files) == 2
@@ -499,8 +498,11 @@ class TestOnlineAccount:
ac1_clone = acfactory.clone_online_account(ac1) ac1_clone = acfactory.clone_online_account(ac1)
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
wait_configuration_progress(ac1_clone, 1000) wait_configuration_progress(ac1_clone, 1000)
ac1_clone.start_threads()
chat = self.get_chat(ac1, ac2) chat = self.get_chat(ac1, ac2)
@@ -605,10 +607,12 @@ class TestOnlineAccount:
lp.sec("ac2: waiting for configuration") lp.sec("ac2: waiting for configuration")
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
lp.sec("ac1: waiting for configuration") lp.sec("ac1: waiting for configuration")
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
lp.sec("ac1: send message and wait for ac2 to receive it") lp.sec("ac1: send message and wait for ac2 to receive it")
chat = self.get_chat(ac1, ac2) chat = self.get_chat(ac1, ac2)
chat.send_text("message1") chat.send_text("message1")
@@ -620,7 +624,9 @@ class TestOnlineAccount:
ac1 = acfactory.get_online_configuring_account() ac1 = acfactory.get_online_configuring_account()
ac2 = acfactory.get_online_configuring_account(mvbox=True) ac2 = acfactory.get_online_configuring_account(mvbox=True)
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
chat = self.get_chat(ac1, ac2) chat = self.get_chat(ac1, ac2)
chat.send_text("message1") chat.send_text("message1")
ev = ac2._evlogger.get_matching("DC_EVENT_INCOMING_MSG|DC_EVENT_MSGS_CHANGED") ev = ac2._evlogger.get_matching("DC_EVENT_INCOMING_MSG|DC_EVENT_MSGS_CHANGED")
@@ -632,7 +638,9 @@ class TestOnlineAccount:
ac1.set_config("bcc_self", "1") ac1.set_config("bcc_self", "1")
ac2 = acfactory.get_online_configuring_account() ac2 = acfactory.get_online_configuring_account()
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
chat = self.get_chat(ac1, ac2) chat = self.get_chat(ac1, ac2)
chat.send_text("message1") chat.send_text("message1")
chat.send_text("message2") chat.send_text("message2")
@@ -979,6 +987,7 @@ class TestOnlineAccount:
def test_import_export_online_all(self, acfactory, tmpdir, lp): def test_import_export_online_all(self, acfactory, tmpdir, lp):
ac1 = acfactory.get_online_configuring_account() ac1 = acfactory.get_online_configuring_account()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
lp.sec("create some chat content") lp.sec("create some chat content")
contact1 = ac1.create_contact("some1@hello.com", name="some1") contact1 = ac1.create_contact("some1@hello.com", name="some1")
@@ -1027,7 +1036,9 @@ class TestOnlineAccount:
ac1 = acfactory.get_online_configuring_account() ac1 = acfactory.get_online_configuring_account()
ac2 = acfactory.clone_online_account(ac1) ac2 = acfactory.clone_online_account(ac1)
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
lp.sec("trigger ac setup message and return setupcode") lp.sec("trigger ac setup message and return setupcode")
assert ac1.get_info()["fingerprint"] != ac2.get_info()["fingerprint"] assert ac1.get_info()["fingerprint"] != ac2.get_info()["fingerprint"]
setup_code = ac1.initiate_key_transfer() setup_code = ac1.initiate_key_transfer()
@@ -1050,7 +1061,9 @@ class TestOnlineAccount:
ac2 = acfactory.clone_online_account(ac1) ac2 = acfactory.clone_online_account(ac1)
ac2._evlogger.set_timeout(30) ac2._evlogger.set_timeout(30)
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
lp.sec("trigger ac setup message but ignore") lp.sec("trigger ac setup message but ignore")
assert ac1.get_info()["fingerprint"] != ac2.get_info()["fingerprint"] assert ac1.get_info()["fingerprint"] != ac2.get_info()["fingerprint"]
@@ -1279,6 +1292,7 @@ class TestGroupStressTests:
accounts = [acfactory.get_online_configuring_account() for i in range(5)] accounts = [acfactory.get_online_configuring_account() for i in range(5)]
for acc in accounts: for acc in accounts:
wait_configuration_progress(acc, 1000) wait_configuration_progress(acc, 1000)
acc.start_threads()
ac1 = accounts.pop() ac1 = accounts.pop()
lp.sec("ac1: setting up contacts with 4 other members") lp.sec("ac1: setting up contacts with 4 other members")
@@ -1382,6 +1396,7 @@ class TestGroupStressTests:
accounts = [acfactory.get_online_configuring_account() for i in range(3)] accounts = [acfactory.get_online_configuring_account() for i in range(3)]
for acc in accounts: for acc in accounts:
wait_configuration_progress(acc, 1000) wait_configuration_progress(acc, 1000)
acc.start_threads()
ac1 = accounts.pop() ac1 = accounts.pop()
lp.sec("ac1: setting up contacts with 2 other members") lp.sec("ac1: setting up contacts with 2 other members")
@@ -1449,7 +1464,6 @@ class TestOnlineConfigureFails:
def test_invalid_password(self, acfactory): def test_invalid_password(self, acfactory):
ac1, configdict = acfactory.get_online_config() ac1, configdict = acfactory.get_online_config()
ac1.configure(addr=configdict["addr"], mail_pw="123") ac1.configure(addr=configdict["addr"], mail_pw="123")
ac1.start_threads()
wait_configuration_progress(ac1, 500) wait_configuration_progress(ac1, 500)
ev1 = ac1._evlogger.get_matching("DC_EVENT_ERROR_NETWORK") ev1 = ac1._evlogger.get_matching("DC_EVENT_ERROR_NETWORK")
assert "cannot login" in ev1[2].lower() assert "cannot login" in ev1[2].lower()
@@ -1458,7 +1472,6 @@ class TestOnlineConfigureFails:
def test_invalid_user(self, acfactory): def test_invalid_user(self, acfactory):
ac1, configdict = acfactory.get_online_config() ac1, configdict = acfactory.get_online_config()
ac1.configure(addr="x" + configdict["addr"], mail_pw=configdict["mail_pw"]) ac1.configure(addr="x" + configdict["addr"], mail_pw=configdict["mail_pw"])
ac1.start_threads()
wait_configuration_progress(ac1, 500) wait_configuration_progress(ac1, 500)
ev1 = ac1._evlogger.get_matching("DC_EVENT_ERROR_NETWORK") ev1 = ac1._evlogger.get_matching("DC_EVENT_ERROR_NETWORK")
assert "cannot login" in ev1[2].lower() assert "cannot login" in ev1[2].lower()
@@ -1467,7 +1480,6 @@ class TestOnlineConfigureFails:
def test_invalid_domain(self, acfactory): def test_invalid_domain(self, acfactory):
ac1, configdict = acfactory.get_online_config() ac1, configdict = acfactory.get_online_config()
ac1.configure(addr=configdict["addr"] + "x", mail_pw=configdict["mail_pw"]) ac1.configure(addr=configdict["addr"] + "x", mail_pw=configdict["mail_pw"])
ac1.start_threads()
wait_configuration_progress(ac1, 500) wait_configuration_progress(ac1, 500)
ev1 = ac1._evlogger.get_matching("DC_EVENT_ERROR_NETWORK") ev1 = ac1._evlogger.get_matching("DC_EVENT_ERROR_NETWORK")
assert "could not connect" in ev1[2].lower() assert "could not connect" in ev1[2].lower()

View File

@@ -15,7 +15,9 @@ class TestOnlineInCreation:
ac1 = acfactory.get_online_configuring_account() ac1 = acfactory.get_online_configuring_account()
ac2 = acfactory.get_online_configuring_account() ac2 = acfactory.get_online_configuring_account()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
c2 = ac1.create_contact(email=ac2.get_config("addr")) c2 = ac1.create_contact(email=ac2.get_config("addr"))
chat = ac1.create_chat_by_contact(c2) chat = ac1.create_chat_by_contact(c2)
@@ -30,7 +32,9 @@ class TestOnlineInCreation:
ac1 = acfactory.get_online_configuring_account() ac1 = acfactory.get_online_configuring_account()
ac2 = acfactory.get_online_configuring_account() ac2 = acfactory.get_online_configuring_account()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
c2 = ac1.create_contact(email=ac2.get_config("addr")) c2 = ac1.create_contact(email=ac2.get_config("addr"))
chat = ac1.create_chat_by_contact(c2) chat = ac1.create_chat_by_contact(c2)
@@ -48,7 +52,9 @@ class TestOnlineInCreation:
ac1 = acfactory.get_online_configuring_account() ac1 = acfactory.get_online_configuring_account()
ac2 = acfactory.get_online_configuring_account() ac2 = acfactory.get_online_configuring_account()
wait_configuration_progress(ac1, 1000) wait_configuration_progress(ac1, 1000)
ac1.start_threads()
wait_configuration_progress(ac2, 1000) wait_configuration_progress(ac2, 1000)
ac2.start_threads()
c2 = ac1.create_contact(email=ac2.get_config("addr")) c2 = ac1.create_contact(email=ac2.get_config("addr"))
chat = ac1.create_chat_by_contact(c2) chat = ac1.create_chat_by_contact(c2)

View File

@@ -4,6 +4,7 @@ mod auto_mozilla;
mod auto_outlook; mod auto_outlook;
mod read_url; mod read_url;
use async_std::prelude::*;
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC}; use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
use crate::config::Config; use crate::config::Config;
@@ -13,11 +14,11 @@ use crate::dc_tools::*;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::imap::Imap; use crate::imap::Imap;
use crate::login_param::{CertificateChecks, LoginParam}; use crate::login_param::{CertificateChecks, LoginParam};
use crate::message::Message;
use crate::oauth2::*; use crate::oauth2::*;
use crate::smtp::Smtp; use crate::smtp::Smtp;
use crate::{chat, e2ee, provider}; use crate::{chat, e2ee, provider};
use crate::message::Message;
use auto_mozilla::moz_autoconfigure; use auto_mozilla::moz_autoconfigure;
use auto_outlook::outlk_autodiscover; use auto_outlook::outlk_autodiscover;
@@ -39,10 +40,8 @@ impl Context {
/// Configures this account with the currently set parameters. /// Configures this account with the currently set parameters.
pub async fn configure(&self) -> Result<()> { pub async fn configure(&self) -> Result<()> {
ensure!( use futures::future::FutureExt;
!self.has_ongoing().await,
"There is already another ongoing process running."
);
ensure!( ensure!(
!self.scheduler.read().await.is_running(), !self.scheduler.read().await.is_running(),
"Can not configure, already running" "Can not configure, already running"
@@ -51,11 +50,22 @@ impl Context {
self.sql.is_open().await, self.sql.is_open().await,
"Cannot configure, database not opened." "Cannot configure, database not opened."
); );
ensure!( let cancel_channel = self.alloc_ongoing().await?;
self.alloc_ongoing().await,
"Cannot allocate ongoing process"
);
let res = self
.inner_configure()
.race(cancel_channel.recv().map(|_| {
progress!(self, 0);
Ok(())
}))
.await;
self.free_ongoing().await;
res
}
async fn inner_configure(&self) -> Result<()> {
let mut success = false; let mut success = false;
let mut param_autoconfig: Option<LoginParam> = None; let mut param_autoconfig: Option<LoginParam> = None;
@@ -127,15 +137,12 @@ impl Context {
// and restore to last-entered on failure. // and restore to last-entered on failure.
// this way, the parameters visible to the ui are always in-sync with the current configuration. // this way, the parameters visible to the ui are always in-sync with the current configuration.
if success { if success {
assert!(self.is_configured().await, "epic fail");
LoginParam::from_database(self, "") LoginParam::from_database(self, "")
.await .await
.save_to_database(self, "configured_raw_") .save_to_database(self, "configured_raw_")
.await .await
.ok(); .ok();
self.free_ongoing().await;
progress!(self, 1000); progress!(self, 1000);
Ok(()) Ok(())
} else { } else {
@@ -145,8 +152,6 @@ impl Context {
.await .await
.ok(); .ok();
self.free_ongoing().await;
progress!(self, 0); progress!(self, 0);
Err(Error::Message("Configure failed".to_string())) Err(Error::Message("Configure failed".to_string()))
} }
@@ -398,8 +403,8 @@ async fn exec_step(
progress!(ctx, 600); progress!(ctx, 600);
/* try to connect to IMAP - if we did not got an autoconfig, /* try to connect to IMAP - if we did not got an autoconfig,
do some further tries with different settings and username variations */ do some further tries with different settings and username variations */
try_imap_connections(ctx, param, param_autoconfig.is_some(), imap).await?; *is_imap_connected =
*is_imap_connected = true; try_imap_connections(ctx, param, param_autoconfig.is_some(), imap).await?;
} }
15 => { 15 => {
progress!(ctx, 800); progress!(ctx, 800);
@@ -512,13 +517,10 @@ async fn try_imap_connections(
mut param: &mut LoginParam, mut param: &mut LoginParam,
was_autoconfig: bool, was_autoconfig: bool,
imap: &mut Imap, imap: &mut Imap,
) -> Result<()> { ) -> Result<bool> {
// progress 650 and 660 // progress 650 and 660
if try_imap_connection(context, &mut param, was_autoconfig, 0, imap) if let Ok(val) = try_imap_connection(context, &mut param, was_autoconfig, 0, imap).await {
.await return Ok(val);
.is_ok()
{
return Ok(());
} }
progress!(context, 670); progress!(context, 670);
param.server_flags &= !(DC_LP_IMAP_SOCKET_FLAGS); param.server_flags &= !(DC_LP_IMAP_SOCKET_FLAGS);
@@ -532,9 +534,7 @@ async fn try_imap_connections(
param.send_user = param.send_user.split_at(at).0.to_string(); param.send_user = param.send_user.split_at(at).0.to_string();
} }
// progress 680 and 690 // progress 680 and 690
try_imap_connection(context, &mut param, was_autoconfig, 1, imap).await?; try_imap_connection(context, &mut param, was_autoconfig, 1, imap).await
Ok(())
} }
async fn try_imap_connection( async fn try_imap_connection(
@@ -543,24 +543,26 @@ async fn try_imap_connection(
was_autoconfig: bool, was_autoconfig: bool,
variation: usize, variation: usize,
imap: &mut Imap, imap: &mut Imap,
) -> Result<()> { ) -> Result<bool> {
if try_imap_one_param(context, &param, imap).await.is_ok() { if try_imap_one_param(context, &param, imap).await.is_ok() {
return Ok(()); return Ok(true);
} }
if was_autoconfig { if was_autoconfig {
bail!("autoconfig"); return Ok(false);
} }
progress!(context, 650 + variation * 30); progress!(context, 650 + variation * 30);
param.server_flags &= !(DC_LP_IMAP_SOCKET_FLAGS); param.server_flags &= !(DC_LP_IMAP_SOCKET_FLAGS);
param.server_flags |= DC_LP_IMAP_SOCKET_STARTTLS; param.server_flags |= DC_LP_IMAP_SOCKET_STARTTLS;
if try_imap_one_param(context, &param, imap).await.is_ok() { if try_imap_one_param(context, &param, imap).await.is_ok() {
return Ok(()); return Ok(true);
} }
progress!(context, 660 + variation * 30); progress!(context, 660 + variation * 30);
param.mail_port = 143; param.mail_port = 143;
try_imap_one_param(context, &param, imap).await try_imap_one_param(context, &param, imap).await?;
Ok(true)
} }
async fn try_imap_one_param(context: &Context, param: &LoginParam, imap: &mut Imap) -> Result<()> { async fn try_imap_one_param(context: &Context, param: &LoginParam, imap: &mut Imap) -> Result<()> {
@@ -579,6 +581,10 @@ async fn try_imap_one_param(context: &Context, param: &LoginParam, imap: &mut Im
return Ok(()); return Ok(());
} }
if context.shall_stop_ongoing().await {
bail!("Interrupted");
}
bail!("Could not connect: {}", inf); bail!("Could not connect: {}", inf);
} }
@@ -593,7 +599,7 @@ async fn try_smtp_connections(
return Ok(()); return Ok(());
} }
if was_autoconfig { if was_autoconfig {
bail!("autoconfig"); return Ok(());
} }
progress!(context, 850); progress!(context, 850);
param.server_flags &= !(DC_LP_SMTP_SOCKET_FLAGS as i32); param.server_flags &= !(DC_LP_SMTP_SOCKET_FLAGS as i32);

View File

@@ -5,7 +5,7 @@ use std::ffi::OsString;
use std::ops::Deref; use std::ops::Deref;
use async_std::path::{Path, PathBuf}; use async_std::path::{Path, PathBuf};
use async_std::sync::{Arc, Mutex, RwLock}; use async_std::sync::{channel, Arc, Mutex, Receiver, RwLock, Sender};
use crossbeam_queue::SegQueue; use crossbeam_queue::SegQueue;
use crate::chat::*; use crate::chat::*;
@@ -55,10 +55,11 @@ pub struct InnerContext {
pub(crate) scheduler: RwLock<Scheduler>, pub(crate) scheduler: RwLock<Scheduler>,
} }
#[derive(Debug, PartialEq, Eq)] #[derive(Debug)]
pub struct RunningState { pub struct RunningState {
pub ongoing_running: bool, pub ongoing_running: bool,
shall_stop_ongoing: bool, shall_stop_ongoing: bool,
cancel_sender: Option<Sender<()>>,
} }
/// Return some info about deltachat-core /// Return some info about deltachat-core
@@ -180,20 +181,20 @@ impl Context {
* Ongoing process allocation/free/check * Ongoing process allocation/free/check
******************************************************************************/ ******************************************************************************/
pub async fn alloc_ongoing(&self) -> bool { pub async fn alloc_ongoing(&self) -> Result<Receiver<()>> {
if self.has_ongoing().await { if self.has_ongoing().await {
warn!(self, "There is already another ongoing process running.",); bail!("There is already another ongoing process running.");
false
} else {
let s_a = &self.running_state;
let mut s = s_a.write().await;
s.ongoing_running = true;
s.shall_stop_ongoing = false;
true
} }
let s_a = &self.running_state;
let mut s = s_a.write().await;
s.ongoing_running = true;
s.shall_stop_ongoing = false;
let (sender, receiver) = channel(1);
s.cancel_sender = Some(sender);
Ok(receiver)
} }
pub async fn free_ongoing(&self) { pub async fn free_ongoing(&self) {
@@ -202,6 +203,7 @@ impl Context {
s.ongoing_running = false; s.ongoing_running = false;
s.shall_stop_ongoing = true; s.shall_stop_ongoing = true;
s.cancel_sender.take();
} }
pub async fn has_ongoing(&self) -> bool { pub async fn has_ongoing(&self) -> bool {
@@ -215,6 +217,9 @@ impl Context {
pub async fn stop_ongoing(&self) { pub async fn stop_ongoing(&self) {
let s_a = &self.running_state; let s_a = &self.running_state;
let mut s = s_a.write().await; let mut s = s_a.write().await;
if let Some(cancel) = s.cancel_sender.take() {
cancel.send(()).await;
}
if s.ongoing_running && !s.shall_stop_ongoing { if s.ongoing_running && !s.shall_stop_ongoing {
info!(self, "Signaling the ongoing process to stop ASAP.",); info!(self, "Signaling the ongoing process to stop ASAP.",);
@@ -503,6 +508,7 @@ impl Default for RunningState {
RunningState { RunningState {
ongoing_running: false, ongoing_running: false,
shall_stop_ongoing: true, shall_stop_ongoing: true,
cancel_sender: None,
} }
} }
} }

View File

@@ -70,9 +70,16 @@ pub async fn imex(
what: ImexMode, what: ImexMode,
param1: Option<impl AsRef<Path>>, param1: Option<impl AsRef<Path>>,
) -> Result<()> { ) -> Result<()> {
job_imex_imap(context, what, param1).await?; use futures::future::FutureExt;
Ok(()) let cancel = context.alloc_ongoing().await?;
let res = imex_inner(context, what, param1)
.race(cancel.recv().map(|_| Err(format_err!("canceled"))))
.await;
context.free_ongoing().await;
res
} }
/// Returns the filename of the backup found (otherwise an error) /// Returns the filename of the backup found (otherwise an error)
@@ -110,8 +117,13 @@ pub async fn has_backup(context: &Context, dir_name: impl AsRef<Path>) -> Result
} }
pub async fn initiate_key_transfer(context: &Context) -> Result<String> { pub async fn initiate_key_transfer(context: &Context) -> Result<String> {
ensure!(context.alloc_ongoing().await, "could not allocate ongoing"); use futures::future::FutureExt;
let res = do_initiate_key_transfer(context).await;
let cancel = context.alloc_ongoing().await?;
let res = do_initiate_key_transfer(context)
.race(cancel.recv().map(|_| Err(format_err!("canceled"))))
.await;
context.free_ongoing().await; context.free_ongoing().await;
res res
} }
@@ -120,10 +132,8 @@ async fn do_initiate_key_transfer(context: &Context) -> Result<String> {
let mut msg: Message; let mut msg: Message;
let setup_code = create_setup_code(context); let setup_code = create_setup_code(context);
/* this may require a keypair to be created. this may take a second ... */ /* this may require a keypair to be created. this may take a second ... */
ensure!(!context.shall_stop_ongoing().await, "canceled");
let setup_file_content = render_setup_file(context, &setup_code).await?; let setup_file_content = render_setup_file(context, &setup_code).await?;
/* encrypting may also take a while ... */ /* encrypting may also take a while ... */
ensure!(!context.shall_stop_ongoing().await, "canceled");
let setup_file_blob = BlobObject::create( let setup_file_blob = BlobObject::create(
context, context,
"autocrypt-setup-message.html", "autocrypt-setup-message.html",
@@ -144,7 +154,6 @@ async fn do_initiate_key_transfer(context: &Context) -> Result<String> {
ForcePlaintext::NoAutocryptHeader as i32, ForcePlaintext::NoAutocryptHeader as i32,
); );
ensure!(!context.shall_stop_ongoing().await, "canceled");
let msg_id = chat::send_msg(context, chat_id, &mut msg).await?; let msg_id = chat::send_msg(context, chat_id, &mut msg).await?;
info!(context, "Wait for setup message being sent ...",); info!(context, "Wait for setup message being sent ...",);
while !context.shall_stop_ongoing().await { while !context.shall_stop_ongoing().await {
@@ -363,13 +372,12 @@ pub fn normalize_setup_code(s: &str) -> String {
out out
} }
pub async fn job_imex_imap( async fn imex_inner(
context: &Context, context: &Context,
what: ImexMode, what: ImexMode,
param: Option<impl AsRef<Path>>, param: Option<impl AsRef<Path>>,
) -> Result<()> { ) -> Result<()> {
ensure!(context.alloc_ongoing().await, "could not allocate ongoing"); ensure!(param.is_some(), "No Import/export dir/file given.");
ensure!(!param.is_some(), "No Import/export dir/file given.");
info!(context, "Import/export process started."); info!(context, "Import/export process started.");
context.call_cb(Event::ImexProgress(10)); context.call_cb(Event::ImexProgress(10));
@@ -380,7 +388,6 @@ pub async fn job_imex_imap(
if what == ImexMode::ExportBackup || what == ImexMode::ExportSelfKeys { if what == ImexMode::ExportBackup || what == ImexMode::ExportSelfKeys {
// before we export anything, make sure the private key exists // before we export anything, make sure the private key exists
if e2ee::ensure_secret_key_exists(context).await.is_err() { if e2ee::ensure_secret_key_exists(context).await.is_err() {
context.free_ongoing().await;
bail!("Cannot create private key or private key not available."); bail!("Cannot create private key or private key not available.");
} else { } else {
dc_create_folder(context, &path).await?; dc_create_folder(context, &path).await?;
@@ -393,7 +400,7 @@ pub async fn job_imex_imap(
ImexMode::ExportBackup => export_backup(context, path).await, ImexMode::ExportBackup => export_backup(context, path).await,
ImexMode::ImportBackup => import_backup(context, path).await, ImexMode::ImportBackup => import_backup(context, path).await,
}; };
context.free_ongoing().await;
match success { match success {
Ok(()) => { Ok(()) => {
info!(context, "IMEX successfully completed"); info!(context, "IMEX successfully completed");

View File

@@ -1,5 +1,6 @@
//! Verified contact protocol implementation as [specified by countermitm project](https://countermitm.readthedocs.io/en/stable/new.html#setup-contact-protocol) //! Verified contact protocol implementation as [specified by countermitm project](https://countermitm.readthedocs.io/en/stable/new.html#setup-contact-protocol)
use async_std::prelude::*;
use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC}; use percent_encoding::{utf8_percent_encode, AsciiSet, NON_ALPHANUMERIC};
use crate::aheader::EncryptPreference; use crate::aheader::EncryptPreference;
@@ -181,6 +182,21 @@ async fn cleanup(
/// Take a scanned QR-code and do the setup-contact/join-group handshake. /// Take a scanned QR-code and do the setup-contact/join-group handshake.
/// See the ffi-documentation for more details. /// See the ffi-documentation for more details.
pub async fn dc_join_securejoin(context: &Context, qr: &str) -> ChatId { pub async fn dc_join_securejoin(context: &Context, qr: &str) -> ChatId {
use futures::future::FutureExt;
let cancel = match context.alloc_ongoing().await {
Ok(cancel) => cancel,
Err(_) => {
return cleanup(&context, ChatId::new(0), false, false).await;
}
};
securejoin(context, qr)
.race(cancel.recv().map(|_| ChatId::new(0)))
.await
}
async fn securejoin(context: &Context, qr: &str) -> ChatId {
/*======================================================== /*========================================================
==== Bob - the joiner's side ===== ==== Bob - the joiner's side =====
==== Step 2 in "Setup verified contact" protocol ===== ==== Step 2 in "Setup verified contact" protocol =====
@@ -191,9 +207,6 @@ pub async fn dc_join_securejoin(context: &Context, qr: &str) -> ChatId {
info!(context, "Requesting secure-join ...",); info!(context, "Requesting secure-join ...",);
ensure_secret_key_exists(context).await.ok(); ensure_secret_key_exists(context).await.ok();
if !context.alloc_ongoing().await {
return cleanup(&context, contact_chat_id, false, join_vg).await;
}
let qr_scan = check_qr(context, &qr).await; let qr_scan = check_qr(context, &qr).await;
if qr_scan.state != LotState::QrAskVerifyContact && qr_scan.state != LotState::QrAskVerifyGroup if qr_scan.state != LotState::QrAskVerifyContact && qr_scan.state != LotState::QrAskVerifyGroup
{ {