From 2129b2b7a0efd3fe422bcce79103b215c4c8f42c Mon Sep 17 00:00:00 2001 From: Floris Bruynooghe Date: Thu, 9 Feb 2023 18:09:16 +0100 Subject: [PATCH] Add a ton of code for receiver-side progress --- src/imex/transfer.rs | 194 +++++++++++++++++++++++++++++++++++++++---- 1 file changed, 180 insertions(+), 14 deletions(-) diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index a09d51c0b..5e586354d 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -23,11 +23,15 @@ //! download to an impersonated getter. use std::path::Path; +use std::pin::Pin; +use std::sync::atomic::{AtomicU16, AtomicU64, Ordering}; +use std::sync::Arc; +use std::task::Poll; use crate::chat::delete_and_reset_all_device_msgs; use crate::context::Context; -use crate::e2ee; use crate::qr::Qr; +use crate::{e2ee, EventType}; use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result}; use async_channel::Receiver; use futures_lite::StreamExt; @@ -36,8 +40,9 @@ use sendme::protocol::AuthToken; use sendme::provider::{DataSource, Event, Provider, Ticket}; use sendme::Hash; use tokio::fs::{self, File}; -use tokio::io::{self, BufWriter}; +use tokio::io::{self, AsyncRead, AsyncWriteExt, BufWriter}; use tokio::sync::broadcast; +use tokio::sync::broadcast::error::RecvError; use tokio::task::JoinHandle; use tokio_stream::wrappers::ReadDirStream; @@ -244,19 +249,31 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { addr: ticket.addr, peer_id: Some(ticket.peer), }; - let on_blob = |hash, reader, name| on_blob(context, &ticket, hash, reader, name); - match sendme::get::run( + let progress = ProgressEmitter::new(0, 85); + spawn_progress_proxy(context.clone(), progress.subscribe()); + let on_connected = || { + context.emit_event(ReceiveProgress::Connected.into()); + async { Ok(()) } + }; + let on_blob = |hash, reader, name| on_blob(context, &progress, &ticket, hash, reader, name); + let res = sendme::get::run( ticket.hash, ticket.token, opts, on_connected, - |_collection| async { Ok(()) }, + |collection| { + context.emit_event(ReceiveProgress::CollectionRecieved.into()); + progress.set_total(collection.total_blobs_size()); + async { Ok(()) } + }, on_blob, ) - .await - { + .await; + drop(progress); + match res { Ok(_) => { delete_and_reset_all_device_msgs(context).await?; + context.emit_event(ReceiveProgress::Completed.into()); Ok(()) } Err(err) => { @@ -268,20 +285,16 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> { fs::remove_file(dirent.path()).await.ok(); } } + context.emit_event(ReceiveProgress::Failed.into()); Err(err) } } } -/// Get callback when the connection is established with the provider. -#[allow(clippy::unused_async)] -async fn on_connected() -> Result<()> { - Ok(()) -} - /// Get callback when a blob is received from the provider. async fn on_blob( context: &Context, + progress: &ProgressEmitter, ticket: &Ticket, _hash: Hash, mut reader: DataStream, @@ -305,9 +318,13 @@ async fn on_blob( } else { None }; + + let mut wrapped_reader = progress.wrap_async_read(&mut reader); let file = File::create(&path).await?; let mut file = BufWriter::with_capacity(128 * 1024, file); - io::copy(&mut reader, &mut file).await?; + io::copy(&mut wrapped_reader, &mut file).await?; + file.flush().await?; + if name.starts_with("db/") { context .sql @@ -321,6 +338,155 @@ async fn on_blob( Ok(reader) } +/// Spawns a task proxying progress events. +/// +/// This spawns a tokio tasks which receives events from the [`ProgressEmitter`] and sends +/// them to the context. The task finishes when the emitter is dropped. +/// +/// This could be done directly in the emitter by making it less generic. +fn spawn_progress_proxy(context: Context, mut rx: broadcast::Receiver) { + tokio::spawn(async move { + loop { + match rx.recv().await { + Ok(step) => context.emit_event(ReceiveProgress::BlobProgress(step).into()), + Err(RecvError::Closed) => break, + Err(RecvError::Lagged(_)) => continue, + } + } + }); +} + +/// Create [`EventType::ImexProgress`] events using readable names. +/// +/// Plus you get warnings if you don't use all variants. +enum ReceiveProgress { + Connected, + CollectionRecieved, + /// A value between 0 and 85 as percentage + BlobProgress(u16), + Completed, + Failed, +} + +impl From for EventType { + fn from(source: ReceiveProgress) -> Self { + let val = match source { + ReceiveProgress::Connected => 50, + ReceiveProgress::CollectionRecieved => 100, + ReceiveProgress::BlobProgress(val) => 100 + 10 * val, + ReceiveProgress::Completed => 1000, + ReceiveProgress::Failed => 0, + }; + EventType::ImexProgress(val.into()) + } +} + +#[derive(Debug, Clone)] +struct ProgressEmitter { + inner: Arc, +} + +impl ProgressEmitter { + /// Creates a new emitter. + /// + /// The emitter expects to see *total* being added via [`ProgressEmitter::inc`] and will + /// emit *steps* updates. + fn new(total: u64, steps: u16) -> Self { + let (tx, _rx) = broadcast::channel(16); + Self { + inner: Arc::new(InnerProgressEmitter { + total: AtomicU64::new(total), + count: AtomicU64::new(0), + steps, + last_step: AtomicU16::new(0u16), + tx, + }), + } + } + + /// Sets a new total in case you did not now the total up front. + fn set_total(&self, value: u64) { + self.inner.set_total(value) + } + + /// Return a receiver that gets incremental values. + /// + /// The values yielded depend on *steps* passed to [`ProgressEmitter::new`]: it will go + /// from `1..steps`. + fn subscribe(&self) -> broadcast::Receiver { + self.inner.subscribe() + } + + /// Increments the progress by *amount*. + fn inc(&self, amount: u64) { + self.inner.inc(amount); + } + + fn wrap_async_read(&self, read: R) -> ProgressAsyncReader { + ProgressAsyncReader { + emitter: self.clone(), + inner: read, + } + } +} + +#[derive(Debug)] +struct InnerProgressEmitter { + total: AtomicU64, + count: AtomicU64, + steps: u16, + last_step: AtomicU16, + tx: broadcast::Sender, +} + +impl InnerProgressEmitter { + fn inc(&self, amount: u64) { + let prev_count = self.count.fetch_add(amount, Ordering::Relaxed); + let count = prev_count + amount; + let total = self.total.load(Ordering::Relaxed); + let step = (total * u64::from(self.steps) / std::cmp::min(count, total)) as u16; + let last_step = self.last_step.swap(step, Ordering::Relaxed); + if step > last_step { + self.tx.send(step).ok(); + } + } + + fn set_total(&self, value: u64) { + self.total.store(value, Ordering::Relaxed); + } + + fn subscribe(&self) -> broadcast::Receiver { + self.tx.subscribe() + } +} + +#[derive(Debug)] +struct ProgressAsyncReader { + emitter: ProgressEmitter, + inner: R, +} + +impl AsyncRead for ProgressAsyncReader +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut io::ReadBuf<'_>, + ) -> Poll> { + let prev_len = buf.filled().len() as u64; + match Pin::new(&mut self.inner).poll_read(cx, buf) { + Poll::Ready(val) => { + let new_len = buf.filled().len() as u64; + self.emitter.inc(new_len - prev_len); + Poll::Ready(val) + } + Poll::Pending => Poll::Pending, + } + } +} + #[cfg(test)] mod tests { use std::time::Duration;