mirror of
https://github.com/chatmail/core.git
synced 2026-04-20 06:56:29 +03:00
Add a ton of code for receiver-side progress
This commit is contained in:
@@ -23,11 +23,15 @@
|
||||
//! download to an impersonated getter.
|
||||
|
||||
use std::path::Path;
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicU16, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
|
||||
use crate::chat::delete_and_reset_all_device_msgs;
|
||||
use crate::context::Context;
|
||||
use crate::e2ee;
|
||||
use crate::qr::Qr;
|
||||
use crate::{e2ee, EventType};
|
||||
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
|
||||
use async_channel::Receiver;
|
||||
use futures_lite::StreamExt;
|
||||
@@ -36,8 +40,9 @@ use sendme::protocol::AuthToken;
|
||||
use sendme::provider::{DataSource, Event, Provider, Ticket};
|
||||
use sendme::Hash;
|
||||
use tokio::fs::{self, File};
|
||||
use tokio::io::{self, BufWriter};
|
||||
use tokio::io::{self, AsyncRead, AsyncWriteExt, BufWriter};
|
||||
use tokio::sync::broadcast;
|
||||
use tokio::sync::broadcast::error::RecvError;
|
||||
use tokio::task::JoinHandle;
|
||||
use tokio_stream::wrappers::ReadDirStream;
|
||||
|
||||
@@ -244,19 +249,31 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
|
||||
addr: ticket.addr,
|
||||
peer_id: Some(ticket.peer),
|
||||
};
|
||||
let on_blob = |hash, reader, name| on_blob(context, &ticket, hash, reader, name);
|
||||
match sendme::get::run(
|
||||
let progress = ProgressEmitter::new(0, 85);
|
||||
spawn_progress_proxy(context.clone(), progress.subscribe());
|
||||
let on_connected = || {
|
||||
context.emit_event(ReceiveProgress::Connected.into());
|
||||
async { Ok(()) }
|
||||
};
|
||||
let on_blob = |hash, reader, name| on_blob(context, &progress, &ticket, hash, reader, name);
|
||||
let res = sendme::get::run(
|
||||
ticket.hash,
|
||||
ticket.token,
|
||||
opts,
|
||||
on_connected,
|
||||
|_collection| async { Ok(()) },
|
||||
|collection| {
|
||||
context.emit_event(ReceiveProgress::CollectionRecieved.into());
|
||||
progress.set_total(collection.total_blobs_size());
|
||||
async { Ok(()) }
|
||||
},
|
||||
on_blob,
|
||||
)
|
||||
.await
|
||||
{
|
||||
.await;
|
||||
drop(progress);
|
||||
match res {
|
||||
Ok(_) => {
|
||||
delete_and_reset_all_device_msgs(context).await?;
|
||||
context.emit_event(ReceiveProgress::Completed.into());
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
@@ -268,20 +285,16 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
|
||||
fs::remove_file(dirent.path()).await.ok();
|
||||
}
|
||||
}
|
||||
context.emit_event(ReceiveProgress::Failed.into());
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get callback when the connection is established with the provider.
|
||||
#[allow(clippy::unused_async)]
|
||||
async fn on_connected() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get callback when a blob is received from the provider.
|
||||
async fn on_blob(
|
||||
context: &Context,
|
||||
progress: &ProgressEmitter,
|
||||
ticket: &Ticket,
|
||||
_hash: Hash,
|
||||
mut reader: DataStream,
|
||||
@@ -305,9 +318,13 @@ async fn on_blob(
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut wrapped_reader = progress.wrap_async_read(&mut reader);
|
||||
let file = File::create(&path).await?;
|
||||
let mut file = BufWriter::with_capacity(128 * 1024, file);
|
||||
io::copy(&mut reader, &mut file).await?;
|
||||
io::copy(&mut wrapped_reader, &mut file).await?;
|
||||
file.flush().await?;
|
||||
|
||||
if name.starts_with("db/") {
|
||||
context
|
||||
.sql
|
||||
@@ -321,6 +338,155 @@ async fn on_blob(
|
||||
Ok(reader)
|
||||
}
|
||||
|
||||
/// Spawns a task proxying progress events.
|
||||
///
|
||||
/// This spawns a tokio tasks which receives events from the [`ProgressEmitter`] and sends
|
||||
/// them to the context. The task finishes when the emitter is dropped.
|
||||
///
|
||||
/// This could be done directly in the emitter by making it less generic.
|
||||
fn spawn_progress_proxy(context: Context, mut rx: broadcast::Receiver<u16>) {
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
match rx.recv().await {
|
||||
Ok(step) => context.emit_event(ReceiveProgress::BlobProgress(step).into()),
|
||||
Err(RecvError::Closed) => break,
|
||||
Err(RecvError::Lagged(_)) => continue,
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Create [`EventType::ImexProgress`] events using readable names.
|
||||
///
|
||||
/// Plus you get warnings if you don't use all variants.
|
||||
enum ReceiveProgress {
|
||||
Connected,
|
||||
CollectionRecieved,
|
||||
/// A value between 0 and 85 as percentage
|
||||
BlobProgress(u16),
|
||||
Completed,
|
||||
Failed,
|
||||
}
|
||||
|
||||
impl From<ReceiveProgress> for EventType {
|
||||
fn from(source: ReceiveProgress) -> Self {
|
||||
let val = match source {
|
||||
ReceiveProgress::Connected => 50,
|
||||
ReceiveProgress::CollectionRecieved => 100,
|
||||
ReceiveProgress::BlobProgress(val) => 100 + 10 * val,
|
||||
ReceiveProgress::Completed => 1000,
|
||||
ReceiveProgress::Failed => 0,
|
||||
};
|
||||
EventType::ImexProgress(val.into())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ProgressEmitter {
|
||||
inner: Arc<InnerProgressEmitter>,
|
||||
}
|
||||
|
||||
impl ProgressEmitter {
|
||||
/// Creates a new emitter.
|
||||
///
|
||||
/// The emitter expects to see *total* being added via [`ProgressEmitter::inc`] and will
|
||||
/// emit *steps* updates.
|
||||
fn new(total: u64, steps: u16) -> Self {
|
||||
let (tx, _rx) = broadcast::channel(16);
|
||||
Self {
|
||||
inner: Arc::new(InnerProgressEmitter {
|
||||
total: AtomicU64::new(total),
|
||||
count: AtomicU64::new(0),
|
||||
steps,
|
||||
last_step: AtomicU16::new(0u16),
|
||||
tx,
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
/// Sets a new total in case you did not now the total up front.
|
||||
fn set_total(&self, value: u64) {
|
||||
self.inner.set_total(value)
|
||||
}
|
||||
|
||||
/// Return a receiver that gets incremental values.
|
||||
///
|
||||
/// The values yielded depend on *steps* passed to [`ProgressEmitter::new`]: it will go
|
||||
/// from `1..steps`.
|
||||
fn subscribe(&self) -> broadcast::Receiver<u16> {
|
||||
self.inner.subscribe()
|
||||
}
|
||||
|
||||
/// Increments the progress by *amount*.
|
||||
fn inc(&self, amount: u64) {
|
||||
self.inner.inc(amount);
|
||||
}
|
||||
|
||||
fn wrap_async_read<R: AsyncRead + Unpin>(&self, read: R) -> ProgressAsyncReader<R> {
|
||||
ProgressAsyncReader {
|
||||
emitter: self.clone(),
|
||||
inner: read,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct InnerProgressEmitter {
|
||||
total: AtomicU64,
|
||||
count: AtomicU64,
|
||||
steps: u16,
|
||||
last_step: AtomicU16,
|
||||
tx: broadcast::Sender<u16>,
|
||||
}
|
||||
|
||||
impl InnerProgressEmitter {
|
||||
fn inc(&self, amount: u64) {
|
||||
let prev_count = self.count.fetch_add(amount, Ordering::Relaxed);
|
||||
let count = prev_count + amount;
|
||||
let total = self.total.load(Ordering::Relaxed);
|
||||
let step = (total * u64::from(self.steps) / std::cmp::min(count, total)) as u16;
|
||||
let last_step = self.last_step.swap(step, Ordering::Relaxed);
|
||||
if step > last_step {
|
||||
self.tx.send(step).ok();
|
||||
}
|
||||
}
|
||||
|
||||
fn set_total(&self, value: u64) {
|
||||
self.total.store(value, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
fn subscribe(&self) -> broadcast::Receiver<u16> {
|
||||
self.tx.subscribe()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct ProgressAsyncReader<R: AsyncRead + Unpin> {
|
||||
emitter: ProgressEmitter,
|
||||
inner: R,
|
||||
}
|
||||
|
||||
impl<R> AsyncRead for ProgressAsyncReader<R>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
buf: &mut io::ReadBuf<'_>,
|
||||
) -> Poll<std::io::Result<()>> {
|
||||
let prev_len = buf.filled().len() as u64;
|
||||
match Pin::new(&mut self.inner).poll_read(cx, buf) {
|
||||
Poll::Ready(val) => {
|
||||
let new_len = buf.filled().len() as u64;
|
||||
self.emitter.inc(new_len - prev_len);
|
||||
Poll::Ready(val)
|
||||
}
|
||||
Poll::Pending => Poll::Pending,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::time::Duration;
|
||||
|
||||
Reference in New Issue
Block a user