From c4d07ab99ec54ab75a6a51a0ece34aff788c0efb Mon Sep 17 00:00:00 2001 From: link2xt Date: Fri, 4 Oct 2024 23:08:02 +0000 Subject: [PATCH] fix: smooth progress bar for backup transfer Before this change progress bar only started when database is already transferred. Database is usually the largest file in the whole transfer, so the transfer appears to be stuck for the sender. With this change progress bar starts for backup export as soon as connection is received and counts bytes transferred over the connection using AsyncWrite wrapper. Similarly for backup import, AsyncRead wrapper counts the bytes received and emits progress events. --- Cargo.lock | 1 + Cargo.toml | 1 + src/blob.rs | 4 - src/imex.rs | 174 ++++++++++++++++++++++++++++++++++++------- src/imex/transfer.rs | 5 +- 5 files changed, 151 insertions(+), 34 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index e8ac81203..d7012809d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1315,6 +1315,7 @@ dependencies = [ "parking_lot", "percent-encoding", "pgp", + "pin-project", "pretty_assertions", "proptest", "qrcodegen", diff --git a/Cargo.toml b/Cargo.toml index 056070a00..7ec378b43 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,6 +77,7 @@ once_cell = { workspace = true } parking_lot = "0.12" percent-encoding = "2.3" pgp = { version = "0.13.2", default-features = false } +pin-project = "1" qrcodegen = "1.7.0" quick-xml = "0.36" quoted_printable = "0.5" diff --git a/src/blob.rs b/src/blob.rs index 320cc011e..6c0c03209 100644 --- a/src/blob.rs +++ b/src/blob.rs @@ -666,10 +666,6 @@ impl<'a> BlobDirContents<'a> { pub(crate) fn iter(&self) -> BlobDirIter<'_> { BlobDirIter::new(self.context, self.inner.iter()) } - - pub(crate) fn len(&self) -> usize { - self.inner.len() - } } /// A iterator over all the [`BlobObject`]s in the blobdir. diff --git a/src/imex.rs b/src/imex.rs index 89273a7a5..f99115cb6 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -2,13 +2,16 @@ use std::ffi::OsStr; use std::path::{Path, PathBuf}; +use std::pin::Pin; use ::pgp::types::KeyTrait; use anyhow::{bail, ensure, format_err, Context as _, Result}; use futures::TryStreamExt; use futures_lite::FutureExt; +use pin_project::pin_project; use tokio::fs::{self, File}; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio_tar::Archive; use crate::blob::BlobDirContents; @@ -212,7 +215,7 @@ async fn imex_inner( path.display() ); ensure!(context.sql.is_open().await, "Database not opened."); - context.emit_event(EventType::ImexProgress(10)); + context.emit_event(EventType::ImexProgress(1)); if what == ImexMode::ExportBackup || what == ImexMode::ExportSelfKeys { // before we export anything, make sure the private key exists @@ -294,12 +297,71 @@ pub(crate) async fn import_backup_stream( .0 } +/// Reader that emits progress events as bytes are read from it. +#[pin_project] +struct ProgressReader { + /// Wrapped reader. + #[pin] + inner: R, + + /// Number of bytes successfully read from the internal reader. + read: usize, + + /// Total size of the backup .tar file expected to be read from the reader. + /// Used to calculate the progress. + file_size: usize, + + /// Last progress emitted to avoid emitting the same progress value twice. + last_progress: usize, + + /// Context for emitting progress events. + context: Context, +} + +impl ProgressReader { + fn new(r: R, context: Context, file_size: u64) -> Self { + Self { + inner: r, + read: 0, + file_size: file_size as usize, + last_progress: 1, + context, + } + } +} + +impl AsyncRead for ProgressReader +where + R: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> std::task::Poll> { + let this = self.project(); + let before = buf.filled().len(); + let res = this.inner.poll_read(cx, buf); + if let std::task::Poll::Ready(Ok(())) = res { + *this.read = this.read.saturating_add(buf.filled().len() - before); + + let progress = std::cmp::min(1000 * *this.read / *this.file_size, 999); + if progress > *this.last_progress { + this.context.emit_event(EventType::ImexProgress(progress)); + *this.last_progress = progress; + } + } + res + } +} + async fn import_backup_stream_inner( context: &Context, backup_file: R, file_size: u64, passphrase: String, ) -> (Result<()>,) { + let backup_file = ProgressReader::new(backup_file, context.clone(), file_size); let mut archive = Archive::new(backup_file); let mut entries = match archive.entries() { @@ -307,29 +369,12 @@ async fn import_backup_stream_inner( Err(e) => return (Err(e).context("Failed to get archive entries"),), }; let mut blobs = Vec::new(); - // We already emitted ImexProgress(10) above - let mut last_progress = 10; - const PROGRESS_MIGRATIONS: u128 = 999; - let mut total_size: u64 = 0; let mut res: Result<()> = loop { let mut f = match entries.try_next().await { Ok(Some(f)) => f, Ok(None) => break Ok(()), Err(e) => break Err(e).context("Failed to get next entry"), }; - total_size += match f.header().entry_size() { - Ok(size) => size, - Err(e) => break Err(e).context("Failed to get entry size"), - }; - let max = PROGRESS_MIGRATIONS - 1; - let progress = std::cmp::min( - max * u128::from(total_size) / std::cmp::max(u128::from(file_size), 1), - max, - ); - if progress > last_progress { - context.emit_event(EventType::ImexProgress(progress as usize)); - last_progress = progress; - } let path = match f.path() { Ok(path) => path.to_path_buf(), @@ -379,7 +424,7 @@ async fn import_backup_stream_inner( .log_err(context) .ok(); if res.is_ok() { - context.emit_event(EventType::ImexProgress(PROGRESS_MIGRATIONS as usize)); + context.emit_event(EventType::ImexProgress(999)); res = context.sql.run_migrations(context).await; } if res.is_ok() { @@ -452,7 +497,14 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res let file = File::create(&temp_path).await?; let blobdir = BlobDirContents::new(context).await?; - export_backup_stream(context, &temp_db_path, blobdir, file) + + let mut file_size = 0; + file_size += temp_db_path.metadata()?.len(); + for blob in blobdir.iter() { + file_size += blob.to_abs_path().metadata()?.len() + } + + export_backup_stream(context, &temp_db_path, blobdir, file, file_size) .await .context("Exporting backup to file failed")?; fs::rename(temp_path, &dest_path).await?; @@ -460,33 +512,99 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res Ok(()) } +/// Writer that emits progress events as bytes are written into it. +#[pin_project] +struct ProgressWriter { + /// Wrapped writer. + #[pin] + inner: W, + + /// Number of bytes successfully written into the internal writer. + written: usize, + + /// Total size of the backup .tar file expected to be written into the writer. + /// Used to calculate the progress. + file_size: usize, + + /// Last progress emitted to avoid emitting the same progress value twice. + last_progress: usize, + + /// Context for emitting progress events. + context: Context, +} + +impl ProgressWriter { + fn new(w: W, context: Context, file_size: u64) -> Self { + Self { + inner: w, + written: 0, + file_size: file_size as usize, + last_progress: 1, + context, + } + } +} + +impl AsyncWrite for ProgressWriter +where + W: AsyncWrite, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + let this = self.project(); + let res = this.inner.poll_write(cx, buf); + if let std::task::Poll::Ready(Ok(written)) = res { + *this.written = this.written.saturating_add(written); + + let progress = std::cmp::min(1000 * *this.written / *this.file_size, 999); + if progress > *this.last_progress { + this.context.emit_event(EventType::ImexProgress(progress)); + *this.last_progress = progress; + } + } + res + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_flush(cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + self.project().inner.poll_shutdown(cx) + } +} + /// Exports the database and blobs into a stream. pub(crate) async fn export_backup_stream<'a, W>( context: &'a Context, temp_db_path: &Path, blobdir: BlobDirContents<'a>, writer: W, + file_size: u64, ) -> Result<()> where W: tokio::io::AsyncWrite + tokio::io::AsyncWriteExt + Unpin + Send + 'static, { + let writer = ProgressWriter::new(writer, context.clone(), file_size); let mut builder = tokio_tar::Builder::new(writer); builder .append_path_with_name(temp_db_path, DBFILE_BACKUP_NAME) .await?; - let mut last_progress = 10; - - for (i, blob) in blobdir.iter().enumerate() { + for blob in blobdir.iter() { let mut file = File::open(blob.to_abs_path()).await?; let path_in_archive = PathBuf::from(BLOBS_BACKUP_NAME).join(blob.as_name()); builder.append_file(path_in_archive, &mut file).await?; - let progress = std::cmp::min(1000 * i / blobdir.len(), 999); - if progress > last_progress { - context.emit_event(EventType::ImexProgress(progress)); - last_progress = progress; - } } builder.finish().await?; diff --git a/src/imex/transfer.rs b/src/imex/transfer.rs index b7743771e..15f87c794 100644 --- a/src/imex/transfer.rs +++ b/src/imex/transfer.rs @@ -109,6 +109,7 @@ impl BackupProvider { .get_blobdir() .parent() .context("Context dir not found")?; + let dbfile = context_dir.join(DBFILE_BACKUP_NAME); if fs::metadata(&dbfile).await.is_ok() { fs::remove_file(&dbfile).await?; @@ -124,7 +125,6 @@ impl BackupProvider { export_database(context, &dbfile, passphrase, time()) .await .context("Database export failed")?; - context.emit_event(EventType::ImexProgress(300)); let drop_token = CancellationToken::new(); let handle = { @@ -178,6 +178,7 @@ impl BackupProvider { } info!(context, "Received valid backup authentication token."); + context.emit_event(EventType::ImexProgress(1)); let blobdir = BlobDirContents::new(&context).await?; @@ -189,7 +190,7 @@ impl BackupProvider { send_stream.write_all(&file_size.to_be_bytes()).await?; - export_backup_stream(&context, &dbfile, blobdir, send_stream) + export_backup_stream(&context, &dbfile, blobdir, send_stream, file_size) .await .context("Failed to write backup into QUIC stream")?; info!(context, "Finished writing backup into QUIC stream.");