Skip to content

Commit

Permalink
Refactor database tables and query wrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Dec 30, 2023
1 parent 18e0063 commit 4bf3bc4
Show file tree
Hide file tree
Showing 25 changed files with 1,454 additions and 1,376 deletions.
21 changes: 15 additions & 6 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
package main

import (
"context"
"strings"

"github.com/google/uuid"
"github.com/skip2/go-qrcode"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/bridge/commands"
Expand Down Expand Up @@ -88,7 +90,7 @@ func fnSetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ce.User.MXID
ce.Portal.Update()
ce.Portal.Update(context.TODO())
ce.Reply("Messages from non-logged-in users in this room will now be bridged through your Signal account")
}
}
Expand All @@ -110,7 +112,7 @@ func fnUnsetRelay(ce *WrappedCommandEvent) {
ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge")
} else {
ce.Portal.RelayUserID = ""
ce.Portal.Update()
ce.Portal.Update(context.TODO())
ce.Reply("Messages from non-logged-in users will no longer be bridged in this room")
}
}
Expand Down Expand Up @@ -143,7 +145,7 @@ var cmdPing = &commands.FullHandler{
}

func fnPing(ce *WrappedCommandEvent) {
if ce.User.SignalID == "" {
if ce.User.SignalID == uuid.Nil {
ce.Reply("You're not logged in")
} else if !ce.User.SignalDevice.IsDeviceLoggedIn() {
ce.Reply("You were logged in at some point, but are not anymore")
Expand Down Expand Up @@ -306,13 +308,20 @@ func fnLogin(ce *WrappedCommandEvent) {

// Update user with SignalID
if signalID != "" {
ce.User.SignalID = signalID
ce.User.SignalID, err = uuid.Parse(signalID)
if err != nil {
ce.Reply("Problem logging in - SignalID is not a valid UUID")
return
}
ce.User.SignalUsername = signalUsername
} else {
ce.Reply("Problem logging in - No SignalID received")
return
}
ce.User.Update()
err = ce.User.Update(context.TODO())
if err != nil {
ce.ZLog.Err(err).Msg("Failed to save user to database")
}

// Connect to Signal
ce.User.Connect()
Expand Down Expand Up @@ -415,7 +424,7 @@ var cmdDeleteAllPortals = &commands.FullHandler{
}

func fnDeleteAllPortals(ce *WrappedCommandEvent) {
portals := ce.Bridge.getAllPortals()
portals := ce.Bridge.GetAllPortalsWithMXID()
var portalsToDelete []*Portal

if ce.User.Admin {
Expand Down
19 changes: 14 additions & 5 deletions custompuppet.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,20 @@
package main

import (
"context"
"fmt"

"maunium.net/go/mautrix/id"
)

func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error {
puppet.CustomMXID = mxid
puppet.AccessToken = accessToken
puppet.Update()
err := puppet.StartCustomMXID(false)
err := puppet.Update(context.TODO())
if err != nil {
return fmt.Errorf("failed to save access token: %w", err)
}
err = puppet.StartCustomMXID(false)
if err != nil {
return err
}
Expand All @@ -44,7 +50,10 @@ func (puppet *Puppet) ClearCustomMXID() {
puppet.customIntent = nil
puppet.customUser = nil
if save {
puppet.Update()
err := puppet.Update(context.TODO())
if err != nil {
puppet.log.Err(err).Msg("Failed to clear custom MXID")
}
}
}

Expand All @@ -59,11 +68,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error {
puppet.bridge.puppetsLock.Unlock()
if puppet.AccessToken != newAccessToken {
puppet.AccessToken = newAccessToken
puppet.Update()
err = puppet.Update(context.TODO())
}
puppet.customIntent = newIntent
puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID)
return nil
return err
}

func (user *User) tryAutomaticDoublePuppeting() {
Expand Down
46 changes: 11 additions & 35 deletions database/database.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
// Copyright (C) 2023 Scott Weber, Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
Expand All @@ -21,8 +21,8 @@ import (

_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"

"go.mau.fi/util/dbutil"
"maunium.net/go/maulogger/v2"

"go.mau.fi/mautrix-signal/database/upgrades"
)
Expand All @@ -38,39 +38,15 @@ type Database struct {
DisappearingMessage *DisappearingMessageQuery
}

func New(baseDB *dbutil.Database, log maulogger.Logger) *Database {
db := &Database{Database: baseDB}
func New(db *dbutil.Database) *Database {
db.UpgradeTable = upgrades.Table
db.User = &UserQuery{
db: db,
log: log.Sub("User"),
}
db.Portal = &PortalQuery{
db: db,
log: log.Sub("Portal"),
}
db.Puppet = &PuppetQuery{
db: db,
log: log.Sub("Puppet"),
}
db.Message = &MessageQuery{
db: db,
log: log.Sub("Message"),
}
db.Reaction = &ReactionQuery{
db: db,
log: log.Sub("Reaction"),
}
db.DisappearingMessage = &DisappearingMessageQuery{
db: db,
log: log.Sub("DisappearingMessage"),
}
return db
}

func strPtr(val string) *string {
if val == "" {
return nil
return &Database{
Database: db,
User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)},
Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)},
Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)},
Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)},
Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)},
DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)},
}
return &val
}
154 changes: 64 additions & 90 deletions database/disappearingmessage.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// mautrix-signal - A Matrix-signal puppeting bridge.
// Copyright (C) 2023 Scott Weber
// Copyright (C) 2023 Scott Weber, Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
Expand All @@ -17,136 +17,110 @@
package database

import (
"context"
"database/sql"
"errors"
"time"

"go.mau.fi/util/dbutil"
log "maunium.net/go/maulogger/v2"

"maunium.net/go/mautrix/id"
)

const (
getUnscheduledDisappearingMessagesForRoomQuery = `
SELECT room_id, mxid, expiration_seconds, expiration_ts
FROM disappearing_message WHERE expiration_ts IS NULL AND room_id = $1
`
getExpiredDisappearingMessagesQuery = `
SELECT room_id, mxid, expiration_seconds, expiration_ts
FROM disappearing_message WHERE expiration_ts IS NOT NULL AND expiration_ts <= $1
`
getNextDisappearingMessageQuery = `
SELECT room_id, mxid, expiration_seconds, expiration_ts
FROM disappearing_message WHERE expiration_ts IS NOT NULL ORDER BY expiration_ts ASC LIMIT 1
`
insertDisappearingMessageQuery = `
INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts) VALUES ($1, $2, $3, $4)
`
updateDisappearingMessageQuery = `
UPDATE disappearing_message SET expiration_ts=$2 WHERE mxid=$1
`
deleteDisappearingMessageQuery = `
DELETE FROM disappearing_message WHERE mxid=$1
`
)

type DisappearingMessageQuery struct {
db *Database
log log.Logger
*dbutil.QueryHelper[*DisappearingMessage]
}

func (dmq *DisappearingMessageQuery) New() *DisappearingMessage {
return &DisappearingMessage{
db: dmq.db,
log: dmq.log,
}
}
type DisappearingMessage struct {
qh *dbutil.QueryHelper[*DisappearingMessage]

func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireInSeconds int64, expireAt time.Time) *DisappearingMessage {
dm := &DisappearingMessage{
db: dmq.db,
log: dmq.log,
RoomID: roomID,
EventID: eventID,
ExpireInSeconds: expireInSeconds,
ExpireAt: expireAt,
}
return dm
RoomID id.RoomID
EventID id.EventID
ExpireIn time.Duration
ExpireAt time.Time
}

func (dmq *DisappearingMessageQuery) GetUnscheduledForRoom(roomID id.RoomID) (messages []*DisappearingMessage) {
const getUnscheduledQuery = `
SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message WHERE expiration_ts IS NULL AND room_id = $1
`
rows, err := dmq.db.Query(getUnscheduledQuery, roomID)
if err != nil || rows == nil {
dmq.log.Warnln("Failed to get unscheduled disappearing messages:", err)
return nil
}
for rows.Next() {
messages = append(messages, dmq.New().Scan(rows))
}
return
func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage {
return &DisappearingMessage{qh: qh}
}

func (dmq *DisappearingMessageQuery) GetExpiredMessages() (messages []*DisappearingMessage) {
const getExpiredQuery = `
SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message WHERE expiration_ts IS NOT NULL AND expiration_ts <= $1
`
const wiggleRoom = 1
rows, err := dmq.db.Query(getExpiredQuery, time.Now().Unix()+wiggleRoom)
if err != nil || rows == nil {
dmq.log.Warnln("Failed to get expired disappearing messages:", err)
return nil
}
for rows.Next() {
messages = append(messages, dmq.New().Scan(rows))
func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage {
return &DisappearingMessage{
qh: dmq.QueryHelper,
RoomID: roomID,
EventID: eventID,
ExpireIn: expireIn,
ExpireAt: expireAt,
}
return
}

func (dmq *DisappearingMessageQuery) GetNextScheduledMessage() (message *DisappearingMessage) {
const getNextScheduledQuery = `
SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message WHERE expiration_ts IS NOT NULL ORDER BY expiration_ts ASC LIMIT 1
`
row := dmq.db.QueryRow(getNextScheduledQuery)
if row == nil {
return nil
}
return dmq.New().Scan(row)
func (dmq *DisappearingMessageQuery) GetUnscheduledForRoom(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, getUnscheduledDisappearingMessagesForRoomQuery, roomID)
}

type DisappearingMessage struct {
db *Database
log log.Logger
func (dmq *DisappearingMessageQuery) GetExpiredMessages(ctx context.Context) ([]*DisappearingMessage, error) {
return dmq.QueryMany(ctx, getExpiredDisappearingMessagesQuery, time.Now().Unix()+1)
}

RoomID id.RoomID
EventID id.EventID
ExpireInSeconds int64
ExpireAt time.Time
func (dmq *DisappearingMessageQuery) GetNextScheduledMessage(ctx context.Context) (*DisappearingMessage, error) {
return dmq.QueryOne(ctx, getNextDisappearingMessageQuery)
}

func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage {
func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) {
var expireIn int64
var expireAt sql.NullInt64
err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
msg.log.Errorln("Database scan failed:", err)
}
return nil
return nil, err
}
msg.ExpireInSeconds = expireIn
msg.ExpireIn = time.Duration(expireIn) * time.Second
if expireAt.Valid {
msg.ExpireAt = time.Unix(expireAt.Int64, 0)
}
return msg
return msg, nil
}

func (msg *DisappearingMessage) Insert(txn dbutil.Execable) {
if txn == nil {
txn = msg.db
}
func (msg *DisappearingMessage) sqlVariables() []any {
var expireAt sql.NullInt64
if !msg.ExpireAt.IsZero() {
expireAt.Valid = true
expireAt.Int64 = msg.ExpireAt.Unix()
}
_, err := txn.Exec(`INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts) VALUES ($1, $2, $3, $4)`,
msg.RoomID, msg.EventID, msg.ExpireInSeconds, expireAt)
if err != nil {
msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err)
}
return []any{msg.RoomID, msg.EventID, int64(msg.ExpireIn.Seconds()), expireAt}
}

func (msg *DisappearingMessage) StartExpirationTimer() {
msg.ExpireAt = time.Now().Add(time.Duration(msg.ExpireInSeconds) * time.Second)
_, err := msg.db.Exec("UPDATE disappearing_message SET expiration_ts=$1 WHERE room_id=$2 AND mxid=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err)
}
func (msg *DisappearingMessage) Insert(ctx context.Context) error {
return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...)
}

func (msg *DisappearingMessage) Delete() {
_, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND mxid=$2", msg.RoomID, msg.EventID)
if err != nil {
msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err)
}
func (msg *DisappearingMessage) StartExpirationTimer(ctx context.Context) error {
msg.ExpireAt = time.Now().Add(msg.ExpireIn)
return msg.qh.Exec(ctx, updateDisappearingMessageQuery, msg.EventID, msg.ExpireAt.Unix())
}

func (msg *DisappearingMessage) Delete(ctx context.Context) error {
return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.EventID)
}
Loading

0 comments on commit 4bf3bc4

Please sign in to comment.