diff --git a/src/imex.rs b/src/imex.rs index 27e54039d..be30a6bcd 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -315,13 +315,6 @@ fn set_self_key( ensure!(self_addr.is_some(), "Missing self addr"); let addr = EmailAddress::new(&self_addr.unwrap_or_default())?; - // XXX maybe better make dc_key_save_self_keypair delete things - sql::execute( - context, - &context.sql, - "DELETE FROM keypairs WHERE public_key=? OR private_key=?;", - params![public_key.to_bytes(), private_key.to_bytes()], - )?; let (public, secret) = match (public_key, private_key) { (Key::Public(p), Key::Secret(s)) => (p, s), _ => bail!("wrong keys unpacked"), diff --git a/src/key.rs b/src/key.rs index fd0473c08..8cfc017dd 100644 --- a/src/key.rs +++ b/src/key.rs @@ -11,7 +11,7 @@ use pgp::types::{KeyTrait, SecretKeyTrait}; use crate::constants::*; use crate::context::Context; use crate::dc_tools::*; -use crate::sql::{self, Sql}; +use crate::sql::Sql; // Re-export key types pub use crate::pgp::KeyPair; @@ -350,46 +350,59 @@ impl SaveKeyError { /// "self" here refers to the fact that this DC instance owns the /// keypair. Usually `addr` will be [Config::ConfiguredAddr]. /// +/// If either the public or private keys are already present in the +/// database, this entry will be removed first regardless of the +/// address associated with it. Practically this means saving the +/// same key again overwrites it. +/// /// [Config::ConfiguredAddr]: crate::config::Config::ConfiguredAddr pub fn save_self_keypair( context: &Context, keypair: &KeyPair, default: KeyPairUse, ) -> std::result::Result<(), SaveKeyError> { - // Should really be one transaction, more refactoring is needed for that. - if default == KeyPairUse::Default { - sql::execute( - context, - &context.sql, - "UPDATE keypairs SET is_default=0;", - params![], + // Everything should really be one transaction, more refactoring + // is needed for that. + let public_key = keypair + .public + .to_bytes() + .map_err(|err| SaveKeyError::new("failed to serialise public key", err))?; + let secret_key = keypair + .secret + .to_bytes() + .map_err(|err| SaveKeyError::new("failed to serialise secret key", err))?; + context + .sql + .execute( + "DELETE FROM keypairs WHERE public_key=? OR private_key=?;", + params![public_key, secret_key], ) - .map_err(|err| SaveKeyError::new("failed to clear default", err))?; + .map_err(|err| SaveKeyError::new("failed to remove old use of key", err))?; + if default == KeyPairUse::Default { + context + .sql + .execute("UPDATE keypairs SET is_default=0;", params![]) + .map_err(|err| SaveKeyError::new("failed to clear default", err))?; } let is_default = match default { KeyPairUse::Default => true, KeyPairUse::ReadOnly => false, }; - sql::execute( - context, - &context.sql, - "INSERT INTO keypairs (addr, is_default, public_key, private_key, created) + context + .sql + .execute( + "INSERT INTO keypairs (addr, is_default, public_key, private_key, created) VALUES (?,?,?,?,?);", - params![ - keypair.addr.to_string(), - is_default as i32, - keypair - .public - .to_bytes() - .map_err(|err| SaveKeyError::new("failed to serialise public key", err))?, - keypair - .secret - .to_bytes() - .map_err(|err| SaveKeyError::new("failed to serialise secret key", err))?, - time() - ], - ) - .map_err(|err| SaveKeyError::new("failed to insert keypair", err)) + params![ + keypair.addr.to_string(), + is_default as i32, + public_key, + secret_key, + time() + ], + ) + .map(|_| ()) + .map_err(|err| SaveKeyError::new("failed to insert keypair", err)) } /// Make a fingerprint human-readable, in hex format. @@ -573,4 +586,22 @@ i8pcjGO+IZffvyZJVRWfVooBJmWWbPB1pueo3tx8w3+fcuzpxz+RLFKaPyqXO+dD let public = SignedPublicKey::try_from(public_wrapped).unwrap(); assert_eq!(public.primary_key, KEYPAIR.public.primary_key); } + + #[test] + fn test_save_self_key_twice() { + // Saving the same key twice should result in only one row in + // the keypairs table. + let t = dummy_context(); + let nrows = || { + t.ctx + .sql + .query_get_value::<_, u32>(&t.ctx, "SELECT COUNT(*) FROM keypairs;", params![]) + .unwrap() + }; + assert_eq!(nrows(), 0); + save_self_keypair(&t.ctx, &KEYPAIR, KeyPairUse::Default).unwrap(); + assert_eq!(nrows(), 1); + save_self_keypair(&t.ctx, &KEYPAIR, KeyPairUse::Default).unwrap(); + assert_eq!(nrows(), 1); + } }