diff --git a/CHANGELOG.md b/CHANGELOG.md index 34cebb853..3abb9ca41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Organize the connection pool as a stack rather than a queue to ensure that connection page cache is reused more often. #4065 - Use transaction in `update_blocked_mailinglist_contacts`. #4058 +- Remove `Sql.get_conn()` interface in favor of `.call()` and `.transaction()`. #4055 ### Fixes - Start SQL transactions with IMMEDIATE behaviour rather than default DEFERRED one. #4063 diff --git a/src/chat.rs b/src/chat.rs index dd88cdd30..522f55aa7 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -907,7 +907,8 @@ impl ChatId { async fn parent_query(self, context: &Context, fields: &str, f: F) -> Result> where - F: FnOnce(&rusqlite::Row) -> rusqlite::Result, + F: Send + FnOnce(&rusqlite::Row) -> rusqlite::Result, + T: Send + 'static, { let sql = &context.sql; let query = format!( diff --git a/src/imex.rs b/src/imex.rs index f23496753..a3de4c201 100644 --- a/src/imex.rs +++ b/src/imex.rs @@ -540,25 +540,27 @@ async fn export_backup(context: &Context, dir: &Path, passphrase: String) -> Res .to_str() .with_context(|| format!("path {temp_db_path:?} is not valid unicode"))?; - let conn = context.sql.get_conn().await?; - tokio::task::block_in_place(move || { - if let Err(err) = conn.execute("VACUUM", params![]) { - info!(context, "Vacuum failed, exporting anyway: {:#}.", err); - } - conn.execute( - "ATTACH DATABASE ? AS backup KEY ?", - paramsv![path_str, passphrase], - ) - .context("failed to attach backup database")?; - let res = conn - .query_row("SELECT sqlcipher_export('backup')", [], |_row| Ok(())) - .context("failed to export to attached backup database"); - conn.execute("DETACH DATABASE backup", []) - .context("failed to detach backup database")?; - res?; + context + .sql + .call(|conn| { + if let Err(err) = conn.execute("VACUUM", params![]) { + info!(context, "Vacuum failed, exporting anyway: {:#}.", err); + } + conn.execute( + "ATTACH DATABASE ? AS backup KEY ?", + paramsv![path_str, passphrase], + ) + .context("failed to attach backup database")?; + let res = conn + .query_row("SELECT sqlcipher_export('backup')", [], |_row| Ok(())) + .context("failed to export to attached backup database"); + conn.execute("DETACH DATABASE backup", []) + .context("failed to detach backup database")?; + res?; - Ok::<_, Error>(()) - })?; + Ok::<_, Error>(()) + }) + .await?; let res = export_backup_inner(context, &temp_db_path, &temp_path).await; diff --git a/src/key.rs b/src/key.rs index bee59c60b..3b48fe19e 100644 --- a/src/key.rs +++ b/src/key.rs @@ -289,39 +289,41 @@ pub async fn store_self_keypair( keypair: &KeyPair, default: KeyPairUse, ) -> Result<()> { - let mut conn = context.sql.get_conn().await?; - let transaction = conn.transaction()?; + context + .sql + .transaction(|transaction| { + let public_key = DcKey::to_bytes(&keypair.public); + let secret_key = DcKey::to_bytes(&keypair.secret); + transaction + .execute( + "DELETE FROM keypairs WHERE public_key=? OR private_key=?;", + paramsv![public_key, secret_key], + ) + .context("failed to remove old use of key")?; + if default == KeyPairUse::Default { + transaction + .execute("UPDATE keypairs SET is_default=0;", paramsv![]) + .context("failed to clear default")?; + } + let is_default = match default { + KeyPairUse::Default => i32::from(true), + KeyPairUse::ReadOnly => i32::from(false), + }; - let public_key = DcKey::to_bytes(&keypair.public); - let secret_key = DcKey::to_bytes(&keypair.secret); - transaction - .execute( - "DELETE FROM keypairs WHERE public_key=? OR private_key=?;", - paramsv![public_key, secret_key], - ) - .context("failed to remove old use of key")?; - if default == KeyPairUse::Default { - transaction - .execute("UPDATE keypairs SET is_default=0;", paramsv![]) - .context("failed to clear default")?; - } - let is_default = match default { - KeyPairUse::Default => i32::from(true), - KeyPairUse::ReadOnly => i32::from(false), - }; + let addr = keypair.addr.to_string(); + let t = time(); - let addr = keypair.addr.to_string(); - let t = time(); - - transaction - .execute( - "INSERT INTO keypairs (addr, is_default, public_key, private_key, created) + transaction + .execute( + "INSERT INTO keypairs (addr, is_default, public_key, private_key, created) VALUES (?,?,?,?,?);", - paramsv![addr, is_default, public_key, secret_key, t], - ) - .context("failed to insert keypair")?; + paramsv![addr, is_default, public_key, secret_key, t], + ) + .context("failed to insert keypair")?; - transaction.commit()?; + Ok(()) + }) + .await?; Ok(()) } diff --git a/src/location.rs b/src/location.rs index 24a202c17..6698dea3b 100644 --- a/src/location.rs +++ b/src/location.rs @@ -601,32 +601,38 @@ pub(crate) async fn save( .. } = location; - let conn = context.sql.get_conn().await?; - let mut stmt_test = - conn.prepare_cached("SELECT id FROM locations WHERE timestamp=? AND from_id=?")?; - let mut stmt_insert = conn.prepare_cached(stmt_insert)?; + context + .sql + .call(|conn| { + let mut stmt_test = conn + .prepare_cached("SELECT id FROM locations WHERE timestamp=? AND from_id=?")?; + let mut stmt_insert = conn.prepare_cached(stmt_insert)?; - let exists = stmt_test.exists(paramsv![timestamp, contact_id])?; + let exists = stmt_test.exists(paramsv![timestamp, contact_id])?; - if independent || !exists { - stmt_insert.execute(paramsv![ - timestamp, - contact_id, - chat_id, - latitude, - longitude, - accuracy, - independent, - ])?; + if independent || !exists { + stmt_insert.execute(paramsv![ + timestamp, + contact_id, + chat_id, + latitude, + longitude, + accuracy, + independent, + ])?; - if timestamp > newest_timestamp { - // okay to drop, as we use cached prepared statements - drop(stmt_test); - drop(stmt_insert); - newest_timestamp = timestamp; - newest_location_id = Some(u32::try_from(conn.last_insert_rowid())?); - } - } + if timestamp > newest_timestamp { + // okay to drop, as we use cached prepared statements + drop(stmt_test); + drop(stmt_insert); + newest_timestamp = timestamp; + newest_location_id = Some(u32::try_from(conn.last_insert_rowid())?); + } + } + + Ok(()) + }) + .await?; } Ok(newest_location_id) diff --git a/src/peerstate.rs b/src/peerstate.rs index acac59a65..e7433f0f3 100644 --- a/src/peerstate.rs +++ b/src/peerstate.rs @@ -186,7 +186,7 @@ impl Peerstate { async fn from_stmt( context: &Context, query: &str, - params: impl rusqlite::Params, + params: impl rusqlite::Params + Send, ) -> Result> { let peerstate = context .sql diff --git a/src/receive_imf.rs b/src/receive_imf.rs index 1896e844b..2540d7442 100644 --- a/src/receive_imf.rs +++ b/src/receive_imf.rs @@ -1085,8 +1085,6 @@ async fn add_parts( let mut created_db_entries = Vec::with_capacity(mime_parser.parts.len()); - let conn = context.sql.get_conn().await?; - for part in &mime_parser.parts { if part.is_reaction { set_msg_reaction( @@ -1118,39 +1116,6 @@ async fn add_parts( } let mut txt_raw = "".to_string(); - let mut stmt = conn.prepare_cached( - r#" -INSERT INTO msgs - ( - id, - rfc724_mid, chat_id, - from_id, to_id, timestamp, timestamp_sent, - timestamp_rcvd, type, state, msgrmsg, - txt, subject, txt_raw, param, - bytes, mime_headers, mime_in_reply_to, - mime_references, mime_modified, error, ephemeral_timer, - ephemeral_timestamp, download_state, hop_info - ) - VALUES ( - ?, - ?, ?, ?, ?, - ?, ?, ?, ?, - ?, ?, ?, ?, - ?, ?, ?, ?, - ?, ?, ?, ?, - ?, ?, ?, ? - ) -ON CONFLICT (id) DO UPDATE -SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id, - from_id=excluded.from_id, to_id=excluded.to_id, timestamp=excluded.timestamp, timestamp_sent=excluded.timestamp_sent, - timestamp_rcvd=excluded.timestamp_rcvd, type=excluded.type, state=excluded.state, msgrmsg=excluded.msgrmsg, - txt=excluded.txt, subject=excluded.subject, txt_raw=excluded.txt_raw, param=excluded.param, - bytes=excluded.bytes, mime_headers=excluded.mime_headers, mime_in_reply_to=excluded.mime_in_reply_to, - mime_references=excluded.mime_references, mime_modified=excluded.mime_modified, error=excluded.error, ephemeral_timer=excluded.ephemeral_timer, - ephemeral_timestamp=excluded.ephemeral_timestamp, download_state=excluded.download_state, hop_info=excluded.hop_info -"#, - )?; - let (msg, typ): (&str, Viewtype) = if let Some(better_msg) = &better_msg { (better_msg, Viewtype::Text) } else { @@ -1184,7 +1149,38 @@ SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id, // also change `MsgId::trash()` and `delete_expired_messages()` let trash = chat_id.is_trash() || (is_location_kml && msg.is_empty()); - stmt.execute(paramsv![ + let row_id = context.sql.insert( + r#" +INSERT INTO msgs + ( + id, + rfc724_mid, chat_id, + from_id, to_id, timestamp, timestamp_sent, + timestamp_rcvd, type, state, msgrmsg, + txt, subject, txt_raw, param, + bytes, mime_headers, mime_in_reply_to, + mime_references, mime_modified, error, ephemeral_timer, + ephemeral_timestamp, download_state, hop_info + ) + VALUES ( + ?, + ?, ?, ?, ?, + ?, ?, ?, ?, + ?, ?, ?, ?, + ?, ?, ?, ?, + ?, ?, ?, ?, + ?, ?, ?, ? + ) +ON CONFLICT (id) DO UPDATE +SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id, + from_id=excluded.from_id, to_id=excluded.to_id, timestamp=excluded.timestamp, timestamp_sent=excluded.timestamp_sent, + timestamp_rcvd=excluded.timestamp_rcvd, type=excluded.type, state=excluded.state, msgrmsg=excluded.msgrmsg, + txt=excluded.txt, subject=excluded.subject, txt_raw=excluded.txt_raw, param=excluded.param, + bytes=excluded.bytes, mime_headers=excluded.mime_headers, mime_in_reply_to=excluded.mime_in_reply_to, + mime_references=excluded.mime_references, mime_modified=excluded.mime_modified, error=excluded.error, ephemeral_timer=excluded.ephemeral_timer, + ephemeral_timestamp=excluded.ephemeral_timestamp, download_state=excluded.download_state, hop_info=excluded.hop_info +"#, + paramsv![ replace_msg_id, rfc724_mid, if trash { DC_CHAT_ID_TRASH } else { chat_id }, @@ -1223,17 +1219,14 @@ SET rfc724_mid=excluded.rfc724_mid, chat_id=excluded.chat_id, DownloadState::Done }, mime_parser.hop_info - ])?; + ]).await?; // We only replace placeholder with a first part, // afterwards insert additional parts. replace_msg_id = None; - let row_id = conn.last_insert_rowid(); - drop(stmt); created_db_entries.push(MsgId::new(u32::try_from(row_id)?)); } - drop(conn); // check all parts whether they contain a new logging webxdc for (part, msg_id) in mime_parser.parts.iter().zip(&created_db_entries) { diff --git a/src/sql.rs b/src/sql.rs index 4b4b190eb..d96c85b66 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -2,8 +2,7 @@ use std::collections::{HashMap, HashSet}; use std::convert::TryFrom; -use std::path::Path; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use anyhow::{bail, Context as _, Result}; use rusqlite::{self, config::DbConfig, Connection, OpenFlags, TransactionBehavior}; @@ -49,7 +48,7 @@ pub(crate) fn params_iter(iter: &[impl crate::ToSql]) -> impl Iterator Result<()> { let path_str = path .to_str() - .with_context(|| format!("path {path:?} is not valid unicode"))?; - let conn = self.get_conn().await?; - - tokio::task::block_in_place(move || { + .with_context(|| format!("path {path:?} is not valid unicode"))? + .to_string(); + self.call(move |conn| { // Check that backup passphrase is correct before resetting our database. conn.execute( "ATTACH DATABASE ? AS backup KEY ?", @@ -167,6 +165,7 @@ impl Sql { res?; Ok(()) }) + .await } /// Creates a new connection pool. @@ -294,22 +293,41 @@ impl Sql { } } + /// Allocates a connection and calls given function with the connection. + /// + /// Returns the result of the function. + pub async fn call<'a, F, R>(&'a self, function: F) -> Result + where + F: 'a + FnOnce(&mut Connection) -> Result + Send, + R: Send + 'static, + { + let lock = self.pool.read().await; + let pool = lock.as_ref().context("no SQL connection")?; + let mut conn = pool.get().await?; + let res = tokio::task::block_in_place(move || function(&mut conn))?; + Ok(res) + } + /// Execute the given query, returning the number of affected rows. - pub async fn execute(&self, query: &str, params: impl rusqlite::Params) -> Result { - let conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + pub async fn execute( + &self, + query: &str, + params: impl rusqlite::Params + Send, + ) -> Result { + self.call(move |conn| { let res = conn.execute(query, params)?; Ok(res) }) + .await } /// Executes the given query, returning the last inserted row ID. - pub async fn insert(&self, query: &str, params: impl rusqlite::Params) -> Result { - let conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + pub async fn insert(&self, query: &str, params: impl rusqlite::Params + Send) -> Result { + self.call(move |conn| { conn.execute(query, params)?; Ok(conn.last_insert_rowid()) }) + .await } /// Prepares and executes the statement and maps a function over the resulting rows. @@ -318,40 +336,32 @@ impl Sql { pub async fn query_map( &self, sql: &str, - params: impl rusqlite::Params, + params: impl rusqlite::Params + Send, f: F, mut g: G, ) -> Result where - F: FnMut(&rusqlite::Row) -> rusqlite::Result, - G: FnMut(rusqlite::MappedRows) -> Result, + F: Send + FnMut(&rusqlite::Row) -> rusqlite::Result, + G: Send + FnMut(rusqlite::MappedRows) -> Result, + H: Send + 'static, { - let conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + self.call(move |conn| { let mut stmt = conn.prepare(sql)?; let res = stmt.query_map(params, f)?; g(res) }) - } - - /// Allocates a connection from the connection pool and returns it. - pub(crate) async fn get_conn(&self) -> Result { - let lock = self.pool.read().await; - let pool = lock.as_ref().context("no SQL connection")?; - let conn = pool.get().await?; - - Ok(conn) + .await } /// Used for executing `SELECT COUNT` statements only. Returns the resulting count. - pub async fn count(&self, query: &str, params: impl rusqlite::Params) -> Result { + pub async fn count(&self, query: &str, params: impl rusqlite::Params + Send) -> Result { let count: isize = self.query_row(query, params, |row| row.get(0)).await?; Ok(usize::try_from(count)?) } /// Used for executing `SELECT COUNT` statements only. Returns `true`, if the count is at least /// one, `false` otherwise. - pub async fn exists(&self, sql: &str, params: impl rusqlite::Params) -> Result { + pub async fn exists(&self, sql: &str, params: impl rusqlite::Params + Send) -> Result { let count = self.count(sql, params).await?; Ok(count > 0) } @@ -360,17 +370,18 @@ impl Sql { pub async fn query_row( &self, query: &str, - params: impl rusqlite::Params, + params: impl rusqlite::Params + Send, f: F, ) -> Result where - F: FnOnce(&rusqlite::Row) -> rusqlite::Result, + F: FnOnce(&rusqlite::Row) -> rusqlite::Result + Send, + T: Send + 'static, { - let conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + self.call(move |conn| { let res = conn.query_row(query, params, f)?; Ok(res) }) + .await } /// Execute the function inside a transaction. @@ -388,8 +399,7 @@ impl Sql { H: Send + 'static, G: Send + FnOnce(&mut rusqlite::Transaction<'_>) -> Result, { - let mut conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + self.call(move |conn| { let mut transaction = conn.transaction_with_behavior(TransactionBehavior::Immediate)?; let ret = callback(&mut transaction); @@ -404,12 +414,12 @@ impl Sql { } } }) + .await } /// Query the database if the requested table already exists. pub async fn table_exists(&self, name: &str) -> Result { - let conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + self.call(move |conn| { let mut exists = false; conn.pragma(None, "table_info", name.to_string(), |_row| { // will only be executed if the info was found @@ -419,12 +429,12 @@ impl Sql { Ok(exists) }) + .await } /// Check if a column exists in a given table. pub async fn col_exists(&self, table_name: &str, col_name: &str) -> Result { - let conn = self.get_conn().await?; - tokio::task::block_in_place(move || { + self.call(move |conn| { let mut exists = false; // `PRAGMA table_info` returns one row per column, // each row containing 0=cid, 1=name, 2=type, 3=notnull, 4=dflt_value @@ -438,29 +448,27 @@ impl Sql { Ok(exists) }) + .await } /// Execute a query which is expected to return zero or one row. pub async fn query_row_optional( &self, sql: &str, - params: impl rusqlite::Params, + params: impl rusqlite::Params + Send, f: F, ) -> Result> where - F: FnOnce(&rusqlite::Row) -> rusqlite::Result, + F: Send + FnOnce(&rusqlite::Row) -> rusqlite::Result, + T: Send + 'static, { - let conn = self.get_conn().await?; - let res = - tokio::task::block_in_place(move || match conn.query_row(sql.as_ref(), params, f) { - Ok(res) => Ok(Some(res)), - Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), - Err(rusqlite::Error::InvalidColumnType(_, _, rusqlite::types::Type::Null)) => { - Ok(None) - } - Err(err) => Err(err), - })?; - Ok(res) + self.call(move |conn| match conn.query_row(sql.as_ref(), params, f) { + Ok(res) => Ok(Some(res)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(rusqlite::Error::InvalidColumnType(_, _, rusqlite::types::Type::Null)) => Ok(None), + Err(err) => Err(err.into()), + }) + .await } /// Executes a query which is expected to return one row and one @@ -469,10 +477,10 @@ impl Sql { pub async fn query_get_value( &self, query: &str, - params: impl rusqlite::Params, + params: impl rusqlite::Params + Send, ) -> Result> where - T: rusqlite::types::FromSql, + T: rusqlite::types::FromSql + Send + 'static, { self.query_row_optional(query, params, |row| row.get::<_, T>(0)) .await @@ -935,11 +943,16 @@ mod tests { async fn test_auto_vacuum() -> Result<()> { let t = TestContext::new().await; - let conn = t.sql.get_conn().await?; - let auto_vacuum = conn.pragma_query_value(None, "auto_vacuum", |row| { - let auto_vacuum: i32 = row.get(0)?; - Ok(auto_vacuum) - })?; + let auto_vacuum = t + .sql + .call(|conn| { + let auto_vacuum = conn.pragma_query_value(None, "auto_vacuum", |row| { + let auto_vacuum: i32 = row.get(0)?; + Ok(auto_vacuum) + })?; + Ok(auto_vacuum) + }) + .await?; // auto_vacuum=2 is the same as auto_vacuum=INCREMENTAL assert_eq!(auto_vacuum, 2);