diff --git a/src/e2ee.rs b/src/e2ee.rs index 47585a066..2df7ba624 100644 --- a/src/e2ee.rs +++ b/src/e2ee.rs @@ -417,7 +417,10 @@ mod tests { #[async_std::test] async fn test_prexisting() { let t = TestContext::new_alice().await; - assert_eq!(ensure_secret_key_exists(&t).await.unwrap(), "alice@example.com"); + assert_eq!( + ensure_secret_key_exists(&t).await.unwrap(), + "alice@example.org" + ); } #[async_std::test] diff --git a/src/securejoin.rs b/src/securejoin.rs index c1a324063..1e2a2465c 100644 --- a/src/securejoin.rs +++ b/src/securejoin.rs @@ -942,7 +942,6 @@ mod tests { use crate::chat::ProtectionStatus; use crate::chatlist::Chatlist; use crate::constants::Chattype; - use crate::events::Event; use crate::peerstate::Peerstate; use crate::test_utils::TestContext; use std::time::Duration; @@ -955,16 +954,8 @@ mod tests { assert_eq!(Chatlist::try_load(&bob, 0, None, None).await?.len(), 0); // Setup JoinerProgress sinks. - let (joiner_progress_tx, joiner_progress_rx) = async_std::channel::bounded(100); - bob.add_event_sink(move |event: Event| { - let joiner_progress_tx = joiner_progress_tx.clone(); - async move { - if let EventType::SecurejoinJoinerProgress { .. } = event.typ { - joiner_progress_tx.try_send(event).unwrap(); - } - } - }) - .await; + let (joiner_progress_tx, joiner_progress_rx) = async_std::channel::unbounded(); + bob.add_event_sender(joiner_progress_tx).await; // Step 1: Generate QR-code, ChatId(0) indicates setup-contact let qr = dc_get_securejoin_qr(&alice.ctx, None).await?; @@ -997,29 +988,33 @@ mod tests { bob.recv_msg(&sent).await; // Check Bob emitted the JoinerProgress event. - { - let evt = joiner_progress_rx - .recv() - .timeout(Duration::from_secs(10)) - .await - .expect("timeout waiting for JoinerProgress event") - .expect("missing JoinerProgress event"); - match evt.typ { - EventType::SecurejoinJoinerProgress { - contact_id, - progress, - } => { - let alice_contact_id = - Contact::lookup_id_by_addr(&bob.ctx, "alice@example.org", Origin::Unknown) - .await - .expect("Error looking up contact") - .expect("Contact not found"); - assert_eq!(contact_id, alice_contact_id); - assert_eq!(progress, 400); + async { + loop { + let event = joiner_progress_rx.recv().await.unwrap(); + match event.typ { + EventType::SecurejoinJoinerProgress { + contact_id, + progress, + } => { + let alice_contact_id = Contact::lookup_id_by_addr( + &bob.ctx, + "alice@example.org", + Origin::Unknown, + ) + .await + .expect("Error looking up contact") + .expect("Contact not found"); + assert_eq!(contact_id, alice_contact_id); + assert_eq!(progress, 400); + break; + } + _ => {} } - _ => panic!("Wrong event type"), } } + .timeout(Duration::from_secs(10)) + .await + .expect("timeout waiting for JoinerProgress event"); // Check Bob sent the right message. let sent = bob.pop_sent_msg().await; @@ -1157,16 +1152,8 @@ mod tests { let bob = TestContext::new_bob().await; // Setup JoinerProgress sinks. - let (joiner_progress_tx, joiner_progress_rx) = async_std::channel::bounded(100); - bob.add_event_sink(move |event: Event| { - let joiner_progress_tx = joiner_progress_tx.clone(); - async move { - if let EventType::SecurejoinJoinerProgress { .. } = event.typ { - joiner_progress_tx.try_send(event).unwrap(); - } - } - }) - .await; + let (joiner_progress_tx, joiner_progress_rx) = async_std::channel::unbounded(); + bob.add_event_sender(joiner_progress_tx).await; // Ensure Bob knows Alice_FP let alice_pubkey = SignedPublicKey::load_self(&alice.ctx).await?; @@ -1194,29 +1181,33 @@ mod tests { dc_join_securejoin(&bob.ctx, &qr).await.unwrap(); // Check Bob emitted the JoinerProgress event. - { - let evt = joiner_progress_rx - .recv() - .timeout(Duration::from_secs(10)) - .await - .expect("timeout waiting for JoinerProgress event") - .expect("missing JoinerProgress event"); - match evt.typ { - EventType::SecurejoinJoinerProgress { - contact_id, - progress, - } => { - let alice_contact_id = - Contact::lookup_id_by_addr(&bob.ctx, "alice@example.org", Origin::Unknown) - .await - .expect("Error looking up contact") - .expect("Contact not found"); - assert_eq!(contact_id, alice_contact_id); - assert_eq!(progress, 400); + async { + loop { + let event = joiner_progress_rx.recv().await.unwrap(); + match event.typ { + EventType::SecurejoinJoinerProgress { + contact_id, + progress, + } => { + let alice_contact_id = Contact::lookup_id_by_addr( + &bob.ctx, + "alice@example.org", + Origin::Unknown, + ) + .await + .expect("Error looking up contact") + .expect("Contact not found"); + assert_eq!(contact_id, alice_contact_id); + assert_eq!(progress, 400); + break; + } + _ => {} } - _ => panic!("Wrong event type"), } } + .timeout(Duration::from_secs(10)) + .await + .expect("timeout waiting for JoinerProgress event"); assert!(!bob.ctx.has_ongoing().await); // Check Bob sent the right handshake message. @@ -1330,16 +1321,8 @@ mod tests { assert_eq!(Chatlist::try_load(&bob, 0, None, None).await?.len(), 0); // Setup JoinerProgress sinks. - let (joiner_progress_tx, joiner_progress_rx) = async_std::channel::bounded(100); - bob.add_event_sink(move |event: Event| { - let joiner_progress_tx = joiner_progress_tx.clone(); - async move { - if let EventType::SecurejoinJoinerProgress { .. } = event.typ { - joiner_progress_tx.try_send(event).unwrap(); - } - } - }) - .await; + let (joiner_progress_tx, joiner_progress_rx) = async_std::channel::unbounded(); + bob.add_event_sender(joiner_progress_tx).await; let chatid = chat::create_group_chat(&alice.ctx, ProtectionStatus::Protected, "the chat").await?; @@ -1376,29 +1359,33 @@ mod tests { let sent = bob.pop_sent_msg().await; // Check Bob emitted the JoinerProgress event. - { - let evt = joiner_progress_rx - .recv() - .timeout(Duration::from_secs(10)) - .await - .expect("timeout waiting for JoinerProgress event") - .expect("missing JoinerProgress event"); - match evt.typ { - EventType::SecurejoinJoinerProgress { - contact_id, - progress, - } => { - let alice_contact_id = - Contact::lookup_id_by_addr(&bob.ctx, "alice@example.org", Origin::Unknown) - .await - .expect("Error looking up contact") - .expect("Contact not found"); - assert_eq!(contact_id, alice_contact_id); - assert_eq!(progress, 400); + async { + loop { + let event = joiner_progress_rx.recv().await.unwrap(); + match event.typ { + EventType::SecurejoinJoinerProgress { + contact_id, + progress, + } => { + let alice_contact_id = Contact::lookup_id_by_addr( + &bob.ctx, + "alice@example.org", + Origin::Unknown, + ) + .await + .expect("Error looking up contact") + .expect("Contact not found"); + assert_eq!(contact_id, alice_contact_id); + assert_eq!(progress, 400); + break; + } + _ => {} } - _ => panic!("Wrong event type"), } } + .timeout(Duration::from_secs(10)) + .await + .expect("timeout waiting for JoinerProgress event"); // Check Bob sent the right handshake message. let msg = alice.parse_msg(&sent).await; diff --git a/src/sql.rs b/src/sql.rs index 676701ae9..59e97c2c2 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -678,11 +678,12 @@ async fn prune_tombstones(sql: &Sql) -> Result<()> { } #[cfg(test)] -mod test { +mod tests { + use async_std::channel; use async_std::fs::File; use crate::config::Config; - use crate::{test_utils::TestContext, Event, EventType}; + use crate::{test_utils::TestContext, EventType}; use super::*; @@ -743,18 +744,8 @@ mod test { .await .unwrap(); - t.add_event_sink(move |event: Event| async move { - match event.typ { - EventType::Info(s) => assert!( - !s.contains("Keeping new unreferenced file"), - "File {} was almost deleted, only reason it was kept is that it was created recently (as the tests don't run for a long time)", - s - ), - EventType::Error(s) => panic!("{}", s), - _ => {} - } - }) - .await; + let (event_sink, event_source) = channel::unbounded(); + t.add_event_sender(event_sink).await; let a = t.get_config(Config::Selfavatar).await.unwrap().unwrap(); assert_eq!(avatar_bytes, &async_std::fs::read(&a).await.unwrap()[..]); @@ -765,6 +756,18 @@ mod test { let a = t.get_config(Config::Selfavatar).await.unwrap().unwrap(); assert_eq!(avatar_bytes, &async_std::fs::read(&a).await.unwrap()[..]); + + while let Ok(event) = event_source.try_recv() { + match event.typ { + EventType::Info(s) => assert!( + !s.contains("Keeping new unreferenced file"), + "File {} was almost deleted, only reason it was kept is that it was created recently (as the tests don't run for a long time)", + s + ), + EventType::Error(s) => panic!("{}", s), + _ => {} + } + } } /// Regression test. diff --git a/src/test_utils.rs b/src/test_utils.rs index b3762c20e..48eaf40fe 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -12,9 +12,7 @@ use std::time::{Duration, Instant}; use ansi_term::Color; use async_std::channel::{self, Receiver, Sender}; -use async_std::future::Future; use async_std::path::PathBuf; -use async_std::pin::Pin; use async_std::sync::{Arc, RwLock}; use async_std::task; use chat::ChatItem; @@ -41,9 +39,6 @@ use crate::param::{Param, Params}; #[allow(non_upper_case_globals)] pub const AVATAR_900x900_BYTES: &[u8] = include_bytes!("../test-data/image/avatar900x900.png"); -type EventSink = - dyn Fn(Event) -> Pin + Send + 'static>> + Send + Sync + 'static; - /// Map of [`Context::id`] to names for [`TestContext`]s. static CONTEXT_NAMES: Lazy>> = Lazy::new(|| std::sync::RwLock::new(BTreeMap::new())); @@ -55,7 +50,7 @@ pub struct TestContextBuilder { } impl TestContextBuilder { - /// Configures as alice@example.com with fixed secret key. + /// Configures as alice@example.org with fixed secret key. /// /// This is a shortcut for `.with_key_pair(alice_keypair()). pub fn configure_alice(self) -> Self { @@ -115,8 +110,8 @@ pub struct TestContext { pub ctx: Context, pub dir: TempDir, pub evtracker: EvTracker, - /// Functions to call for events received. - event_sinks: Arc>>>, + /// Channels which should receive events from this context. + event_senders: Arc>>>, /// Receives panics from sinks ("sink" means "event handler" here) poison_receiver: Receiver, /// Reference to implicit [`LogSink`] so it is dropped together with the context. @@ -196,16 +191,15 @@ impl TestContext { let events = ctx.get_event_emitter(); let (log_sender, log_sink) = match log_sender { - Some(sender) => (Arc::new(RwLock::new(sender)), None), + Some(sender) => (sender, None), None => { let (sender, sink) = LogSink::create(); - (Arc::new(RwLock::new(sender)), Some(sink)) + (sender, Some(sink)) } }; - let log_sender_clone = Arc::clone(&log_sender); - let event_sinks: Arc>>> = Arc::new(RwLock::new(Vec::new())); - let sinks = Arc::clone(&event_sinks); + let event_senders = Arc::new(RwLock::new(vec![log_sender])); + let senders = Arc::clone(&event_senders); let (poison_sender, poison_receiver) = channel::bounded(1); let (evtracker_sender, evtracker_receiver) = channel::unbounded(); @@ -225,17 +219,13 @@ impl TestContext { while let Some(event) = events.recv().await { { - let sinks = sinks.read().await; - for sink in sinks.iter() { - sink(event.clone()).await; + let sinks = senders.read().await; + for sender in sinks.iter() { + // Best effort, don't block because someone wanted to use a oneshot + // receiver. + sender.try_send(event.clone()).ok(); } } - log_sender_clone - .read() - .await - .send(event.clone()) - .await - .expect("log sender can not block"); evtracker_sender.send(event.typ).await.ok(); } }); @@ -244,7 +234,7 @@ impl TestContext { ctx, dir, evtracker: EvTracker(evtracker_receiver), - event_sinks, + event_senders, poison_receiver, log_sink, } @@ -260,22 +250,14 @@ impl TestContext { .or_insert_with(|| name.into()); } - /// Add a new callback which will receive events. + /// Adds a new [`Event`]s sender. /// - /// The test context runs an async task receiving all events from the [`Context`], which - /// are logged to stdout. This allows you to register additional callbacks which will - /// receive all events in case your tests need to watch for a specific event. - pub async fn add_event_sink(&self, sink: F) - where - // Aka `F: EventSink` but type aliases are not allowed. - F: Fn(Event) -> R + Send + Sync + 'static, - R: Future + Send + 'static, - { - let mut sinks = self.event_sinks.write().await; - sinks.push(Box::new(move |evt| Box::pin(sink(evt)))); + /// Once added, all events emitted by this context will be sent to this channel. This + /// is useful if you need to wait for events or make assertions on them. + pub async fn add_event_sender(&self, sink: Sender) { + self.event_senders.write().await.push(sink) } - /// Configure as a given email address. /// /// The context will be configured but the key will not be pre-generated so if a key is