sql: replace .get_conn() interface with .call()

.call() interface is safer because it ensures
that blocking operations on SQL connection
are called within tokio::task::block_in_place().

Previously some code called blocking operations
in async context, e.g. add_parts() in receive_imf module.

The underlying implementation of .call()
can later be replaced with an implementation
that does not require block_in_place(),
e.g. a worker pool,
without changing the code using the .call() interface.
This commit is contained in:
link2xt
2023-02-18 13:50:44 +00:00
parent 710cec1beb
commit 92c7cc40d4
8 changed files with 189 additions and 171 deletions

View File

@@ -907,7 +907,8 @@ impl ChatId {
async fn parent_query<T, F>(self, context: &Context, fields: &str, f: F) -> Result<Option<T>>
where
F: FnOnce(&rusqlite::Row) -> rusqlite::Result<T>,
F: Send + FnOnce(&rusqlite::Row) -> rusqlite::Result<T>,
T: Send + 'static,
{
let sql = &context.sql;
let query = format!(

View File

@@ -540,25 +540,27 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res
.to_str()
.with_context(|| format!("path {temp_db_path:?} is not valid unicode"))?;
let conn = context.sql.get_conn().await?;
tokio::task::block_in_place(move || {
if let Err(err) = conn.execute("VACUUM", params![]) {
info!(context, "Vacuum failed, exporting anyway: {:#}.", err);
}
conn.execute(
"ATTACH DATABASE ? AS backup KEY ?",
paramsv![path_str, passphrase],
)
.context("failed to attach backup database")?;
let res = conn
.query_row("SELECT sqlcipher_export('backup')", [], |_row| Ok(()))
.context("failed to export to attached backup database");
conn.execute("DETACH DATABASE backup", [])
.context("failed to detach backup database")?;
res?;
context
.sql
.call(|conn| {
if let Err(err) = conn.execute("VACUUM", params![]) {
info!(context, "Vacuum failed, exporting anyway: {:#}.", err);
}
conn.execute(
"ATTACH DATABASE ? AS backup KEY ?",
paramsv![path_str, passphrase],
)
.context("failed to attach backup database")?;
let res = conn
.query_row("SELECT sqlcipher_export('backup')", [], |_row| Ok(()))
.context("failed to export to attached backup database");
conn.execute("DETACH DATABASE backup", [])
.context("failed to detach backup database")?;
res?;
Ok::<_, Error>(())
})?;
Ok::<_, Error>(())
})
.await?;
let res = export_backup_inner(context, &temp_db_path, &temp_path).await;

View File

@@ -289,39 +289,41 @@ pub async fn store_self_keypair(
keypair: &KeyPair,
default: KeyPairUse,
) -> Result<()> {
let mut conn = context.sql.get_conn().await?;
let transaction = conn.transaction()?;
context
.sql
.transaction(|transaction| {
let public_key = DcKey::to_bytes(&keypair.public);
let secret_key = DcKey::to_bytes(&keypair.secret);
transaction
.execute(
"DELETE FROM keypairs WHERE public_key=? OR private_key=?;",
paramsv![public_key, secret_key],
)
.context("failed to remove old use of key")?;
if default == KeyPairUse::Default {
transaction
.execute("UPDATE keypairs SET is_default=0;", paramsv![])
.context("failed to clear default")?;
}
let is_default = match default {
KeyPairUse::Default => i32::from(true),
KeyPairUse::ReadOnly => i32::from(false),
};
let public_key = DcKey::to_bytes(&keypair.public);
let secret_key = DcKey::to_bytes(&keypair.secret);
transaction
.execute(
"DELETE FROM keypairs WHERE public_key=? OR private_key=?;",
paramsv![public_key, secret_key],
)
.context("failed to remove old use of key")?;
if default == KeyPairUse::Default {
transaction
.execute("UPDATE keypairs SET is_default=0;", paramsv![])
.context("failed to clear default")?;
}
let is_default = match default {
KeyPairUse::Default => i32::from(true),
KeyPairUse::ReadOnly => i32::from(false),
};
let addr = keypair.addr.to_string();
let t = time();
let addr = keypair.addr.to_string();
let t = time();
transaction
.execute(
"INSERT INTO keypairs (addr, is_default, public_key, private_key, created)
transaction
.execute(
"INSERT INTO keypairs (addr, is_default, public_key, private_key, created)
VALUES (?,?,?,?,?);",
paramsv![addr, is_default, public_key, secret_key, t],
)
.context("failed to insert keypair")?;
paramsv![addr, is_default, public_key, secret_key, t],
)
.context("failed to insert keypair")?;
transaction.commit()?;
Ok(())
})
.await?;
Ok(())
}

View File

@@ -601,32 +601,38 @@ pub(crate) async fn save(
..
} = location;
let conn = context.sql.get_conn().await?;
let mut stmt_test =
conn.prepare_cached("SELECT id FROM locations WHERE timestamp=? AND from_id=?")?;
let mut stmt_insert = conn.prepare_cached(stmt_insert)?;
context
.sql
.call(|conn| {
let mut stmt_test = conn
.prepare_cached("SELECT id FROM locations WHERE timestamp=? AND from_id=?")?;
let mut stmt_insert = conn.prepare_cached(stmt_insert)?;
let exists = stmt_test.exists(paramsv![timestamp, contact_id])?;
let exists = stmt_test.exists(paramsv![timestamp, contact_id])?;
if independent || !exists {
stmt_insert.execute(paramsv![
timestamp,
contact_id,
chat_id,
latitude,
longitude,
accuracy,
independent,
])?;
if independent || !exists {
stmt_insert.execute(paramsv![
timestamp,
contact_id,
chat_id,
latitude,
longitude,
accuracy,
independent,
])?;
if timestamp > newest_timestamp {
// okay to drop, as we use cached prepared statements
drop(stmt_test);
drop(stmt_insert);
newest_timestamp = timestamp;
newest_location_id = Some(u32::try_from(conn.last_insert_rowid())?);
}
}
if timestamp > newest_timestamp {
// okay to drop, as we use cached prepared statements
drop(stmt_test);
drop(stmt_insert);
newest_timestamp = timestamp;
newest_location_id = Some(u32::try_from(conn.last_insert_rowid())?);
}
}
Ok(())
})
.await?;
}
Ok(newest_location_id)

View File

@@ -186,7 +186,7 @@ impl Peerstate {
async fn from_stmt(
context: &Context,
query: &str,
params: impl rusqlite::Params,
params: impl rusqlite::Params + Send,
) -> Result<Option<Peerstate>> {
let peerstate = context
.sql

View File

@@ -1085,8 +1085,6 @@ async fn add_parts(
let mut created_db_entries = Vec::with_capacity(mime_parser.parts.len());
let conn = context.sql.get_conn().await?;
for part in &mime_parser.parts {
if part.is_reaction {
set_msg_reaction(
@@ -1118,39 +1116,6 @@ async fn add_parts(
}
let mut txt_raw = "".to_string();
let mut stmt = conn.prepare_cached(
r#"
INSERT INTO msgs
(
id,
rfc724_mid, chat_id,
from_id, to_id, timestamp, timestamp_sent,
timestamp_rcvd, type, state, msgrmsg,
txt, subject, txt_raw, param,
bytes, mime_headers, mime_in_reply_to,
mime_references, mime_modified, error, ephemeral_timer,
ephemeral_timestamp, download_state, hop_info
)
VALUES (
?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?
)
ON CONFLICT (id) DO UPDATE
SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id,
from_id=excluded.from_id, to_id=excluded.to_id, timestamp=excluded.timestamp, timestamp_sent=excluded.timestamp_sent,
timestamp_rcvd=excluded.timestamp_rcvd, type=excluded.type, state=excluded.state, msgrmsg=excluded.msgrmsg,
txt=excluded.txt, subject=excluded.subject, txt_raw=excluded.txt_raw, param=excluded.param,
bytes=excluded.bytes, mime_headers=excluded.mime_headers, mime_in_reply_to=excluded.mime_in_reply_to,
mime_references=excluded.mime_references, mime_modified=excluded.mime_modified, error=excluded.error, ephemeral_timer=excluded.ephemeral_timer,
ephemeral_timestamp=excluded.ephemeral_timestamp, download_state=excluded.download_state, hop_info=excluded.hop_info
"#,
)?;
let (msg, typ): (&str, Viewtype) = if let Some(better_msg) = &better_msg {
(better_msg, Viewtype::Text)
} else {
@@ -1184,7 +1149,38 @@ SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id,
// also change `MsgId::trash()` and `delete_expired_messages()`
let trash = chat_id.is_trash() || (is_location_kml && msg.is_empty());
stmt.execute(paramsv![
let row_id = context.sql.insert(
r#"
INSERT INTO msgs
(
id,
rfc724_mid, chat_id,
from_id, to_id, timestamp, timestamp_sent,
timestamp_rcvd, type, state, msgrmsg,
txt, subject, txt_raw, param,
bytes, mime_headers, mime_in_reply_to,
mime_references, mime_modified, error, ephemeral_timer,
ephemeral_timestamp, download_state, hop_info
)
VALUES (
?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?,
?, ?, ?, ?
)
ON CONFLICT (id) DO UPDATE
SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id,
from_id=excluded.from_id, to_id=excluded.to_id, timestamp=excluded.timestamp, timestamp_sent=excluded.timestamp_sent,
timestamp_rcvd=excluded.timestamp_rcvd, type=excluded.type, state=excluded.state, msgrmsg=excluded.msgrmsg,
txt=excluded.txt, subject=excluded.subject, txt_raw=excluded.txt_raw, param=excluded.param,
bytes=excluded.bytes, mime_headers=excluded.mime_headers, mime_in_reply_to=excluded.mime_in_reply_to,
mime_references=excluded.mime_references, mime_modified=excluded.mime_modified, error=excluded.error, ephemeral_timer=excluded.ephemeral_timer,
ephemeral_timestamp=excluded.ephemeral_timestamp, download_state=excluded.download_state, hop_info=excluded.hop_info
"#,
paramsv![
replace_msg_id,
rfc724_mid,
if trash { DC_CHAT_ID_TRASH } else { chat_id },
@@ -1223,17 +1219,14 @@ SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id,
DownloadState::Done
},
mime_parser.hop_info
])?;
]).await?;
// We only replace placeholder with a first part,
// afterwards insert additional parts.
replace_msg_id = None;
let row_id = conn.last_insert_rowid();
drop(stmt);
created_db_entries.push(MsgId::new(u32::try_from(row_id)?));
}
drop(conn);
// check all parts whether they contain a new logging webxdc
for (part, msg_id) in mime_parser.parts.iter().zip(&created_db_entries) {

View File

@@ -2,8 +2,7 @@
use std::collections::{HashMap, HashSet};
use std::convert::TryFrom;
use std::path::Path;
use std::path::PathBuf;
use std::path::{Path, PathBuf};
use anyhow::{bail, Context as _, Result};
use rusqlite::{self, config::DbConfig, Connection, OpenFlags, TransactionBehavior};
@@ -49,7 +48,7 @@ pub(crate) fn params_iter(iter: &[impl crate::ToSql]) -> impl Iterator<Item = &d
mod migrations;
mod pool;
use pool::{Pool, PooledConnection};
use pool::Pool;
/// A wrapper around the underlying Sqlite3 object.
#[derive(Debug)]
@@ -128,10 +127,9 @@ impl Sql {
pub(crate) async fn import(&self, path: &Path, passphrase: String) -> Result<()> {
let path_str = path
.to_str()
.with_context(|| format!("path {path:?} is not valid unicode"))?;
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
.with_context(|| format!("path {path:?} is not valid unicode"))?
.to_string();
self.call(move |conn| {
// Check that backup passphrase is correct before resetting our database.
conn.execute(
"ATTACH DATABASE ? AS backup KEY ?",
@@ -167,6 +165,7 @@ impl Sql {
res?;
Ok(())
})
.await
}
/// Creates a new connection pool.
@@ -294,22 +293,41 @@ impl Sql {
}
}
/// Allocates a connection and calls given function with the connection.
///
/// Returns the result of the function.
pub async fn call<'a, F, R>(&'a self, function: F) -> Result<R>
where
F: 'a + FnOnce(&mut Connection) -> Result<R> + Send,
R: Send + 'static,
{
let lock = self.pool.read().await;
let pool = lock.as_ref().context("no SQL connection")?;
let mut conn = pool.get().await?;
let res = tokio::task::block_in_place(move || function(&mut conn))?;
Ok(res)
}
/// Execute the given query, returning the number of affected rows.
pub async fn execute(&self, query: &str, params: impl rusqlite::Params) -> Result<usize> {
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
pub async fn execute(
&self,
query: &str,
params: impl rusqlite::Params + Send,
) -> Result<usize> {
self.call(move |conn| {
let res = conn.execute(query, params)?;
Ok(res)
})
.await
}
/// Executes the given query, returning the last inserted row ID.
pub async fn insert(&self, query: &str, params: impl rusqlite::Params) -> Result<i64> {
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
pub async fn insert(&self, query: &str, params: impl rusqlite::Params + Send) -> Result<i64> {
self.call(move |conn| {
conn.execute(query, params)?;
Ok(conn.last_insert_rowid())
})
.await
}
/// Prepares and executes the statement and maps a function over the resulting rows.
@@ -318,40 +336,32 @@ impl Sql {
pub async fn query_map<T, F, G, H>(
&self,
sql: &str,
params: impl rusqlite::Params,
params: impl rusqlite::Params + Send,
f: F,
mut g: G,
) -> Result<H>
where
F: FnMut(&rusqlite::Row) -> rusqlite::Result<T>,
G: FnMut(rusqlite::MappedRows<F>) -> Result<H>,
F: Send + FnMut(&rusqlite::Row) -> rusqlite::Result<T>,
G: Send + FnMut(rusqlite::MappedRows<F>) -> Result<H>,
H: Send + 'static,
{
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
self.call(move |conn| {
let mut stmt = conn.prepare(sql)?;
let res = stmt.query_map(params, f)?;
g(res)
})
}
/// Allocates a connection from the connection pool and returns it.
pub(crate) async fn get_conn(&self) -> Result<PooledConnection> {
let lock = self.pool.read().await;
let pool = lock.as_ref().context("no SQL connection")?;
let conn = pool.get().await?;
Ok(conn)
.await
}
/// Used for executing `SELECT COUNT` statements only. Returns the resulting count.
pub async fn count(&self, query: &str, params: impl rusqlite::Params) -> Result<usize> {
pub async fn count(&self, query: &str, params: impl rusqlite::Params + Send) -> Result<usize> {
let count: isize = self.query_row(query, params, |row| row.get(0)).await?;
Ok(usize::try_from(count)?)
}
/// Used for executing `SELECT COUNT` statements only. Returns `true`, if the count is at least
/// one, `false` otherwise.
pub async fn exists(&self, sql: &str, params: impl rusqlite::Params) -> Result<bool> {
pub async fn exists(&self, sql: &str, params: impl rusqlite::Params + Send) -> Result<bool> {
let count = self.count(sql, params).await?;
Ok(count > 0)
}
@@ -360,17 +370,18 @@ impl Sql {
pub async fn query_row<T, F>(
&self,
query: &str,
params: impl rusqlite::Params,
params: impl rusqlite::Params + Send,
f: F,
) -> Result<T>
where
F: FnOnce(&rusqlite::Row) -> rusqlite::Result<T>,
F: FnOnce(&rusqlite::Row) -> rusqlite::Result<T> + Send,
T: Send + 'static,
{
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
self.call(move |conn| {
let res = conn.query_row(query, params, f)?;
Ok(res)
})
.await
}
/// Execute the function inside a transaction.
@@ -388,8 +399,7 @@ impl Sql {
H: Send + 'static,
G: Send + FnOnce(&mut rusqlite::Transaction<'_>) -> Result<H>,
{
let mut conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
self.call(move |conn| {
let mut transaction = conn.transaction_with_behavior(TransactionBehavior::Immediate)?;
let ret = callback(&mut transaction);
@@ -404,12 +414,12 @@ impl Sql {
}
}
})
.await
}
/// Query the database if the requested table already exists.
pub async fn table_exists(&self, name: &str) -> Result<bool> {
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
self.call(move |conn| {
let mut exists = false;
conn.pragma(None, "table_info", name.to_string(), |_row| {
// will only be executed if the info was found
@@ -419,12 +429,12 @@ impl Sql {
Ok(exists)
})
.await
}
/// Check if a column exists in a given table.
pub async fn col_exists(&self, table_name: &str, col_name: &str) -> Result<bool> {
let conn = self.get_conn().await?;
tokio::task::block_in_place(move || {
self.call(move |conn| {
let mut exists = false;
// `PRAGMA table_info` returns one row per column,
// each row containing 0=cid, 1=name, 2=type, 3=notnull, 4=dflt_value
@@ -438,29 +448,27 @@ impl Sql {
Ok(exists)
})
.await
}
/// Execute a query which is expected to return zero or one row.
pub async fn query_row_optional<T, F>(
&self,
sql: &str,
params: impl rusqlite::Params,
params: impl rusqlite::Params + Send,
f: F,
) -> Result<Option<T>>
where
F: FnOnce(&rusqlite::Row) -> rusqlite::Result<T>,
F: Send + FnOnce(&rusqlite::Row) -> rusqlite::Result<T>,
T: Send + 'static,
{
let conn = self.get_conn().await?;
let res =
tokio::task::block_in_place(move || match conn.query_row(sql.as_ref(), params, f) {
Ok(res) => Ok(Some(res)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(rusqlite::Error::InvalidColumnType(_, _, rusqlite::types::Type::Null)) => {
Ok(None)
}
Err(err) => Err(err),
})?;
Ok(res)
self.call(move |conn| match conn.query_row(sql.as_ref(), params, f) {
Ok(res) => Ok(Some(res)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(rusqlite::Error::InvalidColumnType(_, _, rusqlite::types::Type::Null)) => Ok(None),
Err(err) => Err(err.into()),
})
.await
}
/// Executes a query which is expected to return one row and one
@@ -469,10 +477,10 @@ impl Sql {
pub async fn query_get_value<T>(
&self,
query: &str,
params: impl rusqlite::Params,
params: impl rusqlite::Params + Send,
) -> Result<Option<T>>
where
T: rusqlite::types::FromSql,
T: rusqlite::types::FromSql + Send + 'static,
{
self.query_row_optional(query, params, |row| row.get::<_, T>(0))
.await
@@ -935,11 +943,16 @@ mod tests {
async fn test_auto_vacuum() -> Result<()> {
let t = TestContext::new().await;
let conn = t.sql.get_conn().await?;
let auto_vacuum = conn.pragma_query_value(None, "auto_vacuum", |row| {
let auto_vacuum: i32 = row.get(0)?;
Ok(auto_vacuum)
})?;
let auto_vacuum = t
.sql
.call(|conn| {
let auto_vacuum = conn.pragma_query_value(None, "auto_vacuum", |row| {
let auto_vacuum: i32 = row.get(0)?;
Ok(auto_vacuum)
})?;
Ok(auto_vacuum)
})
.await?;
// auto_vacuum=2 is the same as auto_vacuum=INCREMENTAL
assert_eq!(auto_vacuum, 2);