Single user refactor

This commit is contained in:
Neil Alexander
2021-07-09 00:08:26 +01:00
parent 8411468451
commit 0735fa74de
15 changed files with 190 additions and 233 deletions

View File

@@ -9,7 +9,6 @@ import (
type SQLite3Storage struct {
*TableConfig
*TableUsers
*TableMailboxes
*TableMails
}
@@ -24,10 +23,6 @@ func NewSQLite3StorageStorage(filename string) (*SQLite3Storage, error) {
if err != nil {
return nil, fmt.Errorf("NewTableConfig: %w", err)
}
s.TableUsers, err = NewTableUsers(db)
if err != nil {
return nil, fmt.Errorf("NewTableUsers: %w", err)
}
s.TableMailboxes, err = NewTableMailboxes(db)
if err != nil {
return nil, fmt.Errorf("NewTableMailboxes: %w", err)

View File

@@ -3,6 +3,8 @@ package sqlite3
import (
"database/sql"
"fmt"
"golang.org/x/crypto/bcrypt"
)
type TableConfig struct {
@@ -59,3 +61,26 @@ func (t *TableConfig) ConfigSet(key, value string) error {
_, err := t.set.Exec(key, value)
return err
}
func (t *TableConfig) ConfigSetPassword(password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("bcrypt.GenerateFromPassword: %w", err)
}
return t.ConfigSet("password", string(hash))
}
func (t *TableConfig) ConfigTryPassword(password string) (bool, error) {
dbPasswordHash, err := t.ConfigGet("password")
if err != nil {
return false, err
}
if dbPasswordHash == "" {
return true, nil // TODO: Do we want to allow login if no password is set?
}
err = bcrypt.CompareHashAndPassword([]byte(dbPasswordHash), []byte(password))
if err == nil {
return true, nil
}
return false, err
}

View File

@@ -18,39 +18,38 @@ type TableMailboxes struct {
const mailboxesSchema = `
CREATE TABLE IF NOT EXISTS mailboxes (
username TEXT NOT NULL REFERENCES users(username) ON DELETE CASCADE ON UPDATE CASCADE,
mailbox TEXT NOT NULL DEFAULT('INBOX'),
subscribed BOOLEAN NOT NULL DEFAULT 1,
PRIMARY KEY(username, mailbox)
PRIMARY KEY(mailbox)
);
`
const mailboxesList = `
SELECT mailbox FROM mailboxes WHERE username = $1
SELECT mailbox FROM mailboxes
`
const mailboxesListSubscribed = `
SELECT mailbox FROM mailboxes WHERE username = $1 AND subscribed = 1
SELECT mailbox FROM mailboxes WHERE subscribed = 1
`
const mailboxesSelect = `
SELECT mailbox FROM mailboxes WHERE username = $1 AND mailbox = $2
SELECT mailbox FROM mailboxes WHERE mailbox = $1
`
const mailboxesCreate = `
INSERT INTO mailboxes (username, mailbox) VALUES($1, $2)
INSERT INTO mailboxes (mailbox) VALUES($1)
`
const mailboxesRename = `
UPDATE mailboxes SET mailbox = $1 WHERE username = $2 AND mailbox = $3
UPDATE mailboxes SET mailbox = $1 WHERE mailbox = $2
`
const mailboxesDelete = `
DELETE FROM mailboxes WHERE username = $1 AND mailbox = $2
DELETE FROM mailboxes WHERE mailbox = $1
`
const mailboxesSubscribe = `
UPDATE mailboxes SET subscribed = $1 WHERE username = $2 AND mailbox = $3
UPDATE mailboxes SET subscribed = $1 WHERE mailbox = $2
`
func NewTableMailboxes(db *sql.DB) (*TableMailboxes, error) {
@@ -92,12 +91,12 @@ func NewTableMailboxes(db *sql.DB) (*TableMailboxes, error) {
return t, nil
}
func (t *TableMailboxes) MailboxList(user string, onlySubscribed bool) ([]string, error) {
func (t *TableMailboxes) MailboxList(onlySubscribed bool) ([]string, error) {
stmt := t.listMailboxes
if onlySubscribed {
stmt = t.listMailboxesSubscribed
}
rows, err := stmt.Query(user)
rows, err := stmt.Query()
if err != nil {
return nil, fmt.Errorf("t.listMailboxes.Query: %w", err)
}
@@ -113,8 +112,8 @@ func (t *TableMailboxes) MailboxList(user string, onlySubscribed bool) ([]string
return mailboxes, nil
}
func (t *TableMailboxes) MailboxSelect(user, mailbox string) (bool, error) {
row := t.selectMailboxes.QueryRow(user, mailbox)
func (t *TableMailboxes) MailboxSelect(mailbox string) (bool, error) {
row := t.selectMailboxes.QueryRow(mailbox)
if err := row.Err(); err != nil && err != sql.ErrNoRows {
return false, fmt.Errorf("row.Err: %w", err)
} else if err == sql.ErrNoRows {
@@ -127,26 +126,26 @@ func (t *TableMailboxes) MailboxSelect(user, mailbox string) (bool, error) {
return mailbox == got, nil
}
func (t *TableMailboxes) MailboxCreate(user, name string) error {
_, err := t.createMailbox.Exec(user, name)
func (t *TableMailboxes) MailboxCreate(name string) error {
_, err := t.createMailbox.Exec(name)
return err
}
func (t *TableMailboxes) MailboxRename(user, old, new string) error {
_, err := t.renameMailbox.Exec(new, user, old)
func (t *TableMailboxes) MailboxRename(old, new string) error {
_, err := t.renameMailbox.Exec(old, new)
return err
}
func (t *TableMailboxes) MailboxDelete(user, name string) error {
_, err := t.deleteMailbox.Exec(user, name)
func (t *TableMailboxes) MailboxDelete(name string) error {
_, err := t.deleteMailbox.Exec(name)
return err
}
func (t *TableMailboxes) MailboxSubscribe(user, name string, subscribed bool) error {
func (t *TableMailboxes) MailboxSubscribe(name string, subscribed bool) error {
sn := 1
if !subscribed {
sn = 0
}
_, err := t.subscribeMailbox.Exec(sn, user, name)
_, err := t.subscribeMailbox.Exec(sn, name)
return err
}

View File

@@ -23,7 +23,6 @@ type TableMails struct {
const mailsSchema = `
CREATE TABLE IF NOT EXISTS mails (
username TEXT NOT NULL,
mailbox TEXT NOT NULL,
id INTEGER NOT NULL DEFAULT 1,
mail BLOB NOT NULL,
@@ -32,72 +31,72 @@ const mailsSchema = `
answered BOOLEAN NOT NULL DEFAULT 0, -- the mail has been replied to
flagged BOOLEAN NOT NULL DEFAULT 0, -- the mail has been flagged for later attention
deleted BOOLEAN NOT NULL DEFAULT 0, -- the email is marked for deletion at next EXPUNGE
PRIMARY KEY(username, mailbox, id),
FOREIGN KEY (username, mailbox) REFERENCES mailboxes(username, mailbox) ON DELETE CASCADE ON UPDATE CASCADE
PRIMARY KEY (mailbox, id),
FOREIGN KEY (mailbox) REFERENCES mailboxes(mailbox) ON DELETE CASCADE ON UPDATE CASCADE
);
DROP VIEW IF EXISTS inboxes;
CREATE VIEW IF NOT EXISTS inboxes AS SELECT * FROM (
SELECT ROW_NUMBER() OVER (PARTITION BY username, mailbox) AS seq, * FROM mails
SELECT ROW_NUMBER() OVER (PARTITION BY mailbox) AS seq, * FROM mails
)
ORDER BY username, mailbox, id;
ORDER BY mailbox, id;
`
const selectMailsStmt = `
SELECT * FROM inboxes
ORDER BY username, mailbox, id
ORDER BY mailbox, id
`
const selectMailStmt = `
SELECT seq, id, mail, datetime, seen, answered, flagged, deleted FROM inboxes
WHERE username = $1 AND mailbox = $2 AND id = $3
ORDER BY username, mailbox, id
WHERE mailbox = $1 AND id = $2
ORDER BY mailbox, id
`
const selectMailCountStmt = `
SELECT COUNT(*) FROM mails WHERE username = $1 AND mailbox = $2
SELECT COUNT(*) FROM mails WHERE mailbox = $1
`
const selectMailUnseenStmt = `
SELECT COUNT(*) FROM mails WHERE username = $1 AND mailbox = $2 AND seen = 0
SELECT COUNT(*) FROM mails WHERE mailbox = $1 AND seen = 0
`
const searchMailStmt = `
SELECT id FROM mails
WHERE username = $1 AND mailbox = $2
ORDER BY username, mailbox, id
WHERE mailbox = $1
ORDER BY mailbox, id
`
const insertMailStmt = `
INSERT INTO mails (username, mailbox, id, mail, datetime) VALUES(
$1, $2, (
INSERT INTO mails (mailbox, id, mail, datetime) VALUES(
$1, (
SELECT IFNULL(MAX(id)+1,1) AS id FROM mails
WHERE username = $1 AND mailbox = $2
), $3, $4
WHERE mailbox = $1
), $2, $3
)
RETURNING id;
`
const selectIDForSeqStmt = `
SELECT id FROM inboxes
WHERE username = $1 AND mailbox = $2 AND seq = $3
WHERE mailbox = $1 AND seq = $2
`
const selectMailNextID = `
SELECT IFNULL(MAX(id)+1,1) AS id FROM mails
WHERE username = $1 AND mailbox = $2
WHERE mailbox = $1
`
const updateMailFlagsStmt = `
UPDATE mails SET seen = $1, answered = $2, flagged = $3, deleted = $4 WHERE username = $5 AND mailbox = $6 AND id = $7
UPDATE mails SET seen = $1, answered = $2, flagged = $3, deleted = $4 WHERE mailbox = $5 AND id = $6
`
const deleteMailStmt = `
UPDATE mails SET deleted = 1 WHERE username = $1 AND mailbox = $2 AND id = $3
UPDATE mails SET deleted = 1 WHERE mailbox = $1 AND id = $2
`
const expungeMailStmt = `
DELETE FROM mails WHERE username = $1 AND mailbox = $2 AND deleted = 1
DELETE FROM mails WHERE mailbox = $1 AND deleted = 1
`
func NewTableMails(db *sql.DB) (*TableMails, error) {
@@ -155,24 +154,24 @@ func NewTableMails(db *sql.DB) (*TableMails, error) {
return t, nil
}
func (t *TableMails) MailCreate(user, mailbox string, data []byte) (int, error) {
func (t *TableMails) MailCreate(mailbox string, data []byte) (int, error) {
var id int
err := t.createMail.QueryRow(user, mailbox, data, time.Now().Unix()).Scan(&id)
err := t.createMail.QueryRow(mailbox, data, time.Now().Unix()).Scan(&id)
return id, err
}
func (t *TableMails) MailSelect(user, mailbox string, id int) (int, int, []byte, bool, bool, bool, bool, time.Time, error) {
func (t *TableMails) MailSelect(mailbox string, id int) (int, int, []byte, bool, bool, bool, bool, time.Time, error) {
var data []byte
var seen, answered, flagged, deleted bool
var ts int64
var seq, pid int
err := t.selectMail.QueryRow(user, mailbox, id).Scan(&seq, &pid, &data, &ts, &seen, &answered, &flagged, &deleted)
err := t.selectMail.QueryRow(mailbox, id).Scan(&seq, &pid, &data, &ts, &seen, &answered, &flagged, &deleted)
return seq, pid, data, seen, answered, flagged, deleted, time.Unix(ts, 0), err
}
func (t *TableMails) MailSearch(user, mailbox string) ([]uint32, error) {
func (t *TableMails) MailSearch(mailbox string) ([]uint32, error) {
var ids []uint32
rows, err := t.searchMail.Query(user, mailbox)
rows, err := t.searchMail.Query(mailbox)
if err != nil {
return nil, fmt.Errorf("t.searchMail.Query: %w", err)
}
@@ -187,41 +186,41 @@ func (t *TableMails) MailSearch(user, mailbox string) ([]uint32, error) {
return ids, nil
}
func (t *TableMails) MailNextID(user, mailbox string) (int, error) {
func (t *TableMails) MailNextID(mailbox string) (int, error) {
var id int
err := t.selectMailNextID.QueryRow(user, mailbox).Scan(&id)
err := t.selectMailNextID.QueryRow(mailbox).Scan(&id)
return id, err
}
func (t *TableMails) MailIDForSeq(user, mailbox string, seq int) (int, error) {
func (t *TableMails) MailIDForSeq(mailbox string, seq int) (int, error) {
var id int
err := t.selectIDForSeq.QueryRow(user, mailbox, seq).Scan(&id)
err := t.selectIDForSeq.QueryRow(mailbox, seq).Scan(&id)
return id, err
}
func (t *TableMails) MailUnseen(user, mailbox string) (int, error) {
func (t *TableMails) MailUnseen(mailbox string) (int, error) {
var unseen int
err := t.countUnseenMails.QueryRow(user, mailbox).Scan(&unseen)
err := t.countUnseenMails.QueryRow(mailbox).Scan(&unseen)
return unseen, err
}
func (t *TableMails) MailUpdateFlags(user, mailbox string, id int, seen, answered, flagged, deleted bool) error {
_, err := t.updateMailFlags.Exec(seen, answered, flagged, deleted, user, mailbox, id)
func (t *TableMails) MailUpdateFlags(mailbox string, id int, seen, answered, flagged, deleted bool) error {
_, err := t.updateMailFlags.Exec(seen, answered, flagged, deleted, mailbox, id)
return err
}
func (t *TableMails) MailDelete(user, mailbox, id string) error {
_, err := t.deleteMail.Exec(user, mailbox, id)
func (t *TableMails) MailDelete(mailbox, id string) error {
_, err := t.deleteMail.Exec(mailbox, id)
return err
}
func (t *TableMails) MailExpunge(user, mailbox string) error {
_, err := t.expungeMail.Exec(user, mailbox)
func (t *TableMails) MailExpunge(mailbox string) error {
_, err := t.expungeMail.Exec(mailbox)
return err
}
func (t *TableMails) MailCount(user, mailbox string) (int, error) {
func (t *TableMails) MailCount(mailbox string) (int, error) {
var count int
err := t.countMails.QueryRow(user, mailbox).Scan(&count)
err := t.countMails.QueryRow(mailbox).Scan(&count)
return count, err
}

View File

@@ -1,76 +0,0 @@
package sqlite3
import (
"database/sql"
"fmt"
"golang.org/x/crypto/bcrypt"
)
type TableUsers struct {
db *sql.DB
getPassword *sql.Stmt
insertUser *sql.Stmt
}
const usersSchema = `
CREATE TABLE IF NOT EXISTS users (
username TEXT NOT NULL,
password TEXT NOT NULL,
PRIMARY KEY(username)
);
`
const usersGetPassword = `
SELECT password FROM users WHERE username = $1
`
const usersInsertUser = `
INSERT INTO users (username, password) VALUES($1, $2)
`
func NewTableUsers(db *sql.DB) (*TableUsers, error) {
t := &TableUsers{
db: db,
}
_, err := db.Exec(usersSchema)
if err != nil {
return nil, fmt.Errorf("db.Exec: %w", err)
}
t.getPassword, err = db.Prepare(usersGetPassword)
if err != nil {
return nil, fmt.Errorf("db.Prepare(getPassword): %w", err)
}
t.insertUser, err = db.Prepare(usersInsertUser)
if err != nil {
return nil, fmt.Errorf("db.Prepare(usersInsertUser): %w", err)
}
return t, nil
}
func (t *TableUsers) CreateUser(username, password string) error {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("bcrypt.GenerateFromPassword: %w", err)
}
if _, err := t.insertUser.Exec(username, hash); err != nil {
return fmt.Errorf("t.insertUser.Exec: %w", err)
}
return nil
}
func (t *TableUsers) TryAuthenticate(username, password string) (bool, error) {
var dbPasswordHash []byte
err := t.getPassword.QueryRow(username).Scan(&dbPasswordHash)
if err == sql.ErrNoRows {
return false, nil
}
if err != nil {
return false, err
}
err = bcrypt.CompareHashAndPassword(dbPasswordHash, []byte(password))
if err == nil {
return true, nil
}
return false, err
}