feat: improve internal sql interface

Switches from rusqlite to sqlx to have a fully async based interface
to sqlite.

Co-authored-by: B. Petersen <r10s@b44t.com>
Co-authored-by: Hocuri <hocuri@gmx.de>
Co-authored-by: link2xt <link2xt@testrun.org>
This commit is contained in:
Friedel Ziegelmayer
2021-04-06 16:03:10 +02:00
committed by dignifiedquire
parent 4dedc2d8ce
commit 6bb5721f29
52 changed files with 5505 additions and 4983 deletions

View File

@@ -2,6 +2,7 @@
use std::collections::HashMap;
use anyhow::Result;
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
use serde::Deserialize;
@@ -58,11 +59,7 @@ pub async fn dc_get_oauth2_url(
if let Some(oauth2) = Oauth2::from_address(addr).await {
if context
.sql
.set_raw_config(
context,
"oauth2_pending_redirect_uri",
Some(redirect_uri.as_ref()),
)
.set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri.as_ref()))
.await
.is_err()
{
@@ -82,31 +79,25 @@ pub async fn dc_get_oauth2_access_token(
addr: impl AsRef<str>,
code: impl AsRef<str>,
regenerate: bool,
) -> Option<String> {
) -> Result<Option<String>> {
if let Some(oauth2) = Oauth2::from_address(addr).await {
let lock = context.oauth2_mutex.lock().await;
// read generated token
if !regenerate && !is_expired(context).await {
let access_token = context
.sql
.get_raw_config(context, "oauth2_access_token")
.await;
if !regenerate && !is_expired(context).await? {
let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
if access_token.is_some() {
// success
return access_token;
return Ok(access_token);
}
}
// generate new token: build & call auth url
let refresh_token = context
.sql
.get_raw_config(context, "oauth2_refresh_token")
.await;
let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
let refresh_token_for = context
.sql
.get_raw_config(context, "oauth2_refresh_token_for")
.await
.get_raw_config("oauth2_refresh_token_for")
.await?
.unwrap_or_else(|| "unset".into());
let (redirect_uri, token_url, update_redirect_uri_on_success) =
@@ -115,8 +106,8 @@ pub async fn dc_get_oauth2_access_token(
(
context
.sql
.get_raw_config(context, "oauth2_pending_redirect_uri")
.await
.get_raw_config("oauth2_pending_redirect_uri")
.await?
.unwrap_or_else(|| "unset".into()),
oauth2.init_token,
true,
@@ -129,8 +120,8 @@ pub async fn dc_get_oauth2_access_token(
(
context
.sql
.get_raw_config(context, "oauth2_redirect_uri")
.await
.get_raw_config("oauth2_redirect_uri")
.await?
.unwrap_or_else(|| "unset".into()),
oauth2.refresh_token,
false,
@@ -166,7 +157,7 @@ pub async fn dc_get_oauth2_access_token(
let mut req = surf::post(post_url).build();
if let Err(err) = req.body_form(&post_param) {
warn!(context, "Error calling OAuth2 at {}: {:?}", token_url, err);
return None;
return Ok(None);
}
let client = surf::Client::new();
@@ -176,7 +167,7 @@ pub async fn dc_get_oauth2_access_token(
context,
"Failed to parse OAuth2 JSON response from {}: error: {:?}", token_url, parsed
);
return None;
return Ok(None);
}
// update refresh_token if given, typically on the first round, but we update it later as well.
@@ -184,14 +175,12 @@ pub async fn dc_get_oauth2_access_token(
if let Some(ref token) = response.refresh_token {
context
.sql
.set_raw_config(context, "oauth2_refresh_token", Some(token))
.await
.ok();
.set_raw_config("oauth2_refresh_token", Some(token))
.await?;
context
.sql
.set_raw_config(context, "oauth2_refresh_token_for", Some(code.as_ref()))
.await
.ok();
.set_raw_config("oauth2_refresh_token_for", Some(code.as_ref()))
.await?;
}
// after that, save the access token.
@@ -199,9 +188,8 @@ pub async fn dc_get_oauth2_access_token(
if let Some(ref token) = response.access_token {
context
.sql
.set_raw_config(context, "oauth2_access_token", Some(token))
.await
.ok();
.set_raw_config("oauth2_access_token", Some(token))
.await?;
let expires_in = response
.expires_in
// refresh a bit before
@@ -209,16 +197,14 @@ pub async fn dc_get_oauth2_access_token(
.unwrap_or_else(|| 0);
context
.sql
.set_raw_config_int64(context, "oauth2_timestamp_expires", expires_in)
.await
.ok();
.set_raw_config_int64("oauth2_timestamp_expires", expires_in)
.await?;
if update_redirect_uri_on_success {
context
.sql
.set_raw_config(context, "oauth2_redirect_uri", Some(redirect_uri.as_ref()))
.await
.ok();
.set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
.await?;
}
} else {
warn!(context, "Failed to find OAuth2 access token");
@@ -226,11 +212,11 @@ pub async fn dc_get_oauth2_access_token(
drop(lock);
response.access_token
Ok(response.access_token)
} else {
warn!(context, "Internal OAuth2 error: 2");
None
Ok(None)
}
}
@@ -238,27 +224,33 @@ pub async fn dc_get_oauth2_addr(
context: &Context,
addr: impl AsRef<str>,
code: impl AsRef<str>,
) -> Option<String> {
let oauth2 = Oauth2::from_address(addr.as_ref()).await?;
oauth2.get_userinfo?;
) -> Result<Option<String>> {
let oauth2 = match Oauth2::from_address(addr.as_ref()).await {
Some(o) => o,
None => return Ok(None),
};
if oauth2.get_userinfo.is_none() {
return Ok(None);
}
if let Some(access_token) =
dc_get_oauth2_access_token(context, addr.as_ref(), code.as_ref(), false).await
dc_get_oauth2_access_token(context, addr.as_ref(), code.as_ref(), false).await?
{
let addr_out = oauth2.get_addr(context, access_token).await;
if addr_out.is_none() {
// regenerate
if let Some(access_token) = dc_get_oauth2_access_token(context, addr, code, true).await
if let Some(access_token) =
dc_get_oauth2_access_token(context, addr, code, true).await?
{
oauth2.get_addr(context, access_token).await
Ok(oauth2.get_addr(context, access_token).await)
} else {
None
Ok(None)
}
} else {
addr_out
Ok(addr_out)
}
} else {
None
Ok(None)
}
}
@@ -317,21 +309,21 @@ impl Oauth2 {
}
}
async fn is_expired(context: &Context) -> bool {
async fn is_expired(context: &Context) -> Result<bool, crate::sql::Error> {
let expire_timestamp = context
.sql
.get_raw_config_int64(context, "oauth2_timestamp_expires")
.await
.get_raw_config_int64("oauth2_timestamp_expires")
.await?
.unwrap_or_default();
if expire_timestamp <= 0 {
return false;
return Ok(false);
}
if expire_timestamp > time() {
return false;
return Ok(false);
}
true
Ok(true)
}
fn replace_in_uri(uri: impl AsRef<str>, key: impl AsRef<str>, value: impl AsRef<str>) -> String {
@@ -399,7 +391,7 @@ mod tests {
let ctx = TestContext::new().await;
let addr = "dignifiedquire@gmail.com";
let code = "fail";
let res = dc_get_oauth2_addr(&ctx.ctx, addr, code).await;
let res = dc_get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
// this should fail as it is an invalid password
assert_eq!(res, None);
}
@@ -419,7 +411,9 @@ mod tests {
let ctx = TestContext::new().await;
let addr = "dignifiedquire@gmail.com";
let code = "fail";
let res = dc_get_oauth2_access_token(&ctx.ctx, addr, code, false).await;
let res = dc_get_oauth2_access_token(&ctx.ctx, addr, code, false)
.await
.unwrap();
// this should fail as it is an invalid password
assert_eq!(res, None);
}