mirror of
https://github.com/chatmail/core.git
synced 2026-04-27 02:16:29 +03:00
feat: improve internal sql interface
Switches from rusqlite to sqlx to have a fully async based interface to sqlite. Co-authored-by: B. Petersen <r10s@b44t.com> Co-authored-by: Hocuri <hocuri@gmx.de> Co-authored-by: link2xt <link2xt@testrun.org>
This commit is contained in:
committed by
dignifiedquire
parent
4dedc2d8ce
commit
6bb5721f29
110
src/oauth2.rs
110
src/oauth2.rs
@@ -2,6 +2,7 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Result;
|
||||
use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
|
||||
use serde::Deserialize;
|
||||
|
||||
@@ -58,11 +59,7 @@ pub async fn dc_get_oauth2_url(
|
||||
if let Some(oauth2) = Oauth2::from_address(addr).await {
|
||||
if context
|
||||
.sql
|
||||
.set_raw_config(
|
||||
context,
|
||||
"oauth2_pending_redirect_uri",
|
||||
Some(redirect_uri.as_ref()),
|
||||
)
|
||||
.set_raw_config("oauth2_pending_redirect_uri", Some(redirect_uri.as_ref()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
@@ -82,31 +79,25 @@ pub async fn dc_get_oauth2_access_token(
|
||||
addr: impl AsRef<str>,
|
||||
code: impl AsRef<str>,
|
||||
regenerate: bool,
|
||||
) -> Option<String> {
|
||||
) -> Result<Option<String>> {
|
||||
if let Some(oauth2) = Oauth2::from_address(addr).await {
|
||||
let lock = context.oauth2_mutex.lock().await;
|
||||
|
||||
// read generated token
|
||||
if !regenerate && !is_expired(context).await {
|
||||
let access_token = context
|
||||
.sql
|
||||
.get_raw_config(context, "oauth2_access_token")
|
||||
.await;
|
||||
if !regenerate && !is_expired(context).await? {
|
||||
let access_token = context.sql.get_raw_config("oauth2_access_token").await?;
|
||||
if access_token.is_some() {
|
||||
// success
|
||||
return access_token;
|
||||
return Ok(access_token);
|
||||
}
|
||||
}
|
||||
|
||||
// generate new token: build & call auth url
|
||||
let refresh_token = context
|
||||
.sql
|
||||
.get_raw_config(context, "oauth2_refresh_token")
|
||||
.await;
|
||||
let refresh_token = context.sql.get_raw_config("oauth2_refresh_token").await?;
|
||||
let refresh_token_for = context
|
||||
.sql
|
||||
.get_raw_config(context, "oauth2_refresh_token_for")
|
||||
.await
|
||||
.get_raw_config("oauth2_refresh_token_for")
|
||||
.await?
|
||||
.unwrap_or_else(|| "unset".into());
|
||||
|
||||
let (redirect_uri, token_url, update_redirect_uri_on_success) =
|
||||
@@ -115,8 +106,8 @@ pub async fn dc_get_oauth2_access_token(
|
||||
(
|
||||
context
|
||||
.sql
|
||||
.get_raw_config(context, "oauth2_pending_redirect_uri")
|
||||
.await
|
||||
.get_raw_config("oauth2_pending_redirect_uri")
|
||||
.await?
|
||||
.unwrap_or_else(|| "unset".into()),
|
||||
oauth2.init_token,
|
||||
true,
|
||||
@@ -129,8 +120,8 @@ pub async fn dc_get_oauth2_access_token(
|
||||
(
|
||||
context
|
||||
.sql
|
||||
.get_raw_config(context, "oauth2_redirect_uri")
|
||||
.await
|
||||
.get_raw_config("oauth2_redirect_uri")
|
||||
.await?
|
||||
.unwrap_or_else(|| "unset".into()),
|
||||
oauth2.refresh_token,
|
||||
false,
|
||||
@@ -166,7 +157,7 @@ pub async fn dc_get_oauth2_access_token(
|
||||
let mut req = surf::post(post_url).build();
|
||||
if let Err(err) = req.body_form(&post_param) {
|
||||
warn!(context, "Error calling OAuth2 at {}: {:?}", token_url, err);
|
||||
return None;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let client = surf::Client::new();
|
||||
@@ -176,7 +167,7 @@ pub async fn dc_get_oauth2_access_token(
|
||||
context,
|
||||
"Failed to parse OAuth2 JSON response from {}: error: {:?}", token_url, parsed
|
||||
);
|
||||
return None;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// update refresh_token if given, typically on the first round, but we update it later as well.
|
||||
@@ -184,14 +175,12 @@ pub async fn dc_get_oauth2_access_token(
|
||||
if let Some(ref token) = response.refresh_token {
|
||||
context
|
||||
.sql
|
||||
.set_raw_config(context, "oauth2_refresh_token", Some(token))
|
||||
.await
|
||||
.ok();
|
||||
.set_raw_config("oauth2_refresh_token", Some(token))
|
||||
.await?;
|
||||
context
|
||||
.sql
|
||||
.set_raw_config(context, "oauth2_refresh_token_for", Some(code.as_ref()))
|
||||
.await
|
||||
.ok();
|
||||
.set_raw_config("oauth2_refresh_token_for", Some(code.as_ref()))
|
||||
.await?;
|
||||
}
|
||||
|
||||
// after that, save the access token.
|
||||
@@ -199,9 +188,8 @@ pub async fn dc_get_oauth2_access_token(
|
||||
if let Some(ref token) = response.access_token {
|
||||
context
|
||||
.sql
|
||||
.set_raw_config(context, "oauth2_access_token", Some(token))
|
||||
.await
|
||||
.ok();
|
||||
.set_raw_config("oauth2_access_token", Some(token))
|
||||
.await?;
|
||||
let expires_in = response
|
||||
.expires_in
|
||||
// refresh a bit before
|
||||
@@ -209,16 +197,14 @@ pub async fn dc_get_oauth2_access_token(
|
||||
.unwrap_or_else(|| 0);
|
||||
context
|
||||
.sql
|
||||
.set_raw_config_int64(context, "oauth2_timestamp_expires", expires_in)
|
||||
.await
|
||||
.ok();
|
||||
.set_raw_config_int64("oauth2_timestamp_expires", expires_in)
|
||||
.await?;
|
||||
|
||||
if update_redirect_uri_on_success {
|
||||
context
|
||||
.sql
|
||||
.set_raw_config(context, "oauth2_redirect_uri", Some(redirect_uri.as_ref()))
|
||||
.await
|
||||
.ok();
|
||||
.set_raw_config("oauth2_redirect_uri", Some(redirect_uri.as_ref()))
|
||||
.await?;
|
||||
}
|
||||
} else {
|
||||
warn!(context, "Failed to find OAuth2 access token");
|
||||
@@ -226,11 +212,11 @@ pub async fn dc_get_oauth2_access_token(
|
||||
|
||||
drop(lock);
|
||||
|
||||
response.access_token
|
||||
Ok(response.access_token)
|
||||
} else {
|
||||
warn!(context, "Internal OAuth2 error: 2");
|
||||
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,27 +224,33 @@ pub async fn dc_get_oauth2_addr(
|
||||
context: &Context,
|
||||
addr: impl AsRef<str>,
|
||||
code: impl AsRef<str>,
|
||||
) -> Option<String> {
|
||||
let oauth2 = Oauth2::from_address(addr.as_ref()).await?;
|
||||
oauth2.get_userinfo?;
|
||||
) -> Result<Option<String>> {
|
||||
let oauth2 = match Oauth2::from_address(addr.as_ref()).await {
|
||||
Some(o) => o,
|
||||
None => return Ok(None),
|
||||
};
|
||||
if oauth2.get_userinfo.is_none() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
if let Some(access_token) =
|
||||
dc_get_oauth2_access_token(context, addr.as_ref(), code.as_ref(), false).await
|
||||
dc_get_oauth2_access_token(context, addr.as_ref(), code.as_ref(), false).await?
|
||||
{
|
||||
let addr_out = oauth2.get_addr(context, access_token).await;
|
||||
if addr_out.is_none() {
|
||||
// regenerate
|
||||
if let Some(access_token) = dc_get_oauth2_access_token(context, addr, code, true).await
|
||||
if let Some(access_token) =
|
||||
dc_get_oauth2_access_token(context, addr, code, true).await?
|
||||
{
|
||||
oauth2.get_addr(context, access_token).await
|
||||
Ok(oauth2.get_addr(context, access_token).await)
|
||||
} else {
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
addr_out
|
||||
Ok(addr_out)
|
||||
}
|
||||
} else {
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,21 +309,21 @@ impl Oauth2 {
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_expired(context: &Context) -> bool {
|
||||
async fn is_expired(context: &Context) -> Result<bool, crate::sql::Error> {
|
||||
let expire_timestamp = context
|
||||
.sql
|
||||
.get_raw_config_int64(context, "oauth2_timestamp_expires")
|
||||
.await
|
||||
.get_raw_config_int64("oauth2_timestamp_expires")
|
||||
.await?
|
||||
.unwrap_or_default();
|
||||
|
||||
if expire_timestamp <= 0 {
|
||||
return false;
|
||||
return Ok(false);
|
||||
}
|
||||
if expire_timestamp > time() {
|
||||
return false;
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
true
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
fn replace_in_uri(uri: impl AsRef<str>, key: impl AsRef<str>, value: impl AsRef<str>) -> String {
|
||||
@@ -399,7 +391,7 @@ mod tests {
|
||||
let ctx = TestContext::new().await;
|
||||
let addr = "dignifiedquire@gmail.com";
|
||||
let code = "fail";
|
||||
let res = dc_get_oauth2_addr(&ctx.ctx, addr, code).await;
|
||||
let res = dc_get_oauth2_addr(&ctx.ctx, addr, code).await.unwrap();
|
||||
// this should fail as it is an invalid password
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
@@ -419,7 +411,9 @@ mod tests {
|
||||
let ctx = TestContext::new().await;
|
||||
let addr = "dignifiedquire@gmail.com";
|
||||
let code = "fail";
|
||||
let res = dc_get_oauth2_access_token(&ctx.ctx, addr, code, false).await;
|
||||
let res = dc_get_oauth2_access_token(&ctx.ctx, addr, code, false)
|
||||
.await
|
||||
.unwrap();
|
||||
// this should fail as it is an invalid password
|
||||
assert_eq!(res, None);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user