Files
chatmail-core/src/imex/transfer.rs
dignifiedquire 438cf0d953 happy clippy
2023-07-25 21:33:04 +02:00

781 lines
28 KiB
Rust

//! Transfer a backup to an other device.
//!
//! This module provides support for using n0's iroh tool to initiate transfer of a backup
//! to another device using a QR code.
//!
//! Using the iroh terminology there are two parties to this:
//!
//! - The *Provider*, which starts a server and listens for connections.
//! - The *Getter*, which connects to the server and retrieves the data.
//!
//! Iroh is designed around the idea of verifying hashes, the downloads are verified as
//! they are retrieved. The entire transfer is initiated by requesting the data of a single
//! root hash.
//!
//! Both the provider and the getter are authenticated:
//!
//! - The provider is known by its *peer ID*.
//! - The provider needs an *authentication token* from the getter before it accepts a
//! connection.
//!
//! Both these are transferred in the QR code offered to the getter. This ensures that the
//! getter can not connect to an impersonated provider and the provider does not offer the
//! download to an impersonated getter.
use std::future::Future;
use std::net::Ipv4Addr;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::Poll;
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
use async_channel::Receiver;
use bytes::Bytes;
use futures::FutureExt;
use futures_lite::StreamExt;
use iroh::bytes::get::{fsm, Stats};
use iroh::bytes::protocol::{AnyGetRequest, GetRequest, RequestToken};
use iroh::bytes::provider::Event as ProviderEvent;
use iroh::bytes::util::runtime;
use iroh::collection::{Collection, IrohCollectionParser};
use iroh::database::flat::DataSource;
use iroh::dial::Ticket;
use iroh::net::tls::Keypair;
use iroh::node::{Event, Node as IrohNode, StaticTokenAuthHandler};
use iroh::util::progress::ProgressEmitter;
use iroh_io::AsyncSliceWriter;
use tokio::fs;
use tokio::sync::{
broadcast::{self, error::RecvError},
Mutex,
};
use tokio::task::{JoinHandle, JoinSet};
use tokio_stream::wrappers::ReadDirStream;
use tokio_util::sync::CancellationToken;
use crate::blob::BlobDirContents;
use crate::chat::{add_device_msg, delete_and_reset_all_device_msgs};
use crate::context::Context;
use crate::message::{Message, Viewtype};
use crate::qr::{self, Qr};
use crate::stock_str::backup_transfer_msg_body;
use crate::{e2ee, EventType};
use super::{export_database, DBFILE_BACKUP_NAME};
type Node = IrohNode<iroh::database::flat::Database>;
/// Provide or send a backup of this device.
///
/// This creates a backup of the current device and starts a service which offers another
/// device to download this backup.
///
/// This does not make a full backup on disk, only the SQLite database is created on disk,
/// the blobs in the blob directory are not copied.
///
/// This starts a task which acquires the global "ongoing" mutex. If you need to stop the
/// task use the [`Context::stop_ongoing`] mechanism.
///
/// The task implements [`Future`] and awaiting it will complete once a transfer has been
/// either completed or aborted.
#[derive(Debug)]
pub struct BackupProvider {
/// The supervisor task, run by [`BackupProvider::watch_provider`].
handle: JoinHandle<Result<()>>,
/// The ticket to retrieve the backup collection.
ticket: Ticket,
/// Guard to cancel the provider on drop.
_drop_guard: tokio_util::sync::DropGuard,
}
impl BackupProvider {
/// Prepares for sending a backup to a second device.
///
/// Before calling this function all I/O must be stopped so that no changes to the blobs
/// or database are happening, this is done by calling the [`Accounts::stop_io`] or
/// [`Context::stop_io`] APIs first.
///
/// This will acquire the global "ongoing process" mutex, which can be used to cancel
/// the process.
///
/// [`Accounts::stop_io`]: crate::accounts::Accounts::stop_io
pub async fn prepare(context: &Context) -> Result<Self> {
e2ee::ensure_secret_key_exists(context)
.await
.context("Private key not available, aborting backup export")?;
// Acquire global "ongoing" mutex.
let cancel_token = context.alloc_ongoing().await?;
let paused_guard = context.scheduler.pause(context.clone()).await?;
let context_dir = context
.get_blobdir()
.parent()
.ok_or_else(|| anyhow!("Context dir not found"))?;
let dbfile = context_dir.join(DBFILE_BACKUP_NAME);
if fs::metadata(&dbfile).await.is_ok() {
fs::remove_file(&dbfile).await?;
warn!(context, "Previous database export deleted");
}
let dbfile = TempPathGuard::new(dbfile);
let res = tokio::select! {
biased;
res = Self::prepare_inner(context, &dbfile) => {
match res {
Ok(slf) => Ok(slf),
Err(err) => {
error!(context, "Failed to set up second device setup: {:#}", err);
Err(err)
},
}
},
_ = cancel_token.recv() => Err(format_err!("cancelled")),
};
let (provider, ticket) = match res {
Ok((provider, ticket)) => (provider, ticket),
Err(err) => {
context.free_ongoing().await;
return Err(err);
}
};
let drop_token = CancellationToken::new();
let handle = {
let context = context.clone();
let drop_token = drop_token.clone();
tokio::spawn(async move {
let res = Self::watch_provider(&context, provider, cancel_token, drop_token).await;
context.free_ongoing().await;
// Explicit drop to move the guards into this future
drop(paused_guard);
drop(dbfile);
res
})
};
Ok(Self {
handle,
ticket,
_drop_guard: drop_token.drop_guard(),
})
}
/// Creates the provider task.
///
/// Having this as a function makes it easier to cancel it when needed.
async fn prepare_inner(context: &Context, dbfile: &Path) -> Result<(Node, Ticket)> {
// Generate the token up front: we also use it to encrypt the database.
let token = RequestToken::generate();
context.emit_event(SendProgress::Started.into());
export_database(context, dbfile, token.to_string())
.await
.context("Database export failed")?;
context.emit_event(SendProgress::DatabaseExported.into());
// Now we can be sure IO is not running.
let mut files = vec![DataSource::with_name(
dbfile.to_owned(),
format!("db/{DBFILE_BACKUP_NAME}"),
)];
let blobdir = BlobDirContents::new(context).await?;
for blob in blobdir.iter() {
let path = blob.to_abs_path();
let name = format!("blob/{}", blob.as_file_name());
files.push(DataSource::with_name(path, name));
}
// Start listening.
let (db, hash) = iroh::database::flat::create_collection(files).await?;
context.emit_event(SendProgress::CollectionCreated.into());
let auth_token_handler = StaticTokenAuthHandler::new(Some(token.clone()));
let rt = runtime::Handle::from_currrent(1)?;
let provider = Node::builder(db)
.bind_addr((Ipv4Addr::UNSPECIFIED, 0).into())
.custom_auth_handler(Arc::new(auth_token_handler))
.collection_parser(IrohCollectionParser)
.runtime(&rt)
.spawn()
.await?;
context.emit_event(SendProgress::ProviderListening.into());
info!(context, "Waiting for remote to connect");
let ticket = provider.ticket(hash).await?.with_token(Some(token));
Ok((provider, ticket))
}
/// Supervises the iroh [`Node`], terminating it when needed.
///
/// This will watch the provider and terminate it when:
///
/// - A transfer is completed, successful or unsuccessful.
/// - An event could not be observed to protect against not knowing of a completed event.
/// - The ongoing process is cancelled.
///
/// The *cancel_token* is the handle for the ongoing process mutex, when this completes
/// we must cancel this operation.
async fn watch_provider(
context: &Context,
mut provider: Node,
cancel_token: Receiver<()>,
drop_token: CancellationToken,
) -> Result<()> {
let total_size = Arc::new(AtomicU64::new(0));
let current_size = Arc::new(AtomicU64::new(0));
let (transfer_done, mut transfer_done_r) = tokio::sync::mpsc::channel(1);
let ctx = context.clone();
provider
.subscribe(move |event| {
let total_size = total_size.clone();
let current_size = current_size.clone();
let transfer_done = transfer_done.clone();
let context = ctx.clone();
async move {
match event {
Event::ByteProvide(event) => match event {
ProviderEvent::ClientConnected { .. } => {
context.emit_event(SendProgress::ClientConnected.into());
}
ProviderEvent::GetRequestReceived { .. } => {}
ProviderEvent::TransferCollectionStarted {
total_blobs_size, ..
} => {
total_size
.store(total_blobs_size.unwrap_or_default(), Ordering::Relaxed);
context.emit_event(
SendProgress::TransferInProgress {
current_size: current_size.load(Ordering::Relaxed),
total_size: total_size.load(Ordering::Relaxed),
}
.into(),
);
}
ProviderEvent::TransferBlobCompleted { size, .. } => {
current_size.fetch_add(size, Ordering::Relaxed);
context.emit_event(
SendProgress::TransferInProgress {
current_size: current_size.load(Ordering::Relaxed),
total_size: total_size.load(Ordering::Relaxed),
}
.into(),
);
}
ProviderEvent::TransferCollectionCompleted { .. } => {
let total_size = total_size.load(Ordering::Relaxed);
context.emit_event(
SendProgress::TransferInProgress {
current_size: total_size,
total_size,
}
.into(),
);
transfer_done.send(()).await.ok();
}
ProviderEvent::TransferAborted { .. } => {
transfer_done.send(()).await.ok();
}
ProviderEvent::CollectionAdded { .. } => {}
ProviderEvent::CustomGetRequestReceived { .. } => {}
},
}
}
.boxed()
})
.await?;
let res = loop {
tokio::select! {
biased;
res = &mut provider => {
break res.context("BackupProvider failed");
}
_ = cancel_token.recv() => {
provider.shutdown();
break Err(anyhow!("BackupProvider cancelled"));
},
_ = drop_token.cancelled() => {
provider.shutdown();
break Err(anyhow!("BackupProvider dropped"));
}
_ = transfer_done_r.recv() => {
provider.shutdown();
break Ok(());
}
}
};
match &res {
Ok(_) => {
context.emit_event(SendProgress::Completed.into());
let mut msg = Message::new(Viewtype::Text);
msg.text = backup_transfer_msg_body(context).await;
add_device_msg(context, None, Some(&mut msg)).await?;
}
Err(err) => {
error!(context, "Backup transfer failure: {err:#}");
context.emit_event(SendProgress::Failed.into())
}
}
res
}
/// Returns a QR code that allows fetching this backup.
///
/// This QR code can be passed to [`get_backup`] on a (different) device.
pub fn qr(&self) -> Qr {
Qr::Backup {
ticket: self.ticket.clone(),
}
}
}
impl Future for BackupProvider {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.handle).poll(cx)?
}
}
/// A guard which will remove the path when dropped.
///
/// It implements [`Deref`] it it can be used as a `&Path`.
#[derive(Debug)]
struct TempPathGuard {
path: PathBuf,
}
impl TempPathGuard {
fn new(path: PathBuf) -> Self {
Self { path }
}
}
impl Drop for TempPathGuard {
fn drop(&mut self) {
let path = self.path.clone();
tokio::spawn(async move {
fs::remove_file(&path).await.ok();
});
}
}
impl Deref for TempPathGuard {
type Target = Path;
fn deref(&self) -> &Self::Target {
&self.path
}
}
/// Create [`EventType::ImexProgress`] events using readable names.
///
/// Plus you get warnings if you don't use all variants.
#[derive(Debug)]
enum SendProgress {
Failed,
Started,
DatabaseExported,
CollectionCreated,
ProviderListening,
ClientConnected,
TransferInProgress { current_size: u64, total_size: u64 },
Completed,
}
impl From<SendProgress> for EventType {
fn from(source: SendProgress) -> Self {
use SendProgress::*;
let num: u16 = match source {
Failed => 0,
Started => 100,
DatabaseExported => 300,
CollectionCreated => 350,
ProviderListening => 400,
ClientConnected => 450,
TransferInProgress {
current_size,
total_size,
} => {
// the range is 450..=950
450 + ((current_size as f64 / total_size as f64) * 500.).floor() as u16
}
Completed => 1000,
};
Self::ImexProgress(num.into())
}
}
/// Contacts a backup provider and receives the backup from it.
///
/// This uses a QR code to contact another instance of deltachat which is providing a backup
/// using the [`BackupProvider`]. Once connected it will authenticate using the secrets in
/// the QR code and retrieve the backup.
///
/// This is a long running operation which will only when completed.
///
/// Using [`Qr`] as argument is a bit odd as it only accepts one specific variant of it. It
/// does avoid having [`iroh::dial::Ticket`] in the primary API however, without
/// having to revert to untyped bytes.
pub async fn get_backup(context: &Context, qr: Qr) -> Result<()> {
ensure!(
matches!(qr, Qr::Backup { .. }),
"QR code for backup must be of type DCBACKUP"
);
ensure!(
!context.is_configured().await?,
"Cannot import backups to accounts in use."
);
// Acquire global "ongoing" mutex.
let cancel_token = context.alloc_ongoing().await?;
let _guard = context.scheduler.pause(context.clone()).await;
info!(
context,
"Running get_backup for {}",
qr::format_backup(&qr)?
);
let res = tokio::select! {
biased;
res = get_backup_inner(context, qr) => res,
_ = cancel_token.recv() => Err(format_err!("cancelled")),
};
context.free_ongoing().await;
res
}
async fn get_backup_inner(context: &Context, qr: Qr) -> Result<()> {
let ticket = match qr {
Qr::Backup { ticket } => ticket,
_ => bail!("QR code for backup must be of type DCBACKUP"),
};
match transfer_from_provider(context, &ticket).await {
Ok(()) => {
context.sql.run_migrations(context).await?;
delete_and_reset_all_device_msgs(context).await?;
context.emit_event(ReceiveProgress::Completed.into());
Ok(())
}
Err(err) => {
// Clean up any blobs we already wrote.
let readdir = fs::read_dir(context.get_blobdir()).await?;
let mut readdir = ReadDirStream::new(readdir);
while let Some(dirent) = readdir.next().await {
if let Ok(dirent) = dirent {
fs::remove_file(dirent.path()).await.ok();
}
}
context.emit_event(ReceiveProgress::Failed.into());
Err(err)
}
}
}
async fn transfer_from_provider(context: &Context, ticket: &Ticket) -> Result<()> {
let progress = ProgressEmitter::new(0, ReceiveProgress::max_blob_progress());
spawn_progress_proxy(context.clone(), progress.subscribe());
let jobs = Mutex::new(JoinSet::default());
// Perform the transfer.
let stats = run_get_request(context, &progress, &jobs, ticket.clone()).await?;
let mut jobs = jobs.lock().await;
while let Some(job) = jobs.join_next().await {
job.context("job failed")?;
}
drop(progress);
info!(
context,
"Backup transfer finished, transfer rate was {} Mbps.",
stats.mbits()
);
Ok(())
}
/// Run the get request
async fn run_get_request(
context: &Context,
progress: &ProgressEmitter,
jobs: &Mutex<JoinSet<()>>,
ticket: Ticket,
) -> anyhow::Result<Stats> {
// DERP usage for NAT traversal and relay are currently disabled.
let derp_map = None;
let opts = ticket.as_get_options(Keypair::generate(), derp_map);
let request =
AnyGetRequest::Get(GetRequest::all(ticket.hash())).with_token(ticket.token().cloned());
let connection = iroh::dial::dial(opts).await?;
let initial = fsm::start(connection, request);
let connected = initial.next().await?;
context.emit_event(ReceiveProgress::Connected.into());
let rt = runtime::Handle::from_currrent(1)?;
// we assume that the request includes the entire collection
let (mut next, _root, collection) = {
let fsm::ConnectedNext::StartRoot(sc) = connected.next().await? else {
bail!("request did not include collection");
};
let (done, data) = sc.next().concatenate_into_vec().await?;
let data = Bytes::from(data);
let collection = Collection::from_bytes(&data)?;
context.emit_event(ReceiveProgress::CollectionReceived.into());
progress.set_total(collection.total_blobs_size());
(done.next(), data, collection)
};
// download all the children
let mut blobs = collection.blobs().iter();
let finishing = loop {
let start = match next {
fsm::EndBlobNext::MoreChildren(start) => start,
fsm::EndBlobNext::Closing(finishing) => break finishing,
};
// get the hash of the next blob, or finish if there are no more
let Some(blob) = blobs.next() else {
break start.finish();
};
let start = start.next(blob.hash);
let done = on_blob(context, &rt, jobs, &ticket, start, &blob.name).await?;
next = done.next();
};
let stats = finishing.next().await?;
Ok(stats)
}
/// Get callback when a blob is received from the provider.
///
/// This writes the blobs to the blobdir. If the blob is the database it will import it to
/// the database of the current [`Context`].
async fn on_blob(
context: &Context,
rt: &runtime::Handle,
jobs: &Mutex<JoinSet<()>>,
ticket: &Ticket,
state: fsm::AtBlobHeader,
name: &str,
) -> Result<fsm::AtEndBlob> {
ensure!(!name.is_empty(), "Received a nameless blob");
let path = if name.starts_with("db/") {
let context_dir = context
.get_blobdir()
.parent()
.ok_or_else(|| anyhow!("Context dir not found"))?;
let dbfile = context_dir.join(DBFILE_BACKUP_NAME);
if fs::metadata(&dbfile).await.is_ok() {
fs::remove_file(&dbfile).await?;
warn!(context, "Previous database export deleted");
}
dbfile
} else {
ensure!(name.starts_with("blob/"), "malformatted blob name");
let blobname = name.rsplit('/').next().context("malformatted blob name")?;
context.get_blobdir().join(blobname)
};
// `iroh_io` io needs to be done on a local spawn
let file_path = path.clone();
let done = rt
.local_pool()
.spawn_pinned(move || {
let file_path = file_path.clone();
Box::pin(async move {
let mut file =
iroh_io::File::create(move || std::fs::File::create(&file_path)).await?;
// TODO: ProgressEmitter doesn't support writers :(
// let mut wrapped_file = progress.wrap_async_write(&mut file);
let done = state.write_all(&mut file).await?;
file.sync().await?;
anyhow::Ok(done)
})
})
.await??;
if name.starts_with("db/") {
let context = context.clone();
let token = ticket.token().map(|t| t.to_string()).unwrap_or_default();
jobs.lock().await.spawn(async move {
if let Err(err) = context.sql.import(&path, token).await {
error!(context, "cannot import database: {:#?}", err);
}
if let Err(err) = fs::remove_file(&path).await {
error!(
context,
"failed to delete database import file '{}': {:#?}",
path.display(),
err,
);
}
});
}
Ok(done)
}
/// Spawns a task proxying progress events.
///
/// This spawns a tokio task which receives events from the [`ProgressEmitter`] and sends
/// them to the context. The task finishes when the emitter is dropped.
///
/// This could be done directly in the emitter by making it less generic.
fn spawn_progress_proxy(context: Context, mut rx: broadcast::Receiver<u16>) {
tokio::spawn(async move {
loop {
match rx.recv().await {
Ok(step) => context.emit_event(ReceiveProgress::BlobProgress(step).into()),
Err(RecvError::Closed) => break,
Err(RecvError::Lagged(_)) => continue,
}
}
});
}
/// Create [`EventType::ImexProgress`] events using readable names.
///
/// Plus you get warnings if you don't use all variants.
#[derive(Debug)]
enum ReceiveProgress {
Connected,
CollectionReceived,
/// A value between 0 and 85 interpreted as a percentage.
///
/// Other values are already used by the other variants of this enum.
BlobProgress(u16),
Completed,
Failed,
}
impl ReceiveProgress {
/// The maximum value for [`ReceiveProgress::BlobProgress`].
///
/// This only exists to keep this magic value local in this type.
fn max_blob_progress() -> u16 {
85
}
}
impl From<ReceiveProgress> for EventType {
fn from(source: ReceiveProgress) -> Self {
let val = match source {
ReceiveProgress::Connected => 50,
ReceiveProgress::CollectionReceived => 100,
ReceiveProgress::BlobProgress(val) => 100 + 10 * val,
ReceiveProgress::Completed => 1000,
ReceiveProgress::Failed => 0,
};
EventType::ImexProgress(val.into())
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use crate::chat::{get_chat_msgs, send_msg, ChatItem};
use crate::message::{Message, Viewtype};
use crate::test_utils::TestContextManager;
use super::*;
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_send_receive() {
let mut tcm = TestContextManager::new();
// Create first device.
let ctx0 = tcm.alice().await;
// Write a message in the self chat
let self_chat = ctx0.get_self_chat().await;
let mut msg = Message::new(Viewtype::Text);
msg.set_text("hi there".to_string());
send_msg(&ctx0, self_chat.id, &mut msg).await.unwrap();
// Send an attachment in the self chat
let file = ctx0.get_blobdir().join("hello.txt");
fs::write(&file, "i am attachment").await.unwrap();
let mut msg = Message::new(Viewtype::File);
msg.set_file(file.to_str().unwrap(), Some("text/plain"));
send_msg(&ctx0, self_chat.id, &mut msg).await.unwrap();
// Prepare to transfer backup.
let provider = BackupProvider::prepare(&ctx0).await.unwrap();
// Set up second device.
let ctx1 = tcm.unconfigured().await;
get_backup(&ctx1, provider.qr()).await.unwrap();
// Make sure the provider finishes without an error.
tokio::time::timeout(Duration::from_secs(30), provider)
.await
.expect("timed out")
.expect("error in provider");
// Check that we have the self message.
let self_chat = ctx1.get_self_chat().await;
let msgs = get_chat_msgs(&ctx1, self_chat.id).await.unwrap();
assert_eq!(msgs.len(), 2);
let msgid = match msgs.get(0).unwrap() {
ChatItem::Message { msg_id } => msg_id,
_ => panic!("wrong chat item"),
};
let msg = Message::load_from_db(&ctx1, *msgid).await.unwrap();
let text = msg.get_text();
assert_eq!(text, "hi there");
let msgid = match msgs.get(1).unwrap() {
ChatItem::Message { msg_id } => msg_id,
_ => panic!("wrong chat item"),
};
let msg = Message::load_from_db(&ctx1, *msgid).await.unwrap();
let path = msg.get_file(&ctx1).unwrap();
let text = fs::read_to_string(&path).await.unwrap();
assert_eq!(text, "i am attachment");
// Check that both received the ImexProgress events.
ctx0.evtracker
.get_matching(|ev| matches!(ev, EventType::ImexProgress(1000)))
.await;
ctx1.evtracker
.get_matching(|ev| matches!(ev, EventType::ImexProgress(1000)))
.await;
}
#[test]
fn test_send_progress() {
let cases = [
((0, 100), 450),
((10, 100), 500),
((50, 100), 700),
((100, 100), 950),
];
for ((current_size, total_size), progress) in cases {
let out = EventType::from(SendProgress::TransferInProgress {
current_size,
total_size,
});
assert_eq!(out, EventType::ImexProgress(progress));
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_drop_provider() {
let mut tcm = TestContextManager::new();
let ctx = tcm.alice().await;
let provider = BackupProvider::prepare(&ctx).await.unwrap();
drop(provider);
ctx.evtracker
.get_matching(|ev| matches!(ev, EventType::ImexProgress(0)))
.await;
}
}