Remove the need for a directory for db export

Plus on import use the context directory.  We can actually write there
just fine.
This commit is contained in:
Floris Bruynooghe
2023-02-16 16:06:41 +01:00
parent dcce6ef50b
commit 490a14c5ef
9 changed files with 48 additions and 48 deletions

View File

@@ -22,7 +22,7 @@
//! getter can not connect to an impersonated provider and the provider does not offer the
//! download to an impersonated getter.
use std::path::Path;
use std::path::{Path, PathBuf};
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
use async_channel::Receiver;
@@ -76,21 +76,25 @@ impl BackupProvider {
/// the possible cancellation of the "ongoing process".
///
/// [`Accounts::stop_io`]: crate::accounts::Accounts::stop_io
pub async fn prepare(context: &Context, dir: &Path) -> Result<Self> {
ensure!(
// TODO: Should we worry about path normalisation?
dir != context.get_blobdir(),
"Temporary database export directory should not be in blobdir"
);
pub async fn prepare(context: &Context) -> Result<Self> {
e2ee::ensure_secret_key_exists(context)
.await
.context("Private key not available, aborting backup export")?;
// Acquire global "ongoing" mutex.
let cancel_token = context.alloc_ongoing().await?;
let context_dir = context
.get_blobdir()
.parent()
.ok_or(anyhow!("Context dir not found"))?;
let dbfile = context_dir.join(DBFILE_BACKUP_NAME);
if fs::metadata(&dbfile).await.is_ok() {
fs::remove_file(&dbfile).await?;
warn!(context, "Previous database export deleted");
}
let res = tokio::select! {
biased;
res = Self::prepare_inner(context, dir) => {
res = Self::prepare_inner(context, &dbfile) => {
match res {
Ok(slf) => Ok(slf),
Err(err) => {
@@ -112,6 +116,7 @@ impl BackupProvider {
context.clone(),
provider,
cancel_token,
dbfile,
));
Ok(Self { handle, ticket })
}
@@ -119,11 +124,10 @@ impl BackupProvider {
/// Creates the provider task.
///
/// Having this as a function makes it easier to cancel it when needed.
async fn prepare_inner(context: &Context, dir: &Path) -> Result<(Provider, Ticket)> {
async fn prepare_inner(context: &Context, dbfile: &Path) -> Result<(Provider, Ticket)> {
// Generate the token up front: we also use it to encrypt the database.
let token = AuthToken::generate();
context.emit_event(SendProgress::Started.into());
let dbfile = dir.join(DBFILE_BACKUP_NAME);
export_database(context, &dbfile, token.to_string())
.await
.context("Database export failed")?;
@@ -131,7 +135,7 @@ impl BackupProvider {
// Now we can be sure IO is not running.
let mut files = vec![DataSource::with_name(
dbfile,
dbfile.to_owned(),
format!("db/{DBFILE_BACKUP_NAME}"),
)];
let blobdir = BlobDirContents::new(context).await?;
@@ -165,6 +169,7 @@ impl BackupProvider {
context: Context,
mut provider: Provider,
cancel_token: Receiver<()>,
dbfile: PathBuf,
) -> Result<()> {
context.emit_event(SendProgress::ProviderListening.into());
let mut events = provider.subscribe();
@@ -213,7 +218,9 @@ impl BackupProvider {
},
}
};
// TODO: delete the database?
if let Err(err) = fs::remove_file(&dbfile).await {
error!(context, "Failed to remove database export: {err:#}");
}
context.emit_event(SendProgress::Completed.into());
context.free_ongoing().await;
res
@@ -361,12 +368,16 @@ async fn on_blob(
) -> Result<DataStream> {
ensure!(!name.is_empty(), "Received a nameless blob");
let path = if name.starts_with("db/") {
// We can only safely write to the blobdir. But the blobdir could have a file named
// exactly like our special name. We solve this by using an uppercase extension
// which is forbidden for normal blobs.
context
let context_dir = context
.get_blobdir()
.join(format!("{DBFILE_BACKUP_NAME}.SPECIAL"))
.parent()
.ok_or(anyhow!("Context dir not found"))?;
let dbfile = context_dir.join(DBFILE_BACKUP_NAME);
if fs::metadata(&dbfile).await.is_ok() {
fs::remove_file(&dbfile).await?;
warn!(context, "Previous database export deleted");
}
dbfile
} else {
ensure!(name.starts_with("blob/"), "malformatted blob name");
let blobname = name.rsplit('/').next().context("malformatted blob name")?;
@@ -456,8 +467,6 @@ impl From<ReceiveProgress> for EventType {
mod tests {
use std::time::Duration;
use testdir::testdir;
use crate::chat::{get_chat_msgs, send_msg, ChatItem};
use crate::message::{Message, Viewtype};
use crate::test_utils::TestContextManager;
@@ -466,7 +475,6 @@ mod tests {
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_send_receive() {
let dir = testdir!();
let mut tcm = TestContextManager::new();
// Create first device.
@@ -479,7 +487,7 @@ mod tests {
send_msg(&ctx0, self_chat.id, &mut msg).await.unwrap();
// Prepare to transfer backup.
let provider = BackupProvider::prepare(&ctx0, &dir).await.unwrap();
let provider = BackupProvider::prepare(&ctx0).await.unwrap();
// Set up second device.
let ctx1 = tcm.unconfigured().await;