fix(query_row_optional): do not treat rows with NULL as missing rows

Instead of treating NULL type error
as absence of the row,
handle NULL values with SQL.
Previously we sometimes
accidentally treated a single column
being NULL as the lack of the whole row.
This commit is contained in:
link2xt
2024-10-02 02:11:40 +00:00
parent 5711f2fe3a
commit 8a88479d8f
6 changed files with 60 additions and 22 deletions

View File

@@ -640,7 +640,7 @@ pub struct MessageInfo {
error: Option<String>,
rfc724_mid: String,
server_urls: Vec<String>,
hop_info: Option<String>,
hop_info: String,
}
impl MessageInfo {

View File

@@ -1042,7 +1042,13 @@ impl ChatId {
pub(crate) async fn get_timestamp(self, context: &Context) -> Result<Option<i64>> {
let timestamp = context
.sql
.query_get_value("SELECT MAX(timestamp) FROM msgs WHERE chat_id=?", (self,))
.query_get_value(
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=?
HAVING COUNT(*) > 0",
(self,),
)
.await?;
Ok(timestamp)
}
@@ -1251,7 +1257,7 @@ impl ChatId {
) -> Result<Option<(String, String, String)>> {
self.parent_query(
context,
"rfc724_mid, mime_in_reply_to, mime_references",
"rfc724_mid, mime_in_reply_to, IFNULL(mime_references, '')",
state_out_min,
|row: &rusqlite::Row| {
let rfc724_mid: String = row.get(0)?;
@@ -1405,7 +1411,10 @@ impl ChatId {
context
.sql
.query_get_value(
"SELECT MAX(timestamp) FROM msgs WHERE chat_id=? AND state!=?",
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=? AND state!=?
HAVING COUNT(*) > 0",
(self, MessageState::OutDraft),
)
.await?
@@ -1421,7 +1430,10 @@ impl ChatId {
context
.sql
.query_get_value(
"SELECT MAX(timestamp) FROM msgs WHERE chat_id=? AND hidden=0 AND state>?",
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=? AND hidden=0 AND state>?
HAVING COUNT(*) > 0",
(self, MessageState::InFresh),
)
.await?
@@ -4345,7 +4357,10 @@ pub async fn add_device_msg_with_importance(
if let Some(last_msg_time) = context
.sql
.query_get_value(
"SELECT MAX(timestamp) FROM msgs WHERE chat_id=?",
"SELECT MAX(timestamp)
FROM msgs
WHERE chat_id=?
HAVING COUNT(*) > 0",
(chat_id,),
)
.await?

View File

@@ -69,7 +69,7 @@ use std::num::ParseIntError;
use std::str::FromStr;
use std::time::{Duration, UNIX_EPOCH};
use anyhow::{ensure, Result};
use anyhow::{ensure, Context as _, Result};
use async_channel::Receiver;
use serde::{Deserialize, Serialize};
use tokio::time::timeout;
@@ -176,9 +176,13 @@ impl ChatId {
pub async fn get_ephemeral_timer(self, context: &Context) -> Result<Timer> {
let timer = context
.sql
.query_get_value("SELECT ephemeral_timer FROM chats WHERE id=?;", (self,))
.await?;
Ok(timer.unwrap_or_default())
.query_get_value(
"SELECT IFNULL(ephemeral_timer, 0) FROM chats WHERE id=?",
(self,),
)
.await?
.with_context(|| format!("Chat {self} not found"))?;
Ok(timer)
}
/// Set ephemeral timer value without sending a message.
@@ -509,7 +513,8 @@ async fn next_delete_device_after_timestamp(context: &Context) -> Result<Option<
FROM msgs
WHERE chat_id > ?
AND chat_id != ?
AND chat_id != ?;
AND chat_id != ?
HAVING count(*) > 0
"#,
(DC_CHAT_ID_TRASH, self_chat_id, device_chat_id),
)
@@ -533,7 +538,8 @@ async fn next_expiration_timestamp(context: &Context) -> Option<i64> {
SELECT min(ephemeral_timestamp)
FROM msgs
WHERE ephemeral_timestamp != 0
AND chat_id != ?;
AND chat_id != ?
HAVING count(*) > 0
"#,
(DC_CHAT_ID_TRASH,), // Trash contains already deleted messages, skip them
)
@@ -1410,4 +1416,14 @@ mod tests {
Ok(())
}
/// Tests that `.get_ephemeral_timer()` returns an error for invalid chat ID.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_get_ephemeral_timer_wrong_chat_id() -> Result<()> {
let context = TestContext::new().await;
let chat_id = ChatId::new(12345);
assert!(chat_id.get_ephemeral_timer(&context).await.is_err());
Ok(())
}
}

View File

@@ -219,11 +219,13 @@ impl MsgId {
}
/// Returns information about hops of a message, used for message info
pub async fn hop_info(self, context: &Context) -> Result<Option<String>> {
context
pub async fn hop_info(self, context: &Context) -> Result<String> {
let hop_info = context
.sql
.query_get_value("SELECT hop_info FROM msgs WHERE id=?", (self,))
.await
.query_get_value("SELECT IFNULL(hop_info, '') FROM msgs WHERE id=?", (self,))
.await?
.with_context(|| format!("Message {self} not found"))?;
Ok(hop_info)
}
/// Returns detailed message information in a multi-line text form.
@@ -366,7 +368,11 @@ impl MsgId {
let hop_info = self.hop_info(context).await?;
ret += "\n\n";
ret += &hop_info.unwrap_or_else(|| "No Hop Info".to_owned());
if hop_info.is_empty() {
ret += "No Hop Info";
} else {
ret += &hop_info;
}
Ok(ret)
}
@@ -1998,7 +2004,9 @@ pub(crate) async fn rfc724_mid_exists_ex(
.query_row_optional(
&("SELECT id, timestamp_sent, MIN(".to_string()
+ expr
+ ") FROM msgs WHERE rfc724_mid=? ORDER BY timestamp_sent DESC"),
+ ") FROM msgs WHERE rfc724_mid=?
HAVING COUNT(*) > 0 -- Prevent MIN(expr) from returning NULL when there are no rows.
ORDER BY timestamp_sent DESC"),
(rfc724_mid,),
|row| {
let msg_id: MsgId = row.get(0)?;

View File

@@ -201,7 +201,8 @@ impl MimeFactory {
let (in_reply_to, references) = context
.sql
.query_row(
"SELECT mime_in_reply_to, mime_references FROM msgs WHERE id=?",
"SELECT mime_in_reply_to, IFNULL(mime_references, '')
FROM msgs WHERE id=?",
(msg.id,),
|row| {
let in_reply_to: String = row.get(0)?;

View File

@@ -558,15 +558,13 @@ impl Sql {
self.call(move |conn| match conn.query_row(sql.as_ref(), params, f) {
Ok(res) => Ok(Some(res)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(rusqlite::Error::InvalidColumnType(_, _, rusqlite::types::Type::Null)) => Ok(None),
Err(err) => Err(err.into()),
})
.await
}
/// Executes a query which is expected to return one row and one
/// column. If the query does not return a value or returns SQL
/// `NULL`, returns `Ok(None)`.
/// column. If the query does not return any rows, returns `Ok(None)`.
pub async fn query_get_value<T>(
&self,
query: &str,