fix: replace old draft with a new one atomically

This prevents creation of multiple drafts per chat.
This commit is contained in:
link2xt
2024-10-16 04:09:55 +00:00
parent 7db7c0aab1
commit 07fa9c35ee

View File

@@ -944,12 +944,18 @@ impl ChatId {
}
}
// insert new draft
self.maybe_delete_draft(context).await?;
let row_id = context
.sql
.insert(
"INSERT INTO msgs (
.transaction(|transaction| {
// Delete existing draft if it exists.
transaction.execute(
"DELETE FROM msgs WHERE chat_id=? AND state=?",
(self, MessageState::OutDraft),
)?;
// Insert new draft.
transaction.execute(
"INSERT INTO msgs (
chat_id,
from_id,
timestamp,
@@ -961,19 +967,22 @@ impl ChatId {
hidden,
mime_in_reply_to)
VALUES (?,?,?,?,?,?,?,?,?,?);",
(
self,
ContactId::SELF,
time(),
msg.viewtype,
MessageState::OutDraft,
&msg.text,
message::normalize_text(&msg.text),
msg.param.to_string(),
1,
msg.in_reply_to.as_deref().unwrap_or_default(),
),
)
(
self,
ContactId::SELF,
time(),
msg.viewtype,
MessageState::OutDraft,
&msg.text,
message::normalize_text(&msg.text),
msg.param.to_string(),
1,
msg.in_reply_to.as_deref().unwrap_or_default(),
),
)?;
Ok(transaction.last_insert_rowid())
})
.await?;
msg.id = MsgId::new(row_id.try_into()?);
Ok(true)
@@ -4854,6 +4863,37 @@ mod tests {
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_only_one_draft_per_chat() -> Result<()> {
let t = TestContext::new_alice().await;
let chat_id = create_group_chat(&t, ProtectionStatus::Unprotected, "abc").await?;
let msgs: Vec<message::Message> = (1..=1000)
.map(|i| {
let mut msg = Message::new(Viewtype::Text);
msg.set_text(i.to_string());
msg
})
.collect();
let mut tasks = Vec::new();
for mut msg in msgs {
let ctx = t.clone();
let task = tokio::spawn(async move {
let ctx = ctx;
chat_id.set_draft(&ctx, Some(&mut msg)).await
});
tasks.push(task);
}
futures::future::join_all(tasks.into_iter()).await;
assert!(chat_id.get_draft(&t).await?.is_some());
chat_id.set_draft(&t, None).await?;
assert!(chat_id.get_draft(&t).await?.is_none());
Ok(())
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn test_change_quotes_on_reused_message_object() -> Result<()> {
let t = TestContext::new_alice().await;