mirror of
https://github.com/chatmail/core.git
synced 2026-05-02 21:06:31 +03:00
feat: migrate to iroh 0.5
This commit is contained in:
@@ -27,19 +27,27 @@ use std::net::Ipv4Addr;
|
||||
use std::ops::Deref;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::pin::Pin;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::task::Poll;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
|
||||
use async_channel::Receiver;
|
||||
use bytes::Bytes;
|
||||
use futures::FutureExt;
|
||||
use futures_lite::StreamExt;
|
||||
use iroh::blobs::Collection;
|
||||
use iroh::get::DataStream;
|
||||
use iroh::progress::ProgressEmitter;
|
||||
use iroh::protocol::AuthToken;
|
||||
use iroh::provider::{DataSource, Event, Provider, Ticket};
|
||||
use iroh::Hash;
|
||||
use iroh::bytes::get::{fsm, Stats};
|
||||
use iroh::bytes::protocol::{AnyGetRequest, GetRequest, RequestToken};
|
||||
use iroh::bytes::provider::Event as ProviderEvent;
|
||||
use iroh::collection::Collection;
|
||||
use iroh::database::flat::DataSource;
|
||||
use iroh::dial::Ticket;
|
||||
use iroh::net::defaults::default_derp_map;
|
||||
use iroh::net::tls::Keypair;
|
||||
use iroh::node::{Event, Node as IrohNode, StaticTokenAuthHandler};
|
||||
use iroh::util::progress::ProgressEmitter;
|
||||
use tokio::fs::{self, File};
|
||||
use tokio::io::{self, AsyncWriteExt, BufWriter};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
use tokio::sync::broadcast::error::RecvError;
|
||||
use tokio::sync::{broadcast, Mutex};
|
||||
use tokio::task::{JoinHandle, JoinSet};
|
||||
@@ -58,6 +66,8 @@ use super::{export_database, DBFILE_BACKUP_NAME};
|
||||
|
||||
const MAX_CONCURRENT_DIALS: u8 = 16;
|
||||
|
||||
type Node = IrohNode<iroh::database::flat::Database>;
|
||||
|
||||
/// Provide or send a backup of this device.
|
||||
///
|
||||
/// This creates a backup of the current device and starts a service which offers another
|
||||
@@ -154,9 +164,9 @@ impl BackupProvider {
|
||||
/// Creates the provider task.
|
||||
///
|
||||
/// Having this as a function makes it easier to cancel it when needed.
|
||||
async fn prepare_inner(context: &Context, dbfile: &Path) -> Result<(Provider, Ticket)> {
|
||||
async fn prepare_inner(context: &Context, dbfile: &Path) -> Result<(Node, Ticket)> {
|
||||
// Generate the token up front: we also use it to encrypt the database.
|
||||
let token = AuthToken::generate();
|
||||
let token = RequestToken::generate();
|
||||
context.emit_event(SendProgress::Started.into());
|
||||
export_database(context, dbfile, token.to_string())
|
||||
.await
|
||||
@@ -176,19 +186,21 @@ impl BackupProvider {
|
||||
}
|
||||
|
||||
// Start listening.
|
||||
let (db, hash) = iroh::provider::create_collection(files).await?;
|
||||
let (db, hash) = iroh::database::flat::create_collection(files).await?;
|
||||
context.emit_event(SendProgress::CollectionCreated.into());
|
||||
let provider = Provider::builder(db)
|
||||
let auth_token_handler = StaticTokenAuthHandler::new(Some(token.clone()));
|
||||
let provider = Node::builder(db)
|
||||
.bind_addr((Ipv4Addr::UNSPECIFIED, 0).into())
|
||||
.auth_token(token)
|
||||
.spawn()?;
|
||||
.custom_auth_handler(Arc::new(auth_token_handler))
|
||||
.spawn()
|
||||
.await?;
|
||||
context.emit_event(SendProgress::ProviderListening.into());
|
||||
info!(context, "Waiting for remote to connect");
|
||||
let ticket = provider.ticket(hash)?;
|
||||
let ticket = provider.ticket(hash, Some(token)).await?;
|
||||
Ok((provider, ticket))
|
||||
}
|
||||
|
||||
/// Supervises the iroh [`Provider`], terminating it when needed.
|
||||
/// Supervises the iroh [`Node`], terminating it when needed.
|
||||
///
|
||||
/// This will watch the provider and terminate it when:
|
||||
///
|
||||
@@ -200,67 +212,80 @@ impl BackupProvider {
|
||||
/// we must cancel this operation.
|
||||
async fn watch_provider(
|
||||
context: &Context,
|
||||
mut provider: Provider,
|
||||
mut provider: Node,
|
||||
cancel_token: Receiver<()>,
|
||||
drop_token: CancellationToken,
|
||||
) -> Result<()> {
|
||||
let mut events = provider.subscribe();
|
||||
let mut total_size = 0;
|
||||
let mut current_size = 0;
|
||||
let total_size = Arc::new(AtomicU64::new(0));
|
||||
let current_size = Arc::new(AtomicU64::new(0));
|
||||
let (transfer_done, mut transfer_done_r) = tokio::sync::mpsc::channel(1);
|
||||
|
||||
let ctx = context.clone();
|
||||
provider
|
||||
.subscribe(move |event| {
|
||||
let total_size = total_size.clone();
|
||||
let current_size = current_size.clone();
|
||||
let transfer_done = transfer_done.clone();
|
||||
let context = ctx.clone();
|
||||
async move {
|
||||
match event {
|
||||
Event::ByteProvide(event) => match event {
|
||||
ProviderEvent::ClientConnected { .. } => {
|
||||
context.emit_event(SendProgress::ClientConnected.into());
|
||||
}
|
||||
ProviderEvent::GetRequestReceived { .. } => {}
|
||||
ProviderEvent::TransferCollectionStarted {
|
||||
total_blobs_size, ..
|
||||
} => {
|
||||
total_size
|
||||
.store(total_blobs_size.unwrap_or_default(), Ordering::Relaxed);
|
||||
context.emit_event(
|
||||
SendProgress::TransferInProgress {
|
||||
current_size: current_size.load(Ordering::Relaxed),
|
||||
total_size: total_size.load(Ordering::Relaxed),
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
ProviderEvent::TransferBlobCompleted { size, .. } => {
|
||||
current_size.fetch_add(size, Ordering::Relaxed);
|
||||
context.emit_event(
|
||||
SendProgress::TransferInProgress {
|
||||
current_size: current_size.load(Ordering::Relaxed),
|
||||
total_size: total_size.load(Ordering::Relaxed),
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
}
|
||||
ProviderEvent::TransferCollectionCompleted { .. } => {
|
||||
let total_size = total_size.load(Ordering::Relaxed);
|
||||
context.emit_event(
|
||||
SendProgress::TransferInProgress {
|
||||
current_size: total_size,
|
||||
total_size,
|
||||
}
|
||||
.into(),
|
||||
);
|
||||
transfer_done.send(()).await.ok();
|
||||
}
|
||||
ProviderEvent::TransferAborted { .. } => {
|
||||
transfer_done.send(()).await.ok();
|
||||
}
|
||||
ProviderEvent::CollectionAdded { .. } => {}
|
||||
ProviderEvent::CustomGetRequestReceived { .. } => {}
|
||||
},
|
||||
}
|
||||
}
|
||||
.boxed()
|
||||
})
|
||||
.await?;
|
||||
|
||||
let res = loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
res = &mut provider => {
|
||||
break res.context("BackupProvider failed");
|
||||
},
|
||||
maybe_event = events.recv() => {
|
||||
match maybe_event {
|
||||
Ok(event) => {
|
||||
match event {
|
||||
Event::ClientConnected { ..} => {
|
||||
context.emit_event(SendProgress::ClientConnected.into());
|
||||
}
|
||||
Event::RequestReceived { .. } => {
|
||||
}
|
||||
Event::TransferCollectionStarted { total_blobs_size, .. } => {
|
||||
total_size = total_blobs_size;
|
||||
context.emit_event(SendProgress::TransferInProgress {
|
||||
current_size,
|
||||
total_size,
|
||||
}.into());
|
||||
}
|
||||
Event::TransferBlobCompleted { size, .. } => {
|
||||
current_size += size;
|
||||
context.emit_event(SendProgress::TransferInProgress {
|
||||
current_size,
|
||||
total_size,
|
||||
}.into());
|
||||
}
|
||||
Event::TransferCollectionCompleted { .. } => {
|
||||
context.emit_event(SendProgress::TransferInProgress {
|
||||
current_size: total_size,
|
||||
total_size
|
||||
}.into());
|
||||
provider.shutdown();
|
||||
}
|
||||
Event::TransferAborted { .. } => {
|
||||
provider.shutdown();
|
||||
break Err(anyhow!("BackupProvider transfer aborted"));
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(broadcast::error::RecvError::Closed) => {
|
||||
// We should never see this, provider.join() should complete
|
||||
// first.
|
||||
}
|
||||
Err(broadcast::error::RecvError::Lagged(_)) => {
|
||||
// We really shouldn't be lagging, if we did we may have missed
|
||||
// a completion event.
|
||||
provider.shutdown();
|
||||
break Err(anyhow!("Missed events from BackupProvider"));
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
_ = cancel_token.recv() => {
|
||||
provider.shutdown();
|
||||
break Err(anyhow!("BackupProvider cancelled"));
|
||||
@@ -269,6 +294,10 @@ impl BackupProvider {
|
||||
provider.shutdown();
|
||||
break Err(anyhow!("BackupProvider dropped"));
|
||||
}
|
||||
_ = transfer_done_r.recv() => {
|
||||
provider.shutdown();
|
||||
break Ok(());
|
||||
}
|
||||
}
|
||||
};
|
||||
match &res {
|
||||
@@ -441,30 +470,11 @@ async fn get_backup_inner(context: &Context, qr: Qr) -> Result<()> {
|
||||
async fn transfer_from_provider(context: &Context, ticket: &Ticket) -> Result<()> {
|
||||
let progress = ProgressEmitter::new(0, ReceiveProgress::max_blob_progress());
|
||||
spawn_progress_proxy(context.clone(), progress.subscribe());
|
||||
let on_connected = || {
|
||||
context.emit_event(ReceiveProgress::Connected.into());
|
||||
async { Ok(()) }
|
||||
};
|
||||
let on_collection = |collection: &Collection| {
|
||||
context.emit_event(ReceiveProgress::CollectionReceived.into());
|
||||
progress.set_total(collection.total_blobs_size());
|
||||
async { Ok(()) }
|
||||
};
|
||||
|
||||
let jobs = Mutex::new(JoinSet::default());
|
||||
let on_blob =
|
||||
|hash, reader, name| on_blob(context, &progress, &jobs, ticket, hash, reader, name);
|
||||
|
||||
// Perform the transfer.
|
||||
let keylog = false; // Do not enable rustls SSLKEYLOGFILE env var functionality
|
||||
let stats = iroh::get::run_ticket(
|
||||
ticket,
|
||||
keylog,
|
||||
MAX_CONCURRENT_DIALS,
|
||||
on_connected,
|
||||
on_collection,
|
||||
on_blob,
|
||||
)
|
||||
.await?;
|
||||
let stats = run_get_request(context, &progress, &jobs, ticket.clone()).await?;
|
||||
|
||||
let mut jobs = jobs.lock().await;
|
||||
while let Some(job) = jobs.join_next().await {
|
||||
@@ -479,6 +489,64 @@ async fn transfer_from_provider(context: &Context, ticket: &Ticket) -> Result<()
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Run the get request
|
||||
async fn run_get_request(
|
||||
context: &Context,
|
||||
progress: &ProgressEmitter,
|
||||
jobs: &Mutex<JoinSet<()>>,
|
||||
ticket: Ticket,
|
||||
) -> anyhow::Result<Stats> {
|
||||
let opts = ticket.as_get_options(Keypair::generate(), Some(default_derp_map()));
|
||||
let request = AnyGetRequest::Get(GetRequest::all(ticket.hash()));
|
||||
let connection = iroh::dial::dial(opts).await?;
|
||||
let initial = fsm::start(connection, request);
|
||||
use fsm::*;
|
||||
|
||||
let connected = initial.next().await?;
|
||||
context.emit_event(ReceiveProgress::Connected.into());
|
||||
|
||||
// we assume that the request includes the entire collection
|
||||
let (mut next, root, collection) = {
|
||||
let ConnectedNext::StartRoot(sc) = connected.next().await? else {
|
||||
bail!("request did not include collection");
|
||||
};
|
||||
|
||||
let (done, data) = sc.next().concatenate_into_vec().await?;
|
||||
let data = Bytes::from(data);
|
||||
let collection = Collection::from_bytes(&data)?;
|
||||
|
||||
context.emit_event(ReceiveProgress::CollectionReceived.into());
|
||||
progress.set_total(collection.total_blobs_size());
|
||||
|
||||
(done.next(), data, collection)
|
||||
};
|
||||
|
||||
// download all the children
|
||||
let mut current_blob = 0;
|
||||
let finishing = loop {
|
||||
let start = match next {
|
||||
EndBlobNext::MoreChildren(start) => start,
|
||||
EndBlobNext::Closing(finishing) => break finishing,
|
||||
};
|
||||
let child_offset = start.child_offset();
|
||||
let offset = child_offset + 1;
|
||||
|
||||
// get the hash of the next blob, or finish if there are no more
|
||||
let Some(blob) = collection.blobs().get(current_blob) else {
|
||||
break start.finish();
|
||||
};
|
||||
|
||||
let start = start.next(blob.hash);
|
||||
let done = on_blob(context, progress, jobs, &ticket, start, &blob.name).await?;
|
||||
|
||||
current_blob += 1;
|
||||
next = done.next();
|
||||
};
|
||||
let stats = finishing.next().await?;
|
||||
|
||||
Ok(stats)
|
||||
}
|
||||
|
||||
/// Get callback when a blob is received from the provider.
|
||||
///
|
||||
/// This writes the blobs to the blobdir. If the blob is the database it will import it to
|
||||
@@ -488,10 +556,9 @@ async fn on_blob(
|
||||
progress: &ProgressEmitter,
|
||||
jobs: &Mutex<JoinSet<()>>,
|
||||
ticket: &Ticket,
|
||||
_hash: Hash,
|
||||
mut reader: DataStream,
|
||||
name: String,
|
||||
) -> Result<DataStream> {
|
||||
mut state: fsm::AtBlobHeader,
|
||||
name: &str,
|
||||
) -> Result<fsm::AtEndBlob> {
|
||||
ensure!(!name.is_empty(), "Received a nameless blob");
|
||||
let path = if name.starts_with("db/") {
|
||||
let context_dir = context
|
||||
@@ -510,15 +577,17 @@ async fn on_blob(
|
||||
context.get_blobdir().join(blobname)
|
||||
};
|
||||
|
||||
let mut wrapped_reader = progress.wrap_async_read(&mut reader);
|
||||
let file = File::create(&path).await?;
|
||||
let mut file = BufWriter::with_capacity(128 * 1024, file);
|
||||
io::copy(&mut wrapped_reader, &mut file).await?;
|
||||
// TODO: BufWriter doesn't implement AsyncSliceWriter :(
|
||||
// let mut file = BufWriter::with_capacity(128 * 1024, file);
|
||||
// TODO: ProgressEmitter doesn't support writers :(
|
||||
// let mut wrapped_file = progress.wrap_async_write(&mut file);
|
||||
let done = state.write_all(&mut file).await?;
|
||||
file.flush().await?;
|
||||
|
||||
if name.starts_with("db/") {
|
||||
let context = context.clone();
|
||||
let token = ticket.token().to_string();
|
||||
let token = ticket.token().map(|t| t.to_string()).unwrap_or_default();
|
||||
jobs.lock().await.spawn(async move {
|
||||
if let Err(err) = context.sql.import(&path, token).await {
|
||||
error!(context, "cannot import database: {:#?}", err);
|
||||
@@ -533,7 +602,8 @@ async fn on_blob(
|
||||
}
|
||||
});
|
||||
}
|
||||
Ok(reader)
|
||||
|
||||
Ok(done)
|
||||
}
|
||||
|
||||
/// Spawns a task proxying progress events.
|
||||
|
||||
@@ -5,6 +5,7 @@ use std::collections::BTreeMap;
|
||||
|
||||
use anyhow::{anyhow, bail, ensure, Context as _, Result};
|
||||
pub use dclogin_scheme::LoginOptions;
|
||||
use iroh::dial::Ticket;
|
||||
use once_cell::sync::Lazy;
|
||||
use percent_encoding::percent_decode_str;
|
||||
use serde::Deserialize;
|
||||
@@ -113,7 +114,7 @@ pub enum Qr {
|
||||
/// information to connect to and authenticate a backup provider.
|
||||
///
|
||||
/// The format is somewhat opaque, but `sendme` can deserialise this.
|
||||
ticket: iroh::provider::Ticket,
|
||||
ticket: Ticket,
|
||||
},
|
||||
|
||||
/// Ask the user if they want to use the given service for video chats.
|
||||
@@ -501,7 +502,7 @@ fn decode_backup(qr: &str) -> Result<Qr> {
|
||||
let payload = qr
|
||||
.strip_prefix(DCBACKUP_SCHEME)
|
||||
.ok_or_else(|| anyhow!("invalid DCBACKUP scheme"))?;
|
||||
let ticket: iroh::provider::Ticket = payload.parse().context("invalid DCBACKUP payload")?;
|
||||
let ticket: Ticket = payload.parse().context("invalid DCBACKUP payload")?;
|
||||
Ok(Qr::Backup { ticket })
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user