diff --git a/src/imex.rs b/src/imex.rs index 17bf631d8..b03c9d5ee 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -292,62 +292,99 @@ pub(crate) async fn import_backup_stream( file_size: u64, passphrase: String, ) -> Result<()> { + import_backup_stream_inner(context, backup_file, file_size, passphrase) + .await + .0 +} + +async fn import_backup_stream_inner( + context: &Context, + backup_file: R, + file_size: u64, + passphrase: String, +) -> (Result<()>,) { let mut archive = Archive::new(backup_file); - let mut entries = archive.entries()?; - + let mut entries = match archive.entries() { + Ok(entries) => entries, + Err(e) => return (Err(e).context("Failed to get archive entries"),), + }; + let mut blobs = Vec::new(); // We already emitted ImexProgress(10) above let mut last_progress = 10; + const PROGRESS_MIGRATIONS: u64 = 999; let mut total_size = 0; - while let Some(mut f) = entries - .try_next() - .await - .context("Failed to get next entry")? - { - total_size += f.header().entry_size()?; + let mut res: Result<()> = loop { + let mut f = match entries.try_next().await { + Ok(Some(f)) => f, + Ok(None) => break Ok(()), + Err(e) => break Err(e).context("Failed to get next entry"), + }; + total_size += match f.header().entry_size() { + Ok(size) => size, + Err(e) => break Err(e).context("Failed to get entry size"), + }; let progress = std::cmp::min( 1000 * total_size.checked_div(file_size).unwrap_or_default(), - 999, + PROGRESS_MIGRATIONS - 1, ); if progress > last_progress { context.emit_event(EventType::ImexProgress(progress as usize)); last_progress = progress; } - if f.path()?.file_name() == Some(OsStr::new(DBFILE_BACKUP_NAME)) { - // async_tar can't unpack to a specified file name, so we just unpack to the blobdir and then move the unpacked file. - f.unpack_in(context.get_blobdir()) - .await - .context("Failed to unpack database")?; - let unpacked_database = context.get_blobdir().join(DBFILE_BACKUP_NAME); - context - .sql - .import(&unpacked_database, passphrase.clone()) - .await - .context("cannot import unpacked database")?; - fs::remove_file(unpacked_database) - .await - .context("cannot remove unpacked database")?; - } else { - // async_tar will unpack to blobdir/BLOBS_BACKUP_NAME, so we move the file afterwards. - f.unpack_in(context.get_blobdir()) - .await - .context("Failed to unpack blob")?; - let from_path = context.get_blobdir().join(f.path()?); - if from_path.is_file() { - if let Some(name) = from_path.file_name() { - fs::rename(&from_path, context.get_blobdir().join(name)).await?; - } else { - warn!(context, "No file name"); + let path = match f.path() { + Ok(path) => path.to_path_buf(), + Err(e) => break Err(e).context("Failed to get entry path"), + }; + if let Err(e) = f.unpack_in(context.get_blobdir()).await { + break Err(e).context("Failed to unpack file"); + } + if path.file_name() == Some(OsStr::new(DBFILE_BACKUP_NAME)) { + continue; + } + // async_tar unpacked to $BLOBDIR/BLOBS_BACKUP_NAME/, so we move the file afterwards. + let from_path = context.get_blobdir().join(&path); + if from_path.is_file() { + if let Some(name) = from_path.file_name() { + let to_path = context.get_blobdir().join(name); + if let Err(e) = fs::rename(&from_path, &to_path).await { + blobs.push(from_path); + break Err(e).context("Failed to move file to blobdir"); } + blobs.push(to_path); + } else { + warn!(context, "No file name"); } } + }; + if res.is_err() { + for blob in blobs { + fs::remove_file(&blob).await.log_err(context).ok(); + } } - context.sql.run_migrations(context).await?; - delete_and_reset_all_device_msgs(context).await?; - - Ok(()) + let unpacked_database = context.get_blobdir().join(DBFILE_BACKUP_NAME); + if res.is_ok() { + res = context + .sql + .import(&unpacked_database, passphrase.clone()) + .await + .context("cannot import unpacked database"); + } + fs::remove_file(unpacked_database) + .await + .context("cannot remove unpacked database") + .log_err(context) + .ok(); + if res.is_ok() { + context.emit_event(EventType::ImexProgress(PROGRESS_MIGRATIONS as usize)); + res = context.sql.run_migrations(context).await; + } + if res.is_ok() { + res = delete_and_reset_all_device_msgs(context).await; + } + (res,) } /*******************************************************************************