feat: migrate to iroh 0.5

This commit is contained in:
dignifiedquire
2023-07-14 15:47:19 +02:00
parent f930576fd1
commit f325961505
4 changed files with 1506 additions and 821 deletions

View File

@@ -27,19 +27,27 @@ use std::net::Ipv4Addr;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::task::Poll;
use anyhow::{anyhow, bail, ensure, format_err, Context as _, Result};
use async_channel::Receiver;
use bytes::Bytes;
use futures::FutureExt;
use futures_lite::StreamExt;
use iroh::blobs::Collection;
use iroh::get::DataStream;
use iroh::progress::ProgressEmitter;
use iroh::protocol::AuthToken;
use iroh::provider::{DataSource, Event, Provider, Ticket};
use iroh::Hash;
use iroh::bytes::get::{fsm, Stats};
use iroh::bytes::protocol::{AnyGetRequest, GetRequest, RequestToken};
use iroh::bytes::provider::Event as ProviderEvent;
use iroh::collection::Collection;
use iroh::database::flat::DataSource;
use iroh::dial::Ticket;
use iroh::net::defaults::default_derp_map;
use iroh::net::tls::Keypair;
use iroh::node::{Event, Node as IrohNode, StaticTokenAuthHandler};
use iroh::util::progress::ProgressEmitter;
use tokio::fs::{self, File};
use tokio::io::{self, AsyncWriteExt, BufWriter};
use tokio::io::AsyncWriteExt;
use tokio::sync::broadcast::error::RecvError;
use tokio::sync::{broadcast, Mutex};
use tokio::task::{JoinHandle, JoinSet};
@@ -58,6 +66,8 @@ use super::{export_database, DBFILE_BACKUP_NAME};
const MAX_CONCURRENT_DIALS: u8 = 16;
type Node = IrohNode<iroh::database::flat::Database>;
/// Provide or send a backup of this device.
///
/// This creates a backup of the current device and starts a service which offers another
@@ -154,9 +164,9 @@ impl BackupProvider {
/// Creates the provider task.
///
/// Having this as a function makes it easier to cancel it when needed.
async fn prepare_inner(context: &Context, dbfile: &Path) -> Result<(Provider, Ticket)> {
async fn prepare_inner(context: &Context, dbfile: &Path) -> Result<(Node, Ticket)> {
// Generate the token up front: we also use it to encrypt the database.
let token = AuthToken::generate();
let token = RequestToken::generate();
context.emit_event(SendProgress::Started.into());
export_database(context, dbfile, token.to_string())
.await
@@ -176,19 +186,21 @@ impl BackupProvider {
}
// Start listening.
let (db, hash) = iroh::provider::create_collection(files).await?;
let (db, hash) = iroh::database::flat::create_collection(files).await?;
context.emit_event(SendProgress::CollectionCreated.into());
let provider = Provider::builder(db)
let auth_token_handler = StaticTokenAuthHandler::new(Some(token.clone()));
let provider = Node::builder(db)
.bind_addr((Ipv4Addr::UNSPECIFIED, 0).into())
.auth_token(token)
.spawn()?;
.custom_auth_handler(Arc::new(auth_token_handler))
.spawn()
.await?;
context.emit_event(SendProgress::ProviderListening.into());
info!(context, "Waiting for remote to connect");
let ticket = provider.ticket(hash)?;
let ticket = provider.ticket(hash, Some(token)).await?;
Ok((provider, ticket))
}
/// Supervises the iroh [`Provider`], terminating it when needed.
/// Supervises the iroh [`Node`], terminating it when needed.
///
/// This will watch the provider and terminate it when:
///
@@ -200,67 +212,80 @@ impl BackupProvider {
/// we must cancel this operation.
async fn watch_provider(
context: &Context,
mut provider: Provider,
mut provider: Node,
cancel_token: Receiver<()>,
drop_token: CancellationToken,
) -> Result<()> {
let mut events = provider.subscribe();
let mut total_size = 0;
let mut current_size = 0;
let total_size = Arc::new(AtomicU64::new(0));
let current_size = Arc::new(AtomicU64::new(0));
let (transfer_done, mut transfer_done_r) = tokio::sync::mpsc::channel(1);
let ctx = context.clone();
provider
.subscribe(move |event| {
let total_size = total_size.clone();
let current_size = current_size.clone();
let transfer_done = transfer_done.clone();
let context = ctx.clone();
async move {
match event {
Event::ByteProvide(event) => match event {
ProviderEvent::ClientConnected { .. } => {
context.emit_event(SendProgress::ClientConnected.into());
}
ProviderEvent::GetRequestReceived { .. } => {}
ProviderEvent::TransferCollectionStarted {
total_blobs_size, ..
} => {
total_size
.store(total_blobs_size.unwrap_or_default(), Ordering::Relaxed);
context.emit_event(
SendProgress::TransferInProgress {
current_size: current_size.load(Ordering::Relaxed),
total_size: total_size.load(Ordering::Relaxed),
}
.into(),
);
}
ProviderEvent::TransferBlobCompleted { size, .. } => {
current_size.fetch_add(size, Ordering::Relaxed);
context.emit_event(
SendProgress::TransferInProgress {
current_size: current_size.load(Ordering::Relaxed),
total_size: total_size.load(Ordering::Relaxed),
}
.into(),
);
}
ProviderEvent::TransferCollectionCompleted { .. } => {
let total_size = total_size.load(Ordering::Relaxed);
context.emit_event(
SendProgress::TransferInProgress {
current_size: total_size,
total_size,
}
.into(),
);
transfer_done.send(()).await.ok();
}
ProviderEvent::TransferAborted { .. } => {
transfer_done.send(()).await.ok();
}
ProviderEvent::CollectionAdded { .. } => {}
ProviderEvent::CustomGetRequestReceived { .. } => {}
},
}
}
.boxed()
})
.await?;
let res = loop {
tokio::select! {
biased;
res = &mut provider => {
break res.context("BackupProvider failed");
},
maybe_event = events.recv() => {
match maybe_event {
Ok(event) => {
match event {
Event::ClientConnected { ..} => {
context.emit_event(SendProgress::ClientConnected.into());
}
Event::RequestReceived { .. } => {
}
Event::TransferCollectionStarted { total_blobs_size, .. } => {
total_size = total_blobs_size;
context.emit_event(SendProgress::TransferInProgress {
current_size,
total_size,
}.into());
}
Event::TransferBlobCompleted { size, .. } => {
current_size += size;
context.emit_event(SendProgress::TransferInProgress {
current_size,
total_size,
}.into());
}
Event::TransferCollectionCompleted { .. } => {
context.emit_event(SendProgress::TransferInProgress {
current_size: total_size,
total_size
}.into());
provider.shutdown();
}
Event::TransferAborted { .. } => {
provider.shutdown();
break Err(anyhow!("BackupProvider transfer aborted"));
}
}
}
Err(broadcast::error::RecvError::Closed) => {
// We should never see this, provider.join() should complete
// first.
}
Err(broadcast::error::RecvError::Lagged(_)) => {
// We really shouldn't be lagging, if we did we may have missed
// a completion event.
provider.shutdown();
break Err(anyhow!("Missed events from BackupProvider"));
}
}
},
}
_ = cancel_token.recv() => {
provider.shutdown();
break Err(anyhow!("BackupProvider cancelled"));
@@ -269,6 +294,10 @@ impl BackupProvider {
provider.shutdown();
break Err(anyhow!("BackupProvider dropped"));
}
_ = transfer_done_r.recv() => {
provider.shutdown();
break Ok(());
}
}
};
match &res {
@@ -441,30 +470,11 @@ async fn get_backup_inner(context: &Context, qr: Qr) -> Result<()> {
async fn transfer_from_provider(context: &Context, ticket: &Ticket) -> Result<()> {
let progress = ProgressEmitter::new(0, ReceiveProgress::max_blob_progress());
spawn_progress_proxy(context.clone(), progress.subscribe());
let on_connected = || {
context.emit_event(ReceiveProgress::Connected.into());
async { Ok(()) }
};
let on_collection = |collection: &Collection| {
context.emit_event(ReceiveProgress::CollectionReceived.into());
progress.set_total(collection.total_blobs_size());
async { Ok(()) }
};
let jobs = Mutex::new(JoinSet::default());
let on_blob =
|hash, reader, name| on_blob(context, &progress, &jobs, ticket, hash, reader, name);
// Perform the transfer.
let keylog = false; // Do not enable rustls SSLKEYLOGFILE env var functionality
let stats = iroh::get::run_ticket(
ticket,
keylog,
MAX_CONCURRENT_DIALS,
on_connected,
on_collection,
on_blob,
)
.await?;
let stats = run_get_request(context, &progress, &jobs, ticket.clone()).await?;
let mut jobs = jobs.lock().await;
while let Some(job) = jobs.join_next().await {
@@ -479,6 +489,64 @@ async fn transfer_from_provider(context: &Context, ticket: &Ticket) -> Result<()
Ok(())
}
/// Run the get request
async fn run_get_request(
context: &Context,
progress: &ProgressEmitter,
jobs: &Mutex<JoinSet<()>>,
ticket: Ticket,
) -> anyhow::Result<Stats> {
let opts = ticket.as_get_options(Keypair::generate(), Some(default_derp_map()));
let request = AnyGetRequest::Get(GetRequest::all(ticket.hash()));
let connection = iroh::dial::dial(opts).await?;
let initial = fsm::start(connection, request);
use fsm::*;
let connected = initial.next().await?;
context.emit_event(ReceiveProgress::Connected.into());
// we assume that the request includes the entire collection
let (mut next, root, collection) = {
let ConnectedNext::StartRoot(sc) = connected.next().await? else {
bail!("request did not include collection");
};
let (done, data) = sc.next().concatenate_into_vec().await?;
let data = Bytes::from(data);
let collection = Collection::from_bytes(&data)?;
context.emit_event(ReceiveProgress::CollectionReceived.into());
progress.set_total(collection.total_blobs_size());
(done.next(), data, collection)
};
// download all the children
let mut current_blob = 0;
let finishing = loop {
let start = match next {
EndBlobNext::MoreChildren(start) => start,
EndBlobNext::Closing(finishing) => break finishing,
};
let child_offset = start.child_offset();
let offset = child_offset + 1;
// get the hash of the next blob, or finish if there are no more
let Some(blob) = collection.blobs().get(current_blob) else {
break start.finish();
};
let start = start.next(blob.hash);
let done = on_blob(context, progress, jobs, &ticket, start, &blob.name).await?;
current_blob += 1;
next = done.next();
};
let stats = finishing.next().await?;
Ok(stats)
}
/// Get callback when a blob is received from the provider.
///
/// This writes the blobs to the blobdir. If the blob is the database it will import it to
@@ -488,10 +556,9 @@ async fn on_blob(
progress: &ProgressEmitter,
jobs: &Mutex<JoinSet<()>>,
ticket: &Ticket,
_hash: Hash,
mut reader: DataStream,
name: String,
) -> Result<DataStream> {
mut state: fsm::AtBlobHeader,
name: &str,
) -> Result<fsm::AtEndBlob> {
ensure!(!name.is_empty(), "Received a nameless blob");
let path = if name.starts_with("db/") {
let context_dir = context
@@ -510,15 +577,17 @@ async fn on_blob(
context.get_blobdir().join(blobname)
};
let mut wrapped_reader = progress.wrap_async_read(&mut reader);
let file = File::create(&path).await?;
let mut file = BufWriter::with_capacity(128 * 1024, file);
io::copy(&mut wrapped_reader, &mut file).await?;
// TODO: BufWriter doesn't implement AsyncSliceWriter :(
// let mut file = BufWriter::with_capacity(128 * 1024, file);
// TODO: ProgressEmitter doesn't support writers :(
// let mut wrapped_file = progress.wrap_async_write(&mut file);
let done = state.write_all(&mut file).await?;
file.flush().await?;
if name.starts_with("db/") {
let context = context.clone();
let token = ticket.token().to_string();
let token = ticket.token().map(|t| t.to_string()).unwrap_or_default();
jobs.lock().await.spawn(async move {
if let Err(err) = context.sql.import(&path, token).await {
error!(context, "cannot import database: {:#?}", err);
@@ -533,7 +602,8 @@ async fn on_blob(
}
});
}
Ok(reader)
Ok(done)
}
/// Spawns a task proxying progress events.

View File

@@ -5,6 +5,7 @@ use std::collections::BTreeMap;
use anyhow::{anyhow, bail, ensure, Context as _, Result};
pub use dclogin_scheme::LoginOptions;
use iroh::dial::Ticket;
use once_cell::sync::Lazy;
use percent_encoding::percent_decode_str;
use serde::Deserialize;
@@ -113,7 +114,7 @@ pub enum Qr {
/// information to connect to and authenticate a backup provider.
///
/// The format is somewhat opaque, but `sendme` can deserialise this.
ticket: iroh::provider::Ticket,
ticket: Ticket,
},
/// Ask the user if they want to use the given service for video chats.
@@ -501,7 +502,7 @@ fn decode_backup(qr: &str) -> Result<Qr> {
let payload = qr
.strip_prefix(DCBACKUP_SCHEME)
.ok_or_else(|| anyhow!("invalid DCBACKUP scheme"))?;
let ticket: iroh::provider::Ticket = payload.parse().context("invalid DCBACKUP payload")?;
let ticket: Ticket = payload.parse().context("invalid DCBACKUP payload")?;
Ok(Qr::Backup { ticket })
}