diff --git a/src/contact.rs b/src/contact.rs index 029f43694..72e01e347 100644 --- a/src/contact.rs +++ b/src/contact.rs @@ -2029,7 +2029,7 @@ CCCB 5AA9 F6E1 141C 9431 alice1 .evtracker - .get_matching(|e| e == EventType::SelfavatarChanged) + .get_matching(|e| matches!(e, EventType::SelfavatarChanged)) .await; // Bob sends a message so that Alice can encrypt to him. @@ -2059,7 +2059,7 @@ CCCB 5AA9 F6E1 141C 9431 assert!(alice2.get_config(Config::Selfavatar).await?.is_some()); alice2 .evtracker - .get_matching(|e| e == EventType::SelfavatarChanged) + .get_matching(|e| matches!(e, EventType::SelfavatarChanged)) .await; Ok(()) diff --git a/src/dc_receive_imf.rs b/src/dc_receive_imf.rs index eecef33ac..2ac87fd68 100644 --- a/src/dc_receive_imf.rs +++ b/src/dc_receive_imf.rs @@ -4684,14 +4684,18 @@ Second thread."#; let chat = chat::Chat::load_from_db(&t, msg.chat_id).await?; assert!(chat.is_contact_request()); - let duration = std::time::Duration::from_secs(1); loop { - let event = async_std::future::timeout(duration, t.evtracker.recv()).await??; - - if let EventType::IncomingMsg { chat_id, msg_id } = &event { - assert_eq!(msg.chat_id, *chat_id); - assert_eq!(msg.id, *msg_id); - return Ok(()); + let event = t + .evtracker + .get_matching(|evt| matches!(evt, EventType::IncomingMsg { .. })) + .await; + match event { + EventType::IncomingMsg { chat_id, msg_id } => { + assert_eq!(msg.chat_id, chat_id); + assert_eq!(msg.id, msg_id); + return Ok(()); + } + _ => unreachable!(), } } } diff --git a/src/sql.rs b/src/sql.rs index 59e97c2c2..83a404749 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -829,17 +829,24 @@ mod tests { assert!(!disable_server_delete); assert!(!recode_avatar); - info!(&t, "test_migration_flags: XXX"); + info!(&t, "test_migration_flags: XXX END MARKER"); loop { - if let EventType::Info(info) = t.evtracker.recv().await.unwrap() { - assert!( - !info.contains("[migration]"), - "Migrations were run twice, you probably forgot to update the db version" - ); - if info.contains("test_migration_flags: XXX") { - break; + let evt = t + .evtracker + .get_matching(|evt| matches!(evt, EventType::Info(_))) + .await; + match evt { + EventType::Info(msg) => { + assert!( + !msg.contains("[migration]"), + "Migrations were run twice, you probably forgot to update the db version" + ); + if msg.contains("test_migration_flags: XXX END MARKER") { + break; + } } + _ => unreachable!(), } } diff --git a/src/test_utils.rs b/src/test_utils.rs index 48eaf40fe..dc9a5569d 100644 --- a/src/test_utils.rs +++ b/src/test_utils.rs @@ -3,7 +3,6 @@ //! This private module is only compiled for test runs. use std::collections::BTreeMap; -use std::fmt; use std::ops::Deref; use std::panic; use std::str::FromStr; @@ -13,6 +12,7 @@ use std::time::{Duration, Instant}; use ansi_term::Color; use async_std::channel::{self, Receiver, Sender}; use async_std::path::PathBuf; +use async_std::prelude::*; use async_std::sync::{Arc, RwLock}; use async_std::task; use chat::ChatItem; @@ -106,10 +106,11 @@ impl TestContextBuilder { /// /// The temporary directory can be used to store the SQLite database, /// see e.g. [test_context] which does this. +#[derive(Debug)] pub struct TestContext { pub ctx: Context, pub dir: TempDir, - pub evtracker: EvTracker, + pub evtracker: EventTracker, /// Channels which should receive events from this context. event_senders: Arc>>>, /// Receives panics from sinks ("sink" means "event handler" here) @@ -126,16 +127,6 @@ pub struct TestContext { log_sink: Option, } -impl fmt::Debug for TestContext { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("TestContext") - .field("ctx", &self.ctx) - .field("dir", &self.dir) - .field("event_sinks", &String::from("Vec")) - .finish() - } -} - impl TestContext { /// Returns the builder to have more control over creating the context. pub fn builder() -> TestContextBuilder { @@ -198,10 +189,10 @@ impl TestContext { } }; - let event_senders = Arc::new(RwLock::new(vec![log_sender])); + let (evtracker_sender, evtracker_receiver) = channel::unbounded(); + let event_senders = Arc::new(RwLock::new(vec![log_sender, evtracker_sender])); let senders = Arc::clone(&event_senders); let (poison_sender, poison_receiver) = channel::bounded(1); - let (evtracker_sender, evtracker_receiver) = channel::unbounded(); task::spawn(async move { // Make sure that the test fails if there is a panic on this thread here @@ -221,19 +212,18 @@ impl TestContext { { let sinks = senders.read().await; for sender in sinks.iter() { - // Best effort, don't block because someone wanted to use a oneshot - // receiver. + // Don't block because someone wanted to use a oneshot receiver, use + // an unbounded channel if you want all events. sender.try_send(event.clone()).ok(); } } - evtracker_sender.send(event.typ).await.ok(); } }); Self { ctx, dir, - evtracker: EvTracker(evtracker_receiver), + evtracker: EventTracker(evtracker_receiver), event_senders, poison_receiver, log_sink, @@ -573,6 +563,7 @@ impl Drop for TestContext { /// /// To use this create an instance using [`LogSink::create`] and then use the /// [`TestContextBuilder::with_log_sink`]. +#[derive(Debug)] pub struct LogSink { events: Receiver, } @@ -662,40 +653,43 @@ pub fn bob_keypair() -> KeyPair { } } -pub struct EvTracker(Receiver); +/// Utility to help wait for and retrieve events. +/// +/// This buffers the events in order they are emitted. This allows consuming events in +/// order while looking for the right events using the provided methods. +/// +/// The methods only return [`EventType`] rather than the full [`Event`] since it can only +/// be attached to a single [`TestContext`] and therefore the context is already known as +/// you will be accessing it as [`TestContext::evtracker`]. +#[derive(Debug)] +pub struct EventTracker(Receiver); -impl EvTracker { - pub async fn get_info_contains(&self, s: &str) -> EventType { - loop { - let event = self.0.recv().await.unwrap(); - if let EventType::Info(i) = &event { - if i.contains(s) { - return event; +impl EventTracker { + /// Consumes emitted events returning the first matching one. + /// + /// If no matching events are ready this will wait for new events to arrive and time out + /// after 10 seconds. + pub async fn get_matching bool>(&self, event_matcher: F) -> EventType { + async move { + loop { + let event = self.0.recv().await.unwrap(); + if event_matcher(&event.typ) { + return event.typ; } } } + .timeout(Duration::from_secs(10)) + .await + .expect("timeout waiting for event match") } - pub async fn get_matching bool>(&self, event_matcher: F) -> EventType { - const TIMEOUT: Duration = Duration::from_secs(20); - - loop { - let event = async_std::future::timeout(TIMEOUT, self.recv()) - .await - .unwrap() - .unwrap(); - - if event_matcher(event.clone()) { - return event; - } - } - } -} - -impl Deref for EvTracker { - type Target = Receiver; - fn deref(&self) -> &Self::Target { - &self.0 + /// Consumes events looking for an [`EventType::Info`] with substring matching. + pub async fn get_info_contains(&self, s: &str) -> EventType { + self.get_matching(|evt| match evt { + EventType::Info(ref msg) => msg.contains(s), + _ => false, + }) + .await } }