refactor: replace DcKey.load_self trait method with functions

This commit is contained in:
link2xt
2023-07-27 18:23:56 +00:00
parent c55a3d3873
commit 9b9703a48e
8 changed files with 77 additions and 97 deletions

View File

@@ -3,11 +3,9 @@
use std::collections::BTreeMap;
use std::fmt;
use std::io::Cursor;
use std::pin::Pin;
use anyhow::{ensure, Context as _, Result};
use base64::Engine as _;
use futures::Future;
use num_traits::FromPrimitive;
use pgp::composed::Deserializable;
pub use pgp::composed::{SignedPublicKey, SignedSecretKey};
@@ -49,11 +47,6 @@ pub trait DcKey: Serialize + Deserializable + KeyTrait + Clone {
Self::from_armor_single(Cursor::new(bytes)).context("rPGP error")
}
/// Load the users' default key from the database.
fn load_self<'a>(
context: &'a Context,
) -> Pin<Box<dyn Future<Output = Result<Self>> + 'a + Send>>;
/// Serialise the key as bytes.
fn to_bytes(&self) -> Vec<u8> {
// Not using Serialize::to_bytes() to make clear *why* it is
@@ -84,38 +77,55 @@ pub trait DcKey: Serialize + Deserializable + KeyTrait + Clone {
}
}
impl DcKey for SignedPublicKey {
fn load_self<'a>(
context: &'a Context,
) -> Pin<Box<dyn Future<Output = Result<Self>> + 'a + Send>> {
Box::pin(async move {
let addr = context.get_primary_self_addr().await?;
match context
.sql
.query_row_optional(
r#"
SELECT public_key
FROM keypairs
WHERE addr=?
AND is_default=1;
"#,
(addr,),
|row| {
let bytes: Vec<u8> = row.get(0)?;
Ok(bytes)
},
)
.await?
{
Some(bytes) => Self::from_slice(&bytes),
None => {
let keypair = generate_keypair(context).await?;
Ok(keypair.public)
}
}
})
pub(crate) async fn load_self_public_key(context: &Context) -> Result<SignedPublicKey> {
match context
.sql
.query_row_optional(
r#"SELECT public_key
FROM keypairs
WHERE addr=(SELECT value FROM config WHERE keyname="configured_addr")
AND is_default=1"#,
(),
|row| {
let bytes: Vec<u8> = row.get(0)?;
Ok(bytes)
},
)
.await?
{
Some(bytes) => SignedPublicKey::from_slice(&bytes),
None => {
let keypair = generate_keypair(context).await?;
Ok(keypair.public)
}
}
}
pub(crate) async fn load_self_secret_key(context: &Context) -> Result<SignedSecretKey> {
match context
.sql
.query_row_optional(
r#"SELECT private_key
FROM keypairs
WHERE addr=(SELECT value FROM config WHERE keyname="configured_addr")
AND is_default=1"#,
(),
|row| {
let bytes: Vec<u8> = row.get(0)?;
Ok(bytes)
},
)
.await?
{
Some(bytes) => SignedSecretKey::from_slice(&bytes),
None => {
let keypair = generate_keypair(context).await?;
Ok(keypair.secret)
}
}
}
impl DcKey for SignedPublicKey {
fn to_asc(&self, header: Option<(&str, &str)>) -> String {
// Not using .to_armored_string() to make clear *why* it is
// safe to ignore this error.
@@ -134,36 +144,6 @@ impl DcKey for SignedPublicKey {
}
impl DcKey for SignedSecretKey {
fn load_self<'a>(
context: &'a Context,
) -> Pin<Box<dyn Future<Output = Result<Self>> + 'a + Send>> {
Box::pin(async move {
match context
.sql
.query_row_optional(
r#"
SELECT private_key
FROM keypairs
WHERE addr=(SELECT value FROM config WHERE keyname="configured_addr")
AND is_default=1;
"#,
(),
|row| {
let bytes: Vec<u8> = row.get(0)?;
Ok(bytes)
},
)
.await?
{
Some(bytes) => Self::from_slice(&bytes),
None => {
let keypair = generate_keypair(context).await?;
Ok(keypair.secret)
}
}
})
}
fn to_asc(&self, header: Option<(&str, &str)>) -> String {
// Not using .to_armored_string() to make clear *why* it is
// safe to do these unwraps.
@@ -521,9 +501,9 @@ i8pcjGO+IZffvyZJVRWfVooBJmWWbPB1pueo3tx8w3+fcuzpxz+RLFKaPyqXO+dD
async fn test_load_self_existing() {
let alice = alice_keypair();
let t = TestContext::new_alice().await;
let pubkey = SignedPublicKey::load_self(&t).await.unwrap();
let pubkey = load_self_public_key(&t).await.unwrap();
assert_eq!(alice.public, pubkey);
let seckey = SignedSecretKey::load_self(&t).await.unwrap();
let seckey = load_self_secret_key(&t).await.unwrap();
assert_eq!(alice.secret, seckey);
}
@@ -533,7 +513,7 @@ i8pcjGO+IZffvyZJVRWfVooBJmWWbPB1pueo3tx8w3+fcuzpxz+RLFKaPyqXO+dD
t.set_config(Config::ConfiguredAddr, Some("alice@example.org"))
.await
.unwrap();
let key = SignedPublicKey::load_self(&t).await;
let key = load_self_public_key(&t).await;
assert!(key.is_ok());
}
@@ -543,7 +523,7 @@ i8pcjGO+IZffvyZJVRWfVooBJmWWbPB1pueo3tx8w3+fcuzpxz+RLFKaPyqXO+dD
t.set_config(Config::ConfiguredAddr, Some("alice@example.org"))
.await
.unwrap();
let key = SignedSecretKey::load_self(&t).await;
let key = load_self_secret_key(&t).await;
assert!(key.is_ok());
}
@@ -560,7 +540,7 @@ i8pcjGO+IZffvyZJVRWfVooBJmWWbPB1pueo3tx8w3+fcuzpxz+RLFKaPyqXO+dD
thread::spawn(move || {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(SignedPublicKey::load_self(&ctx))
.block_on(load_self_public_key(&ctx))
})
};
let thr1 = {
@@ -568,7 +548,7 @@ i8pcjGO+IZffvyZJVRWfVooBJmWWbPB1pueo3tx8w3+fcuzpxz+RLFKaPyqXO+dD
thread::spawn(move || {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(SignedPublicKey::load_self(&ctx))
.block_on(load_self_public_key(&ctx))
})
};
let res0 = thr0.join().unwrap();