From c6c20d8f3c503039860c6b33de311ce01fe515ff Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Mon, 3 Apr 2023 11:13:44 +0200 Subject: [PATCH] ref(scheduler): Make InnerSchedulerState an enum (#4251) This is more verbose, but makes reasoning about things easier. --- src/imex.rs | 2 +- src/imex/transfer.rs | 4 +- src/scheduler.rs | 130 +++++++++++++++++++++++++--------- src/scheduler/connectivity.rs | 20 +++--- 4 files changed, 109 insertions(+), 47 deletions(-) diff --git a/src/imex.rs b/src/imex.rs index 62d97f8fc..cc6a59dc9 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -91,7 +91,7 @@ pub async fn imex( let cancel = context.alloc_ongoing().await?; let res = { - let _guard = context.scheduler.pause(context.clone()).await; + let _guard = context.scheduler.pause(context.clone()).await?; imex_inner(context, what, path, passphrase) .race(async { cancel.recv().await.ok(); diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index e2a257511..cfb390f42 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -97,7 +97,7 @@ impl BackupProvider { // Acquire global "ongoing" mutex. let cancel_token = context.alloc_ongoing().await?; - let paused_guard = context.scheduler.pause(context.clone()).await; + let paused_guard = context.scheduler.pause(context.clone()).await?; let context_dir = context .get_blobdir() .parent() @@ -386,7 +386,7 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { !context.is_configured().await?, "Cannot import backups to accounts in use." ); - let _guard = context.scheduler.pause(context.clone()).await; + let _guard = context.scheduler.pause(context.clone()).await?; // Acquire global "ongoing" mutex. let cancel_token = context.alloc_ongoing().await?; diff --git a/src/scheduler.rs b/src/scheduler.rs index 1c1d3eb5d..2404d7e22 100644 --- a/src/scheduler.rs +++ b/src/scheduler.rs @@ -1,7 +1,8 @@ use std::iter::{self, once}; +use std::num::NonZeroUsize; use std::sync::atomic::Ordering; -use anyhow::{bail, Context as _, Result}; +use anyhow::{bail, Context as _, Error, Result}; use async_channel::{self as channel, Receiver, Sender}; use futures::future::try_join_all; use futures_lite::FutureExt; @@ -43,15 +44,18 @@ impl SchedulerState { /// Whether the scheduler is currently running. pub(crate) async fn is_running(&self) -> bool { let inner = self.inner.read().await; - inner.scheduler.is_some() + matches!(*inner, InnerSchedulerState::Started(_)) } /// Starts the scheduler if it is not yet started. pub(crate) async fn start(&self, context: Context) { let mut inner = self.inner.write().await; - inner.started = true; - if inner.scheduler.is_none() && inner.paused == 0 { - Self::do_start(inner, context).await; + match *inner { + InnerSchedulerState::Started(_) => (), + InnerSchedulerState::Stopped => Self::do_start(inner, context).await, + InnerSchedulerState::Paused { + ref mut started, .. + } => *started = true, } } @@ -60,7 +64,7 @@ impl SchedulerState { info!(context, "starting IO"); let ctx = context.clone(); match Scheduler::start(context).await { - Ok(scheduler) => inner.scheduler = Some(scheduler), + Ok(scheduler) => *inner = InnerSchedulerState::Started(scheduler), Err(err) => error!(&ctx, "Failed to start IO: {:#}", err), } } @@ -68,12 +72,23 @@ impl SchedulerState { /// Stops the scheduler if it is currently running. pub(crate) async fn stop(&self, context: &Context) { let mut inner = self.inner.write().await; - inner.started = false; - Self::do_stop(inner, context).await; + match *inner { + InnerSchedulerState::Started(_) => { + Self::do_stop(inner, context, InnerSchedulerState::Stopped).await + } + InnerSchedulerState::Stopped => (), + InnerSchedulerState::Paused { + ref mut started, .. + } => *started = false, + } } /// Stops the scheduler if it is currently running. - async fn do_stop(mut inner: RwLockWriteGuard<'_, InnerSchedulerState>, context: &Context) { + async fn do_stop( + mut inner: RwLockWriteGuard<'_, InnerSchedulerState>, + context: &Context, + new_state: InnerSchedulerState, + ) { // Sending an event wakes up event pollers (get_next_event) // so the caller of stop_io() can arrange for proper termination. // For this, the caller needs to instruct the event poller @@ -83,8 +98,10 @@ impl SchedulerState { if let Some(debug_logging) = context.debug_logging.read().await.as_ref() { debug_logging.loop_handle.abort(); } - if let Some(scheduler) = inner.scheduler.take() { - scheduler.stop(context).await; + let prev_state = std::mem::replace(&mut *inner, new_state); + match prev_state { + InnerSchedulerState::Started(scheduler) => scheduler.stop(context).await, + InnerSchedulerState::Stopped | InnerSchedulerState::Paused { .. } => (), } } @@ -96,22 +113,63 @@ impl SchedulerState { /// If in the meantime [`SchedulerState::start`] or [`SchedulerState::stop`] is called /// resume will do the right thing and restore the scheduler to the state requested by /// the last call. - pub(crate) async fn pause<'a>(&'_ self, context: Context) -> IoPausedGuard { + pub(crate) async fn pause<'a>(&'_ self, context: Context) -> Result { { let mut inner = self.inner.write().await; - inner.paused += 1; - Self::do_stop(inner, &context).await; + match *inner { + InnerSchedulerState::Started(_) => { + let new_state = InnerSchedulerState::Paused { + started: true, + pause_guards_count: NonZeroUsize::new(1).unwrap(), + }; + Self::do_stop(inner, &context, new_state).await; + } + InnerSchedulerState::Stopped => { + *inner = InnerSchedulerState::Paused { + started: false, + pause_guards_count: NonZeroUsize::new(1).unwrap(), + }; + } + InnerSchedulerState::Paused { + ref mut pause_guards_count, + .. + } => { + *pause_guards_count = pause_guards_count + .checked_add(1) + .ok_or_else(|| Error::msg("Too many pause guards active"))? + } + } } + let (tx, rx) = oneshot::channel(); tokio::spawn(async move { rx.await.ok(); let mut inner = context.scheduler.inner.write().await; - inner.paused -= 1; - if inner.paused == 0 && inner.started && inner.scheduler.is_none() { - SchedulerState::do_start(inner, context.clone()).await; + match *inner { + InnerSchedulerState::Started(_) => { + warn!(&context, "IoPausedGuard resume: started instead of paused"); + } + InnerSchedulerState::Stopped => { + warn!(&context, "IoPausedGuard resume: stopped instead of paused"); + } + InnerSchedulerState::Paused { + ref started, + ref mut pause_guards_count, + } => { + if *pause_guards_count == NonZeroUsize::new(1).unwrap() { + match *started { + true => SchedulerState::do_start(inner, context.clone()).await, + false => *inner = InnerSchedulerState::Stopped, + } + } else { + let new_count = pause_guards_count.get() - 1; + // SAFETY: Value was >=2 before due to if condition + *pause_guards_count = NonZeroUsize::new(new_count).unwrap(); + } + } } }); - IoPausedGuard { sender: Some(tx) } + Ok(IoPausedGuard { sender: Some(tx) }) } /// Restarts the scheduler, only if it is running. @@ -126,8 +184,8 @@ impl SchedulerState { /// Indicate that the network likely has come back. pub(crate) async fn maybe_network(&self) { let inner = self.inner.read().await; - let (inbox, oboxes) = match inner.scheduler { - Some(ref scheduler) => { + let (inbox, oboxes) = match *inner { + InnerSchedulerState::Started(ref scheduler) => { scheduler.maybe_network(); let inbox = scheduler.inbox.conn_state.state.connectivity.clone(); let oboxes = scheduler @@ -137,7 +195,7 @@ impl SchedulerState { .collect::>(); (inbox, oboxes) } - None => return, + _ => return, }; drop(inner); connectivity::idle_interrupted(inbox, oboxes).await; @@ -146,15 +204,15 @@ impl SchedulerState { /// Indicate that the network likely is lost. pub(crate) async fn maybe_network_lost(&self, context: &Context) { let inner = self.inner.read().await; - let stores = match inner.scheduler { - Some(ref scheduler) => { + let stores = match *inner { + InnerSchedulerState::Started(ref scheduler) => { scheduler.maybe_network_lost(); scheduler .boxes() .map(|b| b.conn_state.state.connectivity.clone()) .collect() } - None => return, + _ => return, }; drop(inner); connectivity::maybe_network_lost(context, stores).await; @@ -162,47 +220,49 @@ impl SchedulerState { pub(crate) async fn interrupt_inbox(&self, info: InterruptInfo) { let inner = self.inner.read().await; - if let Some(ref scheduler) = inner.scheduler { + if let InnerSchedulerState::Started(ref scheduler) = *inner { scheduler.interrupt_inbox(info); } } pub(crate) async fn interrupt_smtp(&self, info: InterruptInfo) { let inner = self.inner.read().await; - if let Some(ref scheduler) = inner.scheduler { + if let InnerSchedulerState::Started(ref scheduler) = *inner { scheduler.interrupt_smtp(info); } } pub(crate) async fn interrupt_ephemeral_task(&self) { let inner = self.inner.read().await; - if let Some(ref scheduler) = inner.scheduler { + if let InnerSchedulerState::Started(ref scheduler) = *inner { scheduler.interrupt_ephemeral_task(); } } pub(crate) async fn interrupt_location(&self) { let inner = self.inner.read().await; - if let Some(ref scheduler) = inner.scheduler { + if let InnerSchedulerState::Started(ref scheduler) = *inner { scheduler.interrupt_location(); } } pub(crate) async fn interrupt_recently_seen(&self, contact_id: ContactId, timestamp: i64) { let inner = self.inner.read().await; - if let Some(ref scheduler) = inner.scheduler { + if let InnerSchedulerState::Started(ref scheduler) = *inner { scheduler.interrupt_recently_seen(contact_id, timestamp); } } } #[derive(Debug, Default)] -struct InnerSchedulerState { - scheduler: Option, - /// Whether IO should be started if there is no [`IoPausedGuard`] active. - started: bool, - /// The number of [`IoPausedGuard`]s that are outstanding. - paused: u32, +enum InnerSchedulerState { + Started(Scheduler), + #[default] + Stopped, + Paused { + started: bool, + pause_guards_count: NonZeroUsize, + }, } /// Guard to make sure the IO Scheduler is resumed. diff --git a/src/scheduler/connectivity.rs b/src/scheduler/connectivity.rs index 9febb8b67..49f3ba55e 100644 --- a/src/scheduler/connectivity.rs +++ b/src/scheduler/connectivity.rs @@ -14,6 +14,8 @@ use crate::tools::time; use crate::{context::Context, log::LogExt}; use crate::{stock_str, tools}; +use super::InnerSchedulerState; + #[derive(Debug, Clone, Copy, PartialEq, Eq, EnumProperty, PartialOrd, Ord)] pub enum Connectivity { NotConnected = 1000, @@ -226,12 +228,12 @@ impl Context { /// If the connectivity changes, a DC_EVENT_CONNECTIVITY_CHANGED will be emitted. pub async fn get_connectivity(&self) -> Connectivity { let lock = self.scheduler.inner.read().await; - let stores: Vec<_> = match lock.scheduler { - Some(ref sched) => sched + let stores: Vec<_> = match *lock { + InnerSchedulerState::Started(ref sched) => sched .boxes() .map(|b| b.conn_state.state.connectivity.clone()) .collect(), - None => return Connectivity::NotConnected, + _ => return Connectivity::NotConnected, }; drop(lock); @@ -309,15 +311,15 @@ impl Context { // ============================================================================================= let lock = self.scheduler.inner.read().await; - let (folders_states, smtp) = match lock.scheduler { - Some(ref sched) => ( + let (folders_states, smtp) = match *lock { + InnerSchedulerState::Started(ref sched) => ( sched .boxes() .map(|b| (b.meaning, b.conn_state.state.connectivity.clone())) .collect::>(), sched.smtp.state.connectivity.clone(), ), - None => { + _ => { return Err(anyhow!("Not started")); } }; @@ -480,14 +482,14 @@ impl Context { /// Returns true if all background work is done. pub async fn all_work_done(&self) -> bool { let lock = self.scheduler.inner.read().await; - let stores: Vec<_> = match lock.scheduler { - Some(ref sched) => sched + let stores: Vec<_> = match *lock { + InnerSchedulerState::Started(ref sched) => sched .boxes() .map(|b| &b.conn_state.state) .chain(once(&sched.smtp.state)) .map(|state| state.connectivity.clone()) .collect(), - None => return false, + _ => return false, }; drop(lock);