handle the database

This commit is contained in:
Floris Bruynooghe
2023-02-02 17:43:12 +01:00
parent 5f29b93970
commit 3267596a30
3 changed files with 28 additions and 6 deletions

View File

@@ -14,6 +14,8 @@ use crate::summary::{Summary, SummaryPrefix};
/// eg. by chatlist.get_summary() or dc_msg_get_summary(). /// eg. by chatlist.get_summary() or dc_msg_get_summary().
/// ///
/// *Lot* is used in the meaning *heap* here. /// *Lot* is used in the meaning *heap* here.
// The QR code grew too large. So be it.
#[allow(clippy::large_enum_variant)]
#[derive(Debug)] #[derive(Debug)]
pub enum Lot { pub enum Lot {
Summary(Summary), Summary(Summary),

View File

@@ -242,7 +242,8 @@ impl<'a> BlobObject<'a> {
/// including the dot. E.g. "foo.txt" is returned as `("foo", /// including the dot. E.g. "foo.txt" is returned as `("foo",
/// ".txt")` while "bar" is returned as `("bar", "")`. /// ".txt")` while "bar" is returned as `("bar", "")`.
/// ///
/// The extension part will always be lowercased. /// The extension part will always be lowercased. Note that [`imex::transfer`] relies
/// on this for safety, if uppercase extensions are ever allowed it needs to be adapted.
fn sanitise_name(name: &str) -> (String, String) { fn sanitise_name(name: &str) -> (String, String) {
let mut name = name.to_string(); let mut name = name.to_string();
for part in name.rsplit('/') { for part in name.rsplit('/') {

View File

@@ -33,7 +33,7 @@ use sendme::blobs::Collection;
use sendme::get::{AsyncSliceDecoder, Hash, Options, ReceiveStream}; use sendme::get::{AsyncSliceDecoder, Hash, Options, ReceiveStream};
use sendme::protocol::AuthToken; use sendme::protocol::AuthToken;
use sendme::provider::{DataSource, Provider, Ticket}; use sendme::provider::{DataSource, Provider, Ticket};
use tokio::fs::File; use tokio::fs::{self, File};
use tokio::io::{self, BufWriter}; use tokio::io::{self, BufWriter};
use super::{export_database, BlobDirContents, DBFILE_BACKUP_NAME}; use super::{export_database, BlobDirContents, DBFILE_BACKUP_NAME};
@@ -184,7 +184,7 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
addr: ticket.addr, addr: ticket.addr,
peer_id: Some(ticket.peer), peer_id: Some(ticket.peer),
}; };
let on_blob = |hash, reader, name| on_blob(context, hash, reader, name); let on_blob = |hash, reader, name| on_blob(context, &ticket, hash, reader, name);
sendme::get::run( sendme::get::run(
ticket.hash, ticket.hash,
ticket.token, ticket.token,
@@ -194,16 +194,17 @@ pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
on_blob, on_blob,
) )
.await?; .await?;
Ok(())
todo!();
} }
/// Get callback when the connection is established with the provider. /// Get callback when the connection is established with the provider.
#[allow(clippy::unused_async)]
async fn on_connected() -> Result<()> { async fn on_connected() -> Result<()> {
Ok(()) Ok(())
} }
/// Get callback when a collection is received from the provider. /// Get callback when a collection is received from the provider.
#[allow(clippy::unused_async)]
async fn on_collection(_collection: Collection) -> Result<()> { async fn on_collection(_collection: Collection) -> Result<()> {
Ok(()) Ok(())
} }
@@ -211,14 +212,32 @@ async fn on_collection(_collection: Collection) -> Result<()> {
/// Get callback when a blob is received from the provider. /// Get callback when a blob is received from the provider.
async fn on_blob( async fn on_blob(
context: &Context, context: &Context,
ticket: &Ticket,
_hash: Hash, _hash: Hash,
mut reader: AsyncSliceDecoder<ReceiveStream>, mut reader: AsyncSliceDecoder<ReceiveStream>,
name: String, name: String,
) -> Result<AsyncSliceDecoder<ReceiveStream>> { ) -> Result<AsyncSliceDecoder<ReceiveStream>> {
ensure!(!name.is_empty(), "Received a nameless blob"); ensure!(!name.is_empty(), "Received a nameless blob");
let path = context.get_blobdir().join(name); let path = if name == DBFILE_BACKUP_NAME {
// We can only safely write to the blobdir. But the blobdir could have a file named
// exactly like our special name. We solve this by using an uppercase extension
// which is forbidden for normal blobs.
context.get_blobdir().join(format!("{name}.SPECIAL"))
} else {
context.get_blobdir().join(&name)
};
let file = File::create(&path).await?; let file = File::create(&path).await?;
let mut file = BufWriter::with_capacity(128 * 1024, file); let mut file = BufWriter::with_capacity(128 * 1024, file);
io::copy(&mut reader, &mut file).await?; io::copy(&mut reader, &mut file).await?;
if name == DBFILE_BACKUP_NAME {
context
.sql
.import(&path, ticket.token.to_string())
.await
.context("cannot import database")?;
fs::remove_file(&path)
.await
.with_context(|| format!("Database file: {}", path.display()))?;
}
Ok(reader) Ok(reader)
} }