mirror of
https://github.com/chatmail/core.git
synced 2026-04-29 03:16:29 +03:00
Add a ton of code for receiver-side progress
This commit is contained in:
@@ -23,11 +23,15 @@
|
|||||||
//! download to an impersonated getter.
|
//! download to an impersonated getter.
|
||||||
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::pin::Pin;
|
||||||
|
use std::sync::atomic::{AtomicU16, AtomicU64, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::task::Poll;
|
||||||
|
|
||||||
use crate::chat::delete_and_reset_all_device_msgs;
|
use crate::chat::delete_and_reset_all_device_msgs;
|
||||||
use crate::context::Context;
|
use crate::context::Context;
|
||||||
use crate::e2ee;
|
|
||||||
use crate::qr::Qr;
|
use crate::qr::Qr;
|
||||||
|
use crate::{e2ee, EventType};
|
||||||
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
|
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
|
||||||
use async_channel::Receiver;
|
use async_channel::Receiver;
|
||||||
use futures_lite::StreamExt;
|
use futures_lite::StreamExt;
|
||||||
@@ -36,8 +40,9 @@ use sendme::protocol::AuthToken;
|
|||||||
use sendme::provider::{DataSource, Event, Provider, Ticket};
|
use sendme::provider::{DataSource, Event, Provider, Ticket};
|
||||||
use sendme::Hash;
|
use sendme::Hash;
|
||||||
use tokio::fs::{self, File};
|
use tokio::fs::{self, File};
|
||||||
use tokio::io::{self, BufWriter};
|
use tokio::io::{self, AsyncRead, AsyncWriteExt, BufWriter};
|
||||||
use tokio::sync::broadcast;
|
use tokio::sync::broadcast;
|
||||||
|
use tokio::sync::broadcast::error::RecvError;
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
use tokio_stream::wrappers::ReadDirStream;
|
use tokio_stream::wrappers::ReadDirStream;
|
||||||
|
|
||||||
@@ -244,19 +249,31 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
|
|||||||
addr: ticket.addr,
|
addr: ticket.addr,
|
||||||
peer_id: Some(ticket.peer),
|
peer_id: Some(ticket.peer),
|
||||||
};
|
};
|
||||||
let on_blob = |hash, reader, name| on_blob(context, &ticket, hash, reader, name);
|
let progress = ProgressEmitter::new(0, 85);
|
||||||
match sendme::get::run(
|
spawn_progress_proxy(context.clone(), progress.subscribe());
|
||||||
|
let on_connected = || {
|
||||||
|
context.emit_event(ReceiveProgress::Connected.into());
|
||||||
|
async { Ok(()) }
|
||||||
|
};
|
||||||
|
let on_blob = |hash, reader, name| on_blob(context, &progress, &ticket, hash, reader, name);
|
||||||
|
let res = sendme::get::run(
|
||||||
ticket.hash,
|
ticket.hash,
|
||||||
ticket.token,
|
ticket.token,
|
||||||
opts,
|
opts,
|
||||||
on_connected,
|
on_connected,
|
||||||
|_collection| async { Ok(()) },
|
|collection| {
|
||||||
|
context.emit_event(ReceiveProgress::CollectionRecieved.into());
|
||||||
|
progress.set_total(collection.total_blobs_size());
|
||||||
|
async { Ok(()) }
|
||||||
|
},
|
||||||
on_blob,
|
on_blob,
|
||||||
)
|
)
|
||||||
.await
|
.await;
|
||||||
{
|
drop(progress);
|
||||||
|
match res {
|
||||||
Ok(_) => {
|
Ok(_) => {
|
||||||
delete_and_reset_all_device_msgs(context).await?;
|
delete_and_reset_all_device_msgs(context).await?;
|
||||||
|
context.emit_event(ReceiveProgress::Completed.into());
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
@@ -268,20 +285,16 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
|
|||||||
fs::remove_file(dirent.path()).await.ok();
|
fs::remove_file(dirent.path()).await.ok();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
context.emit_event(ReceiveProgress::Failed.into());
|
||||||
Err(err)
|
Err(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get callback when the connection is established with the provider.
|
|
||||||
#[allow(clippy::unused_async)]
|
|
||||||
async fn on_connected() -> Result<()> {
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get callback when a blob is received from the provider.
|
/// Get callback when a blob is received from the provider.
|
||||||
async fn on_blob(
|
async fn on_blob(
|
||||||
context: &Context,
|
context: &Context,
|
||||||
|
progress: &ProgressEmitter,
|
||||||
ticket: &Ticket,
|
ticket: &Ticket,
|
||||||
_hash: Hash,
|
_hash: Hash,
|
||||||
mut reader: DataStream,
|
mut reader: DataStream,
|
||||||
@@ -305,9 +318,13 @@ async fn on_blob(
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let mut wrapped_reader = progress.wrap_async_read(&mut reader);
|
||||||
let file = File::create(&path).await?;
|
let file = File::create(&path).await?;
|
||||||
let mut file = BufWriter::with_capacity(128 * 1024, file);
|
let mut file = BufWriter::with_capacity(128 * 1024, file);
|
||||||
io::copy(&mut reader, &mut file).await?;
|
io::copy(&mut wrapped_reader, &mut file).await?;
|
||||||
|
file.flush().await?;
|
||||||
|
|
||||||
if name.starts_with("db/") {
|
if name.starts_with("db/") {
|
||||||
context
|
context
|
||||||
.sql
|
.sql
|
||||||
@@ -321,6 +338,155 @@ async fn on_blob(
|
|||||||
Ok(reader)
|
Ok(reader)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Spawns a task proxying progress events.
|
||||||
|
///
|
||||||
|
/// This spawns a tokio tasks which receives events from the [`ProgressEmitter`] and sends
|
||||||
|
/// them to the context. The task finishes when the emitter is dropped.
|
||||||
|
///
|
||||||
|
/// This could be done directly in the emitter by making it less generic.
|
||||||
|
fn spawn_progress_proxy(context: Context, mut rx: broadcast::Receiver<u16>) {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
loop {
|
||||||
|
match rx.recv().await {
|
||||||
|
Ok(step) => context.emit_event(ReceiveProgress::BlobProgress(step).into()),
|
||||||
|
Err(RecvError::Closed) => break,
|
||||||
|
Err(RecvError::Lagged(_)) => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create [`EventType::ImexProgress`] events using readable names.
|
||||||
|
///
|
||||||
|
/// Plus you get warnings if you don't use all variants.
|
||||||
|
enum ReceiveProgress {
|
||||||
|
Connected,
|
||||||
|
CollectionRecieved,
|
||||||
|
/// A value between 0 and 85 as percentage
|
||||||
|
BlobProgress(u16),
|
||||||
|
Completed,
|
||||||
|
Failed,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<ReceiveProgress> for EventType {
|
||||||
|
fn from(source: ReceiveProgress) -> Self {
|
||||||
|
let val = match source {
|
||||||
|
ReceiveProgress::Connected => 50,
|
||||||
|
ReceiveProgress::CollectionRecieved => 100,
|
||||||
|
ReceiveProgress::BlobProgress(val) => 100 + 10 * val,
|
||||||
|
ReceiveProgress::Completed => 1000,
|
||||||
|
ReceiveProgress::Failed => 0,
|
||||||
|
};
|
||||||
|
EventType::ImexProgress(val.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct ProgressEmitter {
|
||||||
|
inner: Arc<InnerProgressEmitter>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ProgressEmitter {
|
||||||
|
/// Creates a new emitter.
|
||||||
|
///
|
||||||
|
/// The emitter expects to see *total* being added via [`ProgressEmitter::inc`] and will
|
||||||
|
/// emit *steps* updates.
|
||||||
|
fn new(total: u64, steps: u16) -> Self {
|
||||||
|
let (tx, _rx) = broadcast::channel(16);
|
||||||
|
Self {
|
||||||
|
inner: Arc::new(InnerProgressEmitter {
|
||||||
|
total: AtomicU64::new(total),
|
||||||
|
count: AtomicU64::new(0),
|
||||||
|
steps,
|
||||||
|
last_step: AtomicU16::new(0u16),
|
||||||
|
tx,
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets a new total in case you did not now the total up front.
|
||||||
|
fn set_total(&self, value: u64) {
|
||||||
|
self.inner.set_total(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a receiver that gets incremental values.
|
||||||
|
///
|
||||||
|
/// The values yielded depend on *steps* passed to [`ProgressEmitter::new`]: it will go
|
||||||
|
/// from `1..steps`.
|
||||||
|
fn subscribe(&self) -> broadcast::Receiver<u16> {
|
||||||
|
self.inner.subscribe()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increments the progress by *amount*.
|
||||||
|
fn inc(&self, amount: u64) {
|
||||||
|
self.inner.inc(amount);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wrap_async_read<R: AsyncRead + Unpin>(&self, read: R) -> ProgressAsyncReader<R> {
|
||||||
|
ProgressAsyncReader {
|
||||||
|
emitter: self.clone(),
|
||||||
|
inner: read,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct InnerProgressEmitter {
|
||||||
|
total: AtomicU64,
|
||||||
|
count: AtomicU64,
|
||||||
|
steps: u16,
|
||||||
|
last_step: AtomicU16,
|
||||||
|
tx: broadcast::Sender<u16>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl InnerProgressEmitter {
|
||||||
|
fn inc(&self, amount: u64) {
|
||||||
|
let prev_count = self.count.fetch_add(amount, Ordering::Relaxed);
|
||||||
|
let count = prev_count + amount;
|
||||||
|
let total = self.total.load(Ordering::Relaxed);
|
||||||
|
let step = (total * u64::from(self.steps) / std::cmp::min(count, total)) as u16;
|
||||||
|
let last_step = self.last_step.swap(step, Ordering::Relaxed);
|
||||||
|
if step > last_step {
|
||||||
|
self.tx.send(step).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_total(&self, value: u64) {
|
||||||
|
self.total.store(value, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn subscribe(&self) -> broadcast::Receiver<u16> {
|
||||||
|
self.tx.subscribe()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ProgressAsyncReader<R: AsyncRead + Unpin> {
|
||||||
|
emitter: ProgressEmitter,
|
||||||
|
inner: R,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R> AsyncRead for ProgressAsyncReader<R>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin,
|
||||||
|
{
|
||||||
|
fn poll_read(
|
||||||
|
mut self: Pin<&mut Self>,
|
||||||
|
cx: &mut std::task::Context<'_>,
|
||||||
|
buf: &mut io::ReadBuf<'_>,
|
||||||
|
) -> Poll<std::io::Result<()>> {
|
||||||
|
let prev_len = buf.filled().len() as u64;
|
||||||
|
match Pin::new(&mut self.inner).poll_read(cx, buf) {
|
||||||
|
Poll::Ready(val) => {
|
||||||
|
let new_len = buf.filled().len() as u64;
|
||||||
|
self.emitter.inc(new_len - prev_len);
|
||||||
|
Poll::Ready(val)
|
||||||
|
}
|
||||||
|
Poll::Pending => Poll::Pending,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|||||||
Reference in New Issue
Block a user