diff --git a/src/sql.rs b/src/sql.rs index 02c709bad..af196aa72 100644 --- a/src/sql.rs +++ b/src/sql.rs @@ -325,7 +325,8 @@ impl Sql { let mut lock = self.pool.write().await; let pool = lock.take().context("SQL connection pool is not open")?; - let conn = pool.get().await?; + let query_only = false; + let conn = pool.get(query_only).await?; if !passphrase.is_empty() { conn.pragma_update(None, "rekey", passphrase.clone()) .context("Failed to set PRAGMA rekey")?; @@ -382,14 +383,14 @@ impl Sql { /// - or use `call_write()` instead. /// /// Returns the result of the function. - async fn call<'a, F, R>(&'a self, function: F) -> Result + async fn call<'a, F, R>(&'a self, query_only: bool, function: F) -> Result where F: 'a + FnOnce(&mut Connection) -> Result + Send, R: Send + 'static, { let lock = self.pool.read().await; let pool = lock.as_ref().context("no SQL connection")?; - let mut conn = pool.get().await?; + let mut conn = pool.get(query_only).await?; let res = tokio::task::block_in_place(move || function(&mut conn))?; Ok(res) } @@ -404,7 +405,8 @@ impl Sql { R: Send + 'static, { let _lock = self.write_lock().await; - self.call(function).await + let query_only = false; + self.call(query_only, function).await } /// Execute `query` assuming it is a write query, returning the number of affected rows. @@ -444,7 +446,8 @@ impl Sql { G: Send + FnMut(rusqlite::MappedRows) -> Result, H: Send + 'static, { - self.call(move |conn| { + let query_only = true; + self.call(query_only, move |conn| { let mut stmt = conn.prepare(sql)?; let res = stmt.query_map(params, f)?; g(res) @@ -476,7 +479,8 @@ impl Sql { F: FnOnce(&rusqlite::Row) -> rusqlite::Result + Send, T: Send + 'static, { - self.call(move |conn| { + let query_only = true; + self.call(query_only, move |conn| { let res = conn.query_row(query, params, f)?; Ok(res) }) @@ -512,7 +516,8 @@ impl Sql { /// Query the database if the requested table already exists. pub async fn table_exists(&self, name: &str) -> Result { - self.call(move |conn| { + let query_only = true; + self.call(query_only, move |conn| { let mut exists = false; conn.pragma(None, "table_info", name.to_string(), |_row| { // will only be executed if the info was found @@ -527,7 +532,8 @@ impl Sql { /// Check if a column exists in a given table. pub async fn col_exists(&self, table_name: &str, col_name: &str) -> Result { - self.call(move |conn| { + let query_only = true; + self.call(query_only, move |conn| { let mut exists = false; // `PRAGMA table_info` returns one row per column, // each row containing 0=cid, 1=name, 2=type, 3=notnull, 4=dflt_value @@ -555,10 +561,13 @@ impl Sql { F: Send + FnOnce(&rusqlite::Row) -> rusqlite::Result, T: Send + 'static, { - self.call(move |conn| match conn.query_row(sql.as_ref(), params, f) { - Ok(res) => Ok(Some(res)), - Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), - Err(err) => Err(err.into()), + let query_only = true; + self.call(query_only, move |conn| { + match conn.query_row(sql.as_ref(), params, f) { + Ok(res) => Ok(Some(res)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(err) => Err(err.into()), + } }) .await } @@ -1092,9 +1101,10 @@ mod tests { async fn test_auto_vacuum() -> Result<()> { let t = TestContext::new().await; + let query_only = true; let auto_vacuum = t .sql - .call(|conn| { + .call(query_only, |conn| { let auto_vacuum = conn.pragma_query_value(None, "auto_vacuum", |row| { let auto_vacuum: i32 = row.get(0)?; Ok(auto_vacuum) @@ -1320,8 +1330,9 @@ mod tests { { let lock = sql.pool.read().await; let pool = lock.as_ref().unwrap(); - let conn1 = pool.get().await?; - let conn2 = pool.get().await?; + let query_only = true; + let conn1 = pool.get(query_only).await?; + let conn2 = pool.get(query_only).await?; conn1 .query_row("SELECT count(*) FROM sqlite_master", [], |_row| Ok(())) .unwrap(); @@ -1346,4 +1357,62 @@ mod tests { Ok(()) } + + #[tokio::test(flavor = "multi_thread", worker_threads = 2)] + async fn test_query_only() -> Result<()> { + let t = TestContext::new().await; + + // `query_row` does not acquire write lock + // and operates on read-only connection. + // Using it to `INSERT` should fail. + let res = t + .sql + .query_row( + "INSERT INTO config (keyname, value) VALUES (?, ?) RETURNING 1", + ("xyz", "ijk"), + |row| { + let res: u32 = row.get(0)?; + Ok(res) + }, + ) + .await; + assert!(res.is_err()); + + // If you want to `INSERT` and get value via `RETURNING`, + // use `call_write` or `transaction`. + + let res: Result = t + .sql + .call_write(|conn| { + let val = conn.query_row( + "INSERT INTO config (keyname, value) VALUES (?, ?) RETURNING 2", + ("foo", "bar"), + |row| { + let res: u32 = row.get(0)?; + Ok(res) + }, + )?; + Ok(val) + }) + .await; + assert_eq!(res.unwrap(), 2); + + let res = t + .sql + .transaction(|t| { + let val = t.query_row( + "INSERT INTO config (keyname, value) VALUES (?, ?) RETURNING 3", + ("abc", "def"), + |row| { + let res: u32 = row.get(0)?; + Ok(res) + }, + )?; + Ok(val) + }) + .await; + assert_eq!(res.unwrap(), 3); + + Ok(()) + } } diff --git a/src/sql/pool.rs b/src/sql/pool.rs index b7459976a..b3342e25d 100644 --- a/src/sql/pool.rs +++ b/src/sql/pool.rs @@ -93,7 +93,13 @@ impl Pool { } /// Retrieves a connection from the pool. - pub async fn get(&self) -> Result { + /// + /// Sets `query_only` pragma to the provided value + /// to prevent accidentaly misuse of connection + /// for writing when reading is intended. + /// Only pass `query_only=false` if you want + /// to use the connection for writing. + pub async fn get(&self, query_only: bool) -> Result { let permit = self.inner.semaphore.clone().acquire_owned().await?; let mut connections = self.inner.connections.lock(); let conn = connections @@ -104,6 +110,15 @@ impl Pool { conn: Some(conn), _permit: permit, }; + conn.pragma_update( + None, + "query_only", + if query_only { + "1".to_string() + } else { + "0".to_string() + }, + )?; Ok(conn) } }