diff --git a/internal/storage/sqlite3/sqlite3.go b/internal/storage/sqlite3/sqlite3.go index b7f045d..6f5a39a 100644 --- a/internal/storage/sqlite3/sqlite3.go +++ b/internal/storage/sqlite3/sqlite3.go @@ -5,6 +5,7 @@ import ( "fmt" _ "github.com/mattn/go-sqlite3" + "go.uber.org/atomic" ) type SQLite3Storage struct { @@ -12,6 +13,7 @@ type SQLite3Storage struct { *TableMailboxes *TableMails *TableQueue + writer *Writer } func NewSQLite3StorageStorage(filename string) (*SQLite3Storage, error) { @@ -19,22 +21,81 @@ func NewSQLite3StorageStorage(filename string) (*SQLite3Storage, error) { if err != nil { return nil, fmt.Errorf("sql.Open: %w", err) } - s := &SQLite3Storage{} - s.TableConfig, err = NewTableConfig(db) + s := &SQLite3Storage{ + writer: &Writer{ + todo: make(chan writerTask), + }, + } + s.TableConfig, err = NewTableConfig(db, s.writer) if err != nil { return nil, fmt.Errorf("NewTableConfig: %w", err) } - s.TableMailboxes, err = NewTableMailboxes(db) + s.TableMailboxes, err = NewTableMailboxes(db, s.writer) if err != nil { return nil, fmt.Errorf("NewTableMailboxes: %w", err) } - s.TableMails, err = NewTableMails(db) + s.TableMails, err = NewTableMails(db, s.writer) if err != nil { return nil, fmt.Errorf("NewTableMails: %w", err) } - s.TableQueue, err = NewTableQueue(db) + s.TableQueue, err = NewTableQueue(db, s.writer) if err != nil { return nil, fmt.Errorf("NewTableQueue: %w", err) } return s, nil } + +type Writer struct { + running atomic.Bool + todo chan writerTask +} + +type writerTask struct { + db *sql.DB + txn *sql.Tx + f func(txn *sql.Tx) error + wait chan error +} + +func (w *Writer) Do(db *sql.DB, txn *sql.Tx, f func(txn *sql.Tx) error) error { + if !w.running.Load() { + go w.run() + } + task := writerTask{ + db: db, + txn: txn, + f: f, + wait: make(chan error, 1), + } + w.todo <- task + return <-task.wait +} + +func (w *Writer) run() { + if !w.running.CAS(false, true) { + return + } + defer w.running.Store(false) + for task := range w.todo { + if task.db != nil && task.txn != nil { + task.wait <- task.f(task.txn) + } else if task.db != nil && task.txn == nil { + func() { + txn, err := task.db.Begin() + if err != nil { + return + } + err = task.f(txn) + task.wait <- err + if err == nil { + _ = txn.Commit() + } else { + _ = txn.Rollback() + } + }() + } else { + task.wait <- task.f(nil) + } + close(task.wait) + } +} diff --git a/internal/storage/sqlite3/table_config.go b/internal/storage/sqlite3/table_config.go index 2ba200b..b4c8cdf 100644 --- a/internal/storage/sqlite3/table_config.go +++ b/internal/storage/sqlite3/table_config.go @@ -8,9 +8,10 @@ import ( ) type TableConfig struct { - db *sql.DB - get *sql.Stmt - set *sql.Stmt + db *sql.DB + writer *Writer + get *sql.Stmt + set *sql.Stmt } const configSchema = ` @@ -29,9 +30,10 @@ const configSet = ` INSERT OR REPLACE INTO config (key, value) VALUES($1, $2) ` -func NewTableConfig(db *sql.DB) (*TableConfig, error) { +func NewTableConfig(db *sql.DB, writer *Writer) (*TableConfig, error) { t := &TableConfig{ - db: db, + db: db, + writer: writer, } _, err := db.Exec(configSchema) if err != nil { @@ -58,8 +60,10 @@ func (t *TableConfig) ConfigGet(key string) (string, error) { } func (t *TableConfig) ConfigSet(key, value string) error { - _, err := t.set.Exec(key, value) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.set.Exec(key, value) + return err + }) } func (t *TableConfig) ConfigSetPassword(password string) error { diff --git a/internal/storage/sqlite3/table_mailboxes.go b/internal/storage/sqlite3/table_mailboxes.go index f6cf886..c481262 100644 --- a/internal/storage/sqlite3/table_mailboxes.go +++ b/internal/storage/sqlite3/table_mailboxes.go @@ -7,6 +7,7 @@ import ( type TableMailboxes struct { db *sql.DB + writer *Writer selectMailboxes *sql.Stmt listMailboxes *sql.Stmt listMailboxesSubscribed *sql.Stmt @@ -52,9 +53,10 @@ const mailboxesSubscribe = ` UPDATE mailboxes SET subscribed = $1 WHERE mailbox = $2 ` -func NewTableMailboxes(db *sql.DB) (*TableMailboxes, error) { +func NewTableMailboxes(db *sql.DB, writer *Writer) (*TableMailboxes, error) { t := &TableMailboxes{ - db: db, + db: db, + writer: writer, } _, err := db.Exec(mailboxesSchema) if err != nil { @@ -127,18 +129,24 @@ func (t *TableMailboxes) MailboxSelect(mailbox string) (bool, error) { } func (t *TableMailboxes) MailboxCreate(name string) error { - _, err := t.createMailbox.Exec(name) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.createMailbox.Exec(name) + return err + }) } func (t *TableMailboxes) MailboxRename(old, new string) error { - _, err := t.renameMailbox.Exec(old, new) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.renameMailbox.Exec(old, new) + return err + }) } func (t *TableMailboxes) MailboxDelete(name string) error { - _, err := t.deleteMailbox.Exec(name) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.deleteMailbox.Exec(name) + return err + }) } func (t *TableMailboxes) MailboxSubscribe(name string, subscribed bool) error { @@ -146,6 +154,8 @@ func (t *TableMailboxes) MailboxSubscribe(name string, subscribed bool) error { if !subscribed { sn = 0 } - _, err := t.subscribeMailbox.Exec(sn, name) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.subscribeMailbox.Exec(sn, name) + return err + }) } diff --git a/internal/storage/sqlite3/table_mails.go b/internal/storage/sqlite3/table_mails.go index edef371..812ab35 100644 --- a/internal/storage/sqlite3/table_mails.go +++ b/internal/storage/sqlite3/table_mails.go @@ -10,6 +10,7 @@ import ( type TableMails struct { db *sql.DB + writer *Writer selectMails *sql.Stmt selectMail *sql.Stmt selectMailNextID *sql.Stmt @@ -100,9 +101,10 @@ const expungeMailStmt = ` DELETE FROM mails WHERE mailbox = $1 AND deleted = 1 ` -func NewTableMails(db *sql.DB) (*TableMails, error) { +func NewTableMails(db *sql.DB, writer *Writer) (*TableMails, error) { t := &TableMails{ - db: db, + db: db, + writer: writer, } _, err := db.Exec(mailsSchema) if err != nil { @@ -157,7 +159,9 @@ func NewTableMails(db *sql.DB) (*TableMails, error) { func (t *TableMails) MailCreate(mailbox string, data []byte) (int, error) { var id int - err := t.createMail.QueryRow(mailbox, data, time.Now().Unix()).Scan(&id) + err := t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + return t.createMail.QueryRow(mailbox, data, time.Now().Unix()).Scan(&id) + }) return id, err } @@ -209,18 +213,24 @@ func (t *TableMails) MailUnseen(mailbox string) (int, error) { } func (t *TableMails) MailUpdateFlags(mailbox string, id int, seen, answered, flagged, deleted bool) error { - _, err := t.updateMailFlags.Exec(seen, answered, flagged, deleted, mailbox, id) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.updateMailFlags.Exec(seen, answered, flagged, deleted, mailbox, id) + return err + }) } func (t *TableMails) MailDelete(mailbox string, id int) error { - _, err := t.deleteMail.Exec(mailbox, id) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.deleteMail.Exec(mailbox, id) + return err + }) } func (t *TableMails) MailExpunge(mailbox string) error { - _, err := t.expungeMail.Exec(mailbox) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.expungeMail.Exec(mailbox) + return err + }) } func (t *TableMails) MailCount(mailbox string) (int, error) { diff --git a/internal/storage/sqlite3/table_queue.go b/internal/storage/sqlite3/table_queue.go index 518bdf0..70dd4ec 100644 --- a/internal/storage/sqlite3/table_queue.go +++ b/internal/storage/sqlite3/table_queue.go @@ -9,6 +9,7 @@ import ( type TableQueue struct { db *sql.DB + writer *Writer queueSelectDestinations *sql.Stmt queueSelectIDsForDestination *sql.Stmt queueInsertDestinationForID *sql.Stmt @@ -49,9 +50,10 @@ const queueSelectIsMessagePendingSendStmt = ` SELECT COUNT(*) FROM queue WHERE mailbox = $1 AND id = $2 ` -func NewTableQueue(db *sql.DB) (*TableQueue, error) { +func NewTableQueue(db *sql.DB, writer *Writer) (*TableQueue, error) { t := &TableQueue{ - db: db, + db: db, + writer: writer, } _, err := db.Exec(queueSchema) if err != nil { @@ -126,13 +128,17 @@ func (t *TableQueue) QueueMailIDsForDestination(destination string) ([]types.Que } func (t *TableQueue) QueueInsertDestinationForID(destination string, id int, from, rcpt string) error { - _, err := t.queueInsertDestinationForID.Exec(destination, "Outbox", id, from, rcpt) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.queueInsertDestinationForID.Exec(destination, "Outbox", id, from, rcpt) + return err + }) } func (t *TableQueue) QueueDeleteDestinationForID(destination string, id int) error { - _, err := t.queueDeleteIDForDestination.Exec(destination, "Outbox", id) - return err + return t.writer.Do(t.db, nil, func(txn *sql.Tx) error { + _, err := t.queueDeleteIDForDestination.Exec(destination, "Outbox", id) + return err + }) } func (t *TableQueue) QueueSelectIsMessagePendingSend(mailbox string, id int) (bool, error) {