diff --git a/commands.go b/commands.go index ade0bcb3..27e52973 100644 --- a/commands.go +++ b/commands.go @@ -17,8 +17,10 @@ package main import ( + "context" "strings" + "github.com/google/uuid" "github.com/skip2/go-qrcode" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridge/commands" @@ -55,6 +57,7 @@ func (br *SignalBridge) RegisterCommands() { cmdUnsetRelay, cmdDeletePortal, cmdDeleteAllPortals, + cmdCleanupLostPortals, ) } @@ -88,7 +91,7 @@ func fnSetRelay(ce *WrappedCommandEvent) { ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge") } else { ce.Portal.RelayUserID = ce.User.MXID - ce.Portal.Update() + ce.Portal.Update(context.TODO()) ce.Reply("Messages from non-logged-in users in this room will now be bridged through your Signal account") } } @@ -110,7 +113,7 @@ func fnUnsetRelay(ce *WrappedCommandEvent) { ce.Reply("Only bridge admins are allowed to enable relay mode on this instance of the bridge") } else { ce.Portal.RelayUserID = "" - ce.Portal.Update() + ce.Portal.Update(context.TODO()) ce.Reply("Messages from non-logged-in users will no longer be bridged in this room") } } @@ -143,7 +146,7 @@ var cmdPing = &commands.FullHandler{ } func fnPing(ce *WrappedCommandEvent) { - if ce.User.SignalID == "" { + if ce.User.SignalID == uuid.Nil { ce.Reply("You're not logged in") } else if !ce.User.SignalDevice.IsDeviceLoggedIn() { ce.Reply("You were logged in at some point, but are not anymore") @@ -306,13 +309,20 @@ func fnLogin(ce *WrappedCommandEvent) { // Update user with SignalID if signalID != "" { - ce.User.SignalID = signalID + ce.User.SignalID, err = uuid.Parse(signalID) + if err != nil { + ce.Reply("Problem logging in - SignalID is not a valid UUID") + return + } ce.User.SignalUsername = signalUsername } else { ce.Reply("Problem logging in - No SignalID received") return } - ce.User.Update() + err = ce.User.Update(context.TODO()) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to save user to database") + } // Connect to Signal ce.User.Connect() @@ -415,7 +425,7 @@ var cmdDeleteAllPortals = &commands.FullHandler{ } func fnDeleteAllPortals(ce *WrappedCommandEvent) { - portals := ce.Bridge.getAllPortals() + portals := ce.Bridge.GetAllPortalsWithMXID() var portalsToDelete []*Portal if ce.User.Admin { @@ -465,3 +475,39 @@ func fnDeleteAllPortals(ce *WrappedCommandEvent) { ce.Reply("Finished background cleanup of deleted portal rooms.") }() } + +var cmdCleanupLostPortals = &commands.FullHandler{ + Func: wrapCommand(fnCleanupLostPortals), + Name: "cleanup-lost-portals", + Help: commands.HelpMeta{ + Section: HelpSectionPortalManagement, + Description: "Clean up portals that were discarded due to the receiver not being logged into the bridge", + }, + RequiresAdmin: true, +} + +func fnCleanupLostPortals(ce *WrappedCommandEvent) { + portals, err := ce.Bridge.DB.LostPortal.GetAll(context.TODO()) + if err != nil { + ce.Reply("Failed to get portals: %v", err) + return + } else if len(portals) == 0 { + ce.Reply("No lost portals found") + return + } + + ce.Reply("Found %d lost portals, deleting...", len(portals)) + for _, portal := range portals { + dmUUID, err := uuid.Parse(portal.ChatID) + intent := ce.Bot + if err == nil { + intent = ce.Bridge.GetPuppetBySignalID(dmUUID).DefaultIntent() + } + ce.Bridge.CleanupRoom(ce.ZLog, intent, portal.MXID, false) + err = portal.Delete(context.TODO()) + if err != nil { + ce.ZLog.Err(err).Msg("Failed to delete lost portal from database after cleanup") + } + } + ce.Reply("Finished cleaning up portals") +} diff --git a/custompuppet.go b/custompuppet.go index a17d4e3d..31cc1194 100644 --- a/custompuppet.go +++ b/custompuppet.go @@ -17,14 +17,20 @@ package main import ( + "context" + "fmt" + "maunium.net/go/mautrix/id" ) func (puppet *Puppet) SwitchCustomMXID(accessToken string, mxid id.UserID) error { puppet.CustomMXID = mxid puppet.AccessToken = accessToken - puppet.Update() - err := puppet.StartCustomMXID(false) + err := puppet.Update(context.TODO()) + if err != nil { + return fmt.Errorf("failed to save access token: %w", err) + } + err = puppet.StartCustomMXID(false) if err != nil { return err } @@ -44,7 +50,10 @@ func (puppet *Puppet) ClearCustomMXID() { puppet.customIntent = nil puppet.customUser = nil if save { - puppet.Update() + err := puppet.Update(context.TODO()) + if err != nil { + puppet.log.Err(err).Msg("Failed to clear custom MXID") + } } } @@ -59,11 +68,11 @@ func (puppet *Puppet) StartCustomMXID(reloginOnFail bool) error { puppet.bridge.puppetsLock.Unlock() if puppet.AccessToken != newAccessToken { puppet.AccessToken = newAccessToken - puppet.Update() + err = puppet.Update(context.TODO()) } puppet.customIntent = newIntent puppet.customUser = puppet.bridge.GetUserByMXID(puppet.CustomMXID) - return nil + return err } func (user *User) tryAutomaticDoublePuppeting() { diff --git a/database/database.go b/database/database.go index da201bb8..daa365fb 100644 --- a/database/database.go +++ b/database/database.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -22,7 +22,6 @@ import ( _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" "go.mau.fi/util/dbutil" - "maunium.net/go/maulogger/v2" "go.mau.fi/mautrix-signal/database/upgrades" ) @@ -32,45 +31,23 @@ type Database struct { User *UserQuery Portal *PortalQuery + LostPortal *LostPortalQuery Puppet *PuppetQuery Message *MessageQuery Reaction *ReactionQuery DisappearingMessage *DisappearingMessageQuery } -func New(baseDB *dbutil.Database, log maulogger.Logger) *Database { - db := &Database{Database: baseDB} +func New(db *dbutil.Database) *Database { db.UpgradeTable = upgrades.Table - db.User = &UserQuery{ - db: db, - log: log.Sub("User"), + return &Database{ + Database: db, + User: &UserQuery{dbutil.MakeQueryHelper(db, newUser)}, + Portal: &PortalQuery{dbutil.MakeQueryHelper(db, newPortal)}, + LostPortal: &LostPortalQuery{dbutil.MakeQueryHelper(db, newLostPortal)}, + Puppet: &PuppetQuery{dbutil.MakeQueryHelper(db, newPuppet)}, + Message: &MessageQuery{dbutil.MakeQueryHelper(db, newMessage)}, + Reaction: &ReactionQuery{dbutil.MakeQueryHelper(db, newReaction)}, + DisappearingMessage: &DisappearingMessageQuery{dbutil.MakeQueryHelper(db, newDisappearingMessage)}, } - db.Portal = &PortalQuery{ - db: db, - log: log.Sub("Portal"), - } - db.Puppet = &PuppetQuery{ - db: db, - log: log.Sub("Puppet"), - } - db.Message = &MessageQuery{ - db: db, - log: log.Sub("Message"), - } - db.Reaction = &ReactionQuery{ - db: db, - log: log.Sub("Reaction"), - } - db.DisappearingMessage = &DisappearingMessageQuery{ - db: db, - log: log.Sub("DisappearingMessage"), - } - return db -} - -func strPtr(val string) *string { - if val == "" { - return nil - } - return &val } diff --git a/database/disappearingmessage.go b/database/disappearingmessage.go index ea2318d9..b32a8eec 100644 --- a/database/disappearingmessage.go +++ b/database/disappearingmessage.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,136 +17,109 @@ package database import ( + "context" "database/sql" - "errors" "time" "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" - "maunium.net/go/mautrix/id" ) +const ( + getUnscheduledDisappearingMessagesForRoomQuery = ` + SELECT room_id, mxid, expiration_seconds, expiration_ts + FROM disappearing_message WHERE expiration_ts IS NULL AND room_id = $1 + ` + getExpiredDisappearingMessagesQuery = ` + SELECT room_id, mxid, expiration_seconds, expiration_ts + FROM disappearing_message WHERE expiration_ts IS NOT NULL AND expiration_ts <= $1 + ` + getNextDisappearingMessageQuery = ` + SELECT room_id, mxid, expiration_seconds, expiration_ts + FROM disappearing_message WHERE expiration_ts IS NOT NULL ORDER BY expiration_ts ASC LIMIT 1 + ` + insertDisappearingMessageQuery = ` + INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts) VALUES ($1, $2, $3, $4) + ` + updateDisappearingMessageQuery = ` + UPDATE disappearing_message SET expiration_ts=$2 WHERE mxid=$1 + ` + deleteDisappearingMessageQuery = ` + DELETE FROM disappearing_message WHERE mxid=$1 + ` +) + type DisappearingMessageQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*DisappearingMessage] } -func (dmq *DisappearingMessageQuery) New() *DisappearingMessage { - return &DisappearingMessage{ - db: dmq.db, - log: dmq.log, - } -} +type DisappearingMessage struct { + qh *dbutil.QueryHelper[*DisappearingMessage] -func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireInSeconds int64, expireAt time.Time) *DisappearingMessage { - dm := &DisappearingMessage{ - db: dmq.db, - log: dmq.log, - RoomID: roomID, - EventID: eventID, - ExpireInSeconds: expireInSeconds, - ExpireAt: expireAt, - } - return dm + RoomID id.RoomID + EventID id.EventID + ExpireIn time.Duration + ExpireAt time.Time } -func (dmq *DisappearingMessageQuery) GetUnscheduledForRoom(roomID id.RoomID) (messages []*DisappearingMessage) { - const getUnscheduledQuery = ` - SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message WHERE expiration_ts IS NULL AND room_id = $1 - ` - rows, err := dmq.db.Query(getUnscheduledQuery, roomID) - if err != nil || rows == nil { - dmq.log.Warnln("Failed to get unscheduled disappearing messages:", err) - return nil - } - for rows.Next() { - messages = append(messages, dmq.New().Scan(rows)) - } - return +func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { + return &DisappearingMessage{qh: qh} } -func (dmq *DisappearingMessageQuery) GetExpiredMessages() (messages []*DisappearingMessage) { - const getExpiredQuery = ` - SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message WHERE expiration_ts IS NOT NULL AND expiration_ts <= $1 - ` - const wiggleRoom = 1 - rows, err := dmq.db.Query(getExpiredQuery, time.Now().Unix()+wiggleRoom) - if err != nil || rows == nil { - dmq.log.Warnln("Failed to get expired disappearing messages:", err) - return nil - } - for rows.Next() { - messages = append(messages, dmq.New().Scan(rows)) +func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage { + return &DisappearingMessage{ + qh: dmq.QueryHelper, + RoomID: roomID, + EventID: eventID, + ExpireIn: expireIn, + ExpireAt: expireAt, } - return } -func (dmq *DisappearingMessageQuery) GetNextScheduledMessage() (message *DisappearingMessage) { - const getNextScheduledQuery = ` - SELECT room_id, mxid, expiration_seconds, expiration_ts FROM disappearing_message WHERE expiration_ts IS NOT NULL ORDER BY expiration_ts ASC LIMIT 1 - ` - row := dmq.db.QueryRow(getNextScheduledQuery) - if row == nil { - return nil - } - return dmq.New().Scan(row) +func (dmq *DisappearingMessageQuery) GetUnscheduledForRoom(ctx context.Context, roomID id.RoomID) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getUnscheduledDisappearingMessagesForRoomQuery, roomID) } -type DisappearingMessage struct { - db *Database - log log.Logger +func (dmq *DisappearingMessageQuery) GetExpiredMessages(ctx context.Context) ([]*DisappearingMessage, error) { + return dmq.QueryMany(ctx, getExpiredDisappearingMessagesQuery, time.Now().Unix()+1) +} - RoomID id.RoomID - EventID id.EventID - ExpireInSeconds int64 - ExpireAt time.Time +func (dmq *DisappearingMessageQuery) GetNextScheduledMessage(ctx context.Context) (*DisappearingMessage, error) { + return dmq.QueryOne(ctx, getNextDisappearingMessageQuery) } -func (msg *DisappearingMessage) Scan(row dbutil.Scannable) *DisappearingMessage { +func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage, error) { var expireIn int64 var expireAt sql.NullInt64 err := row.Scan(&msg.RoomID, &msg.EventID, &expireIn, &expireAt) if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - msg.log.Errorln("Database scan failed:", err) - } - return nil + return nil, err } - msg.ExpireInSeconds = expireIn + msg.ExpireIn = time.Duration(expireIn) * time.Second if expireAt.Valid { msg.ExpireAt = time.Unix(expireAt.Int64, 0) } - return msg + return msg, nil } -func (msg *DisappearingMessage) Insert(txn dbutil.Execable) { - if txn == nil { - txn = msg.db - } +func (msg *DisappearingMessage) sqlVariables() []any { var expireAt sql.NullInt64 if !msg.ExpireAt.IsZero() { expireAt.Valid = true expireAt.Int64 = msg.ExpireAt.Unix() } - _, err := txn.Exec(`INSERT INTO disappearing_message (room_id, mxid, expiration_seconds, expiration_ts) VALUES ($1, $2, $3, $4)`, - msg.RoomID, msg.EventID, msg.ExpireInSeconds, expireAt) - if err != nil { - msg.log.Warnfln("Failed to insert %s/%s: %v", msg.RoomID, msg.EventID, err) - } + return []any{msg.RoomID, msg.EventID, int64(msg.ExpireIn.Seconds()), expireAt} } -func (msg *DisappearingMessage) StartExpirationTimer() { - msg.ExpireAt = time.Now().Add(time.Duration(msg.ExpireInSeconds) * time.Second) - _, err := msg.db.Exec("UPDATE disappearing_message SET expiration_ts=$1 WHERE room_id=$2 AND mxid=$3", msg.ExpireAt.Unix(), msg.RoomID, msg.EventID) - if err != nil { - msg.log.Warnfln("Failed to update %s/%s: %v", msg.RoomID, msg.EventID, err) - } +func (msg *DisappearingMessage) Insert(ctx context.Context) error { + return msg.qh.Exec(ctx, insertDisappearingMessageQuery, msg.sqlVariables()...) } -func (msg *DisappearingMessage) Delete() { - _, err := msg.db.Exec("DELETE FROM disappearing_message WHERE room_id=$1 AND mxid=$2", msg.RoomID, msg.EventID) - if err != nil { - msg.log.Warnfln("Failed to delete %s/%s: %v", msg.RoomID, msg.EventID, err) - } +func (msg *DisappearingMessage) StartExpirationTimer(ctx context.Context) error { + msg.ExpireAt = time.Now().Add(msg.ExpireIn) + return msg.qh.Exec(ctx, updateDisappearingMessageQuery, msg.EventID, msg.ExpireAt.Unix()) +} + +func (msg *DisappearingMessage) Delete(ctx context.Context) error { + return msg.qh.Exec(ctx, deleteDisappearingMessageQuery, msg.EventID) } diff --git a/database/lostportal.go b/database/lostportal.go new file mode 100644 index 00000000..29779b86 --- /dev/null +++ b/database/lostportal.go @@ -0,0 +1,58 @@ +// mautrix-signal - A Matrix-signal puppeting bridge. +// Copyright (C) 2023 Tulir Asokan +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + +package database + +import ( + "context" + + "go.mau.fi/util/dbutil" + "maunium.net/go/mautrix/id" +) + +const ( + getLostPortalsQuery = `SELECT chat_id, receiver, mxid FROM lost_portals` + deleteLostPortalQuery = `DELETE FROM lost_portals WHERE mxid=$1` +) + +type LostPortalQuery struct { + *dbutil.QueryHelper[*LostPortal] +} + +func (lpq *LostPortalQuery) GetAll(ctx context.Context) ([]*LostPortal, error) { + return lpq.QueryMany(ctx, getLostPortalsQuery) +} + +type LostPortal struct { + qh *dbutil.QueryHelper[*LostPortal] + + ChatID string + Receiver string + MXID id.RoomID +} + +func newLostPortal(qh *dbutil.QueryHelper[*LostPortal]) *LostPortal { + return &LostPortal{qh: qh} +} + +func (l *LostPortal) Scan(row dbutil.Scannable) (*LostPortal, error) { + err := row.Scan(&l.ChatID, &l.Receiver, &l.MXID) + return l, err +} + +func (l *LostPortal) Delete(ctx context.Context) error { + return l.qh.Exec(ctx, deleteLostPortalQuery, l.MXID) +} diff --git a/database/message.go b/database/message.go index e72168a0..3f842154 100644 --- a/database/message.go +++ b/database/message.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,183 +17,133 @@ package database import ( - "database/sql" - "errors" + "context" + "fmt" + "strings" + "github.com/google/uuid" "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" ) -type MessageQuery struct { - db *Database - log log.Logger -} - -func (mq *MessageQuery) New() *Message { - return &Message{ - db: mq.db, - log: mq.log, - } -} - -type Message struct { - db *Database - log log.Logger - - MXID id.EventID - MXRoom id.RoomID - Sender string - Timestamp uint64 - SignalChatID string - SignalReceiver string -} - const ( - getAllMessagesQuery = ` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message - WHERE signal_chat_id=$1 AND signal_receiver=$2 - ` getMessageByMXIDQuery = ` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message WHERE mxid=$1 ` - getMessagesBySignalIDQuery = ` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message - WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4 + getMessagePartBySignalIDQuery = ` + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message + WHERE sender=$1 AND timestamp=$2 AND part_index=$3 AND signal_receiver=$4 + ` + getMessagePartBySignalIDWithUnknownReceiverQuery = ` + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message + WHERE sender=$1 AND timestamp=$2 AND part_index=$3 AND (signal_receiver=$4 OR signal_receiver='00000000-0000-0000-0000-000000000000') + ` + getLastMessagePartBySignalIDQuery = ` + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message + WHERE sender=$1 AND timestamp=$2 AND signal_receiver=$3 + ORDER BY part_index DESC LIMIT 1 ` - findBySenderAndTimestampQuery = ` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message - WHERE sender=$1 AND timestamp=$2 + getAllMessagePartsBySignalIDQuery = ` + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message + WHERE sender=$1 AND timestamp=$2 AND signal_receiver=$3 + ` + getManyMessagesBySignalIDQueryPostgres = ` + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message + WHERE sender=$1 AND signal_receiver=$2 AND timestamp=ANY($3) + ` + getManyMessagesBySignalIDQuerySQLite = ` + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message + WHERE sender=?1 AND signal_receiver=?2 AND timestamp IN (?3) ` getFirstBeforeQuery = ` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message + SELECT sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room FROM message WHERE mx_room=$1 AND timestamp <= $2 ORDER BY timestamp DESC LIMIT 1 ` + insertMessageQuery = ` + INSERT INTO message (sender, timestamp, part_index, signal_chat_id, signal_receiver, mxid, mx_room) + VALUES ($1, $2, $3, $4, $5, $6, $7) + ` + deleteMessageQuery = ` + DELETE FROM message + WHERE sender=$1 AND timestamp=$2 AND part_index=$3 AND signal_receiver=$4 + ` ) -func (msg *Message) Insert(txn dbutil.Execable) { - if txn == nil { - txn = msg.db - } - _, err := txn.Exec(` - INSERT INTO message (mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver) - VALUES ($1, $2, $3, $4, $5, $6) - `, - msg.MXID.String(), msg.MXRoom, msg.Sender, msg.Timestamp, msg.SignalChatID, msg.SignalReceiver) - msg.log.Debugfln("Inserting message", msg.MXID, msg.MXRoom, msg.Sender, msg.Timestamp, msg.SignalChatID, msg.SignalReceiver) - if err != nil { - msg.log.Warnfln("Failed to insert %s, %s: %v", msg.SignalChatID, msg.MXID, err) - } +type MessageQuery struct { + *dbutil.QueryHelper[*Message] } -func (msg *Message) Delete(txn dbutil.Execable) { - if txn == nil { - txn = msg.db - } - _, err := txn.Exec(` - DELETE FROM message - WHERE sender=$1 AND timestamp=$2 AND signal_chat_id=$3 AND signal_receiver=$4 - `, - msg.Sender, msg.Timestamp, msg.SignalChatID, msg.SignalReceiver) - if err != nil { - msg.log.Warnfln("Failed to delete %s, %s: %v", msg.SignalChatID, msg.MXID, err) - } -} +type Message struct { + qh *dbutil.QueryHelper[*Message] -func (msg *Message) Scan(row dbutil.Scannable) *Message { - var timestamp sql.NullInt64 - var signalChatID, signalReceiver sql.NullString - err := row.Scan( - &msg.MXID, - &msg.MXRoom, - &msg.Sender, - ×tamp, - &signalChatID, - &signalReceiver, - ) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - msg.log.Errorln("Database scan failed:", err) - } - return nil - } - msg.Timestamp = uint64(timestamp.Int64) - msg.SignalChatID = signalChatID.String - msg.SignalReceiver = signalReceiver.String - return msg + Sender uuid.UUID + Timestamp uint64 + PartIndex int + + SignalChatID string + SignalReceiver uuid.UUID + + MXID id.EventID + RoomID id.RoomID } -func (mq *MessageQuery) maybeScan(row *sql.Row) *Message { - if row == nil { - return nil - } - return mq.New().Scan(row) +func newMessage(qh *dbutil.QueryHelper[*Message]) *Message { + return &Message{qh: qh} } -func (mq *MessageQuery) DeleteAll(roomID string) { - _, err := mq.db.Exec(` - DELETE FROM message WHERE mx_room=$1 - `, roomID) - if err != nil { - mq.log.Warnfln("Failed to delete messages in %s: %v", roomID, err) - } +func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Message, error) { + return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid) } -func (mq *MessageQuery) GetAll(chatID string, receiver string) (messages []*Message) { - rows, err := mq.db.Query(getAllMessagesQuery, chatID, receiver) - if err != nil || rows == nil { - return nil - } - for rows.Next() { - messages = append(messages, mq.New().Scan(rows)) - } - return +func (mq *MessageQuery) GetBySignalIDWithUnknownReceiver(ctx context.Context, sender uuid.UUID, timestamp uint64, partIndex int, receiver uuid.UUID) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartBySignalIDWithUnknownReceiverQuery, sender, timestamp, partIndex, receiver) } -func (mq *MessageQuery) GetByMXID(mxid id.EventID) *Message { - return mq.maybeScan(mq.db.QueryRow(getMessageByMXIDQuery, mxid)) +func (mq *MessageQuery) GetBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, partIndex int, receiver uuid.UUID) (*Message, error) { + return mq.QueryOne(ctx, getMessagePartBySignalIDQuery, sender, timestamp, partIndex, receiver) } -func (mq *MessageQuery) GetBySignalID(sender string, timestamp uint64, chatID string, receiver string) *Message { - return mq.maybeScan(mq.db.QueryRow(getMessagesBySignalIDQuery, sender, timestamp, chatID, receiver)) +func (mq *MessageQuery) GetLastPartBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, receiver uuid.UUID) (*Message, error) { + return mq.QueryOne(ctx, getLastMessagePartBySignalIDQuery, sender, timestamp, receiver) } -func (mq *MessageQuery) FindByTimestamps(timestamps []uint64) []*Message { - var messages []*Message - var rows dbutil.Rows - var err error +func (mq *MessageQuery) GetAllPartsBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, receiver uuid.UUID) ([]*Message, error) { + return mq.QueryMany(ctx, getAllMessagePartsBySignalIDQuery, sender, timestamp, receiver) +} - if mq.db.Dialect == dbutil.Postgres { - rows, err = mq.db.Query(` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message - WHERE timestamp=ANY($1) - `, timestamps) +func (mq *MessageQuery) GetManyBySignalID(ctx context.Context, sender uuid.UUID, timestamps []uint64, receiver uuid.UUID) ([]*Message, error) { + if mq.GetDB().Dialect == dbutil.Postgres { + return mq.QueryMany(ctx, getManyMessagesBySignalIDQueryPostgres, sender, receiver, timestamps) } else { - placeholders := "" - for i := 0; i < len(timestamps); i++ { - placeholders += "?" + arguments := make([]any, len(timestamps)+2) + placeholders := make([]string, len(timestamps)) + arguments[0] = sender + arguments[1] = receiver + for i, timestamp := range timestamps { + arguments[i+2] = timestamp + placeholders[i] = fmt.Sprintf("?%d", i+3) } - rows, err = mq.db.Query(` - SELECT mxid, mx_room, sender, timestamp, signal_chat_id, signal_receiver FROM message - WHERE timestamp IN ($1) - `, timestamps) - } - if err != nil { - mq.log.Errorln("FindByTimestamps failed:", err) + return mq.QueryMany(ctx, strings.Replace(getManyMessagesBySignalIDQuerySQLite, "?3", strings.Join(placeholders, ", ?"), 1), arguments...) } - for rows.Next() { - messages = append(messages, mq.New().Scan(rows)) - } - return messages } -func (mq *MessageQuery) FindBySenderAndTimestamp(sender string, timestamp uint64) *Message { - return mq.New().Scan(mq.db.QueryRow(findBySenderAndTimestampQuery, sender, timestamp)) +func (msg *Message) Scan(row dbutil.Scannable) (*Message, error) { + return dbutil.ValueOrErr(msg, row.Scan( + &msg.Sender, &msg.Timestamp, &msg.PartIndex, &msg.SignalChatID, &msg.SignalReceiver, &msg.MXID, &msg.RoomID, + )) +} + +func (msg *Message) sqlVariables() []any { + return []any{msg.Sender, msg.Timestamp, msg.PartIndex, msg.SignalChatID, msg.SignalReceiver, msg.MXID, msg.RoomID} +} + +func (msg *Message) Insert(ctx context.Context) error { + return msg.qh.Exec(ctx, insertMessageQuery, msg.sqlVariables()...) } -func (mq *MessageQuery) GetFirstBefore(room string, timestamp uint64) *Message { - return mq.maybeScan(mq.db.QueryRow(getFirstBeforeQuery, room, timestamp)) +func (msg *Message) Delete(ctx context.Context) error { + return msg.qh.Exec(ctx, deleteMessageQuery, msg.Sender, msg.Timestamp, msg.PartIndex, msg.SignalReceiver) } diff --git a/database/portal.go b/database/portal.go index 2503ff6e..365ef8d4 100644 --- a/database/portal.go +++ b/database/portal.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,49 +17,74 @@ package database import ( + "context" "database/sql" - "errors" - "fmt" + "github.com/google/uuid" "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" + + "go.mau.fi/mautrix-signal/pkg/signalmeow" ) -type PortalQuery struct { - db *Database - log log.Logger -} +const ( + portalBaseSelect = ` + SELECT chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, + revision, encrypted, relay_user_id, expiration_time + FROM portal + ` + getPortalByMXIDQuery = portalBaseSelect + `WHERE mxid=$1` + getPortalByChatIDQuery = portalBaseSelect + `WHERE chat_id=$1 AND receiver=$2` + getPortalsByReceiver = portalBaseSelect + `WHERE receiver=$1` + getAllPortalsWithMXIDQuery = portalBaseSelect + `WHERE mxid IS NOT NULL` + insertPortalQuery = ` + INSERT INTO portal ( + chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, + revision, encrypted, relay_user_id, expiration_time + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + ` + updatePortalQuery = ` + UPDATE portal SET + mxid=$3, name=$4, topic=$5, avatar_hash=$6, avatar_url=$7, name_set=$8, + avatar_set=$9, revision=$10, encrypted=$11, relay_user_id=$12, + expiration_time=$13 + WHERE chat_id=$1 AND receiver=$2 + ` + deletePortalQuery = `DELETE FROM portal WHERE chat_id=$1 AND receiver=$2` +) -func (pq *PortalQuery) New() *Portal { - return &Portal{ - db: pq.db, - log: pq.log, - } +type PortalQuery struct { + *dbutil.QueryHelper[*Portal] } type PortalKey struct { ChatID string - Receiver string + Receiver uuid.UUID } -func NewPortalKey(chatID, receiver string) PortalKey { +func (pk *PortalKey) UserID() uuid.UUID { + parsed, _ := uuid.Parse(pk.ChatID) + return parsed +} + +func (pk *PortalKey) GroupID() signalmeow.GroupIdentifier { + if len(pk.ChatID) == 44 { + return signalmeow.GroupIdentifier(pk.ChatID) + } + return "" +} + +func NewPortalKey(chatID string, receiver uuid.UUID) PortalKey { return PortalKey{ ChatID: chatID, Receiver: receiver, } } -func (key PortalKey) String() string { - return fmt.Sprintf("%s:%s", key.ChatID, key.Receiver) -} - type Portal struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Portal] - ChatID string - Receiver string + PortalKey MXID id.RoomID Name string Topic string @@ -73,212 +98,76 @@ type Portal struct { ExpirationTime int } -func (p *Portal) values() []interface{} { - return []interface{}{ - p.ChatID, - p.Receiver, - p.MXID, - p.Name, - p.Topic, - p.AvatarHash, - p.AvatarURL.String(), - p.NameSet, - p.AvatarSet, - p.Revision, - p.Encrypted, - p.RelayUserID, - p.ExpirationTime, - } +func newPortal(qh *dbutil.QueryHelper[*Portal]) *Portal { + return &Portal{qh: qh} } -func (p *Portal) Scan(row dbutil.Scannable) *Portal { - if row == nil { - p.log.Debugln("nil row passed to Portal.Scan") - return nil - } - var chatID, receiver, mxid, name, topic, avatarHash, avatarURL, relayUserID sql.NullString - var expirationTime sql.NullInt64 +func (pq *PortalQuery) GetByMXID(ctx context.Context, mxid id.RoomID) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByMXIDQuery, mxid) +} + +func (pq *PortalQuery) GetByChatID(ctx context.Context, pk PortalKey) (*Portal, error) { + return pq.QueryOne(ctx, getPortalByChatIDQuery, pk.ChatID, pk.Receiver) +} + +func (pq *PortalQuery) FindPrivateChatsOf(ctx context.Context, receiver uuid.UUID) ([]*Portal, error) { + return pq.QueryMany(ctx, getPortalsByReceiver, receiver) +} + +func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { + return pq.QueryMany(ctx, getAllPortalsWithMXIDQuery) +} + +func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { + var mxid sql.NullString err := row.Scan( - &chatID, - &receiver, + &p.ChatID, + &p.Receiver, &mxid, - &name, - &topic, - &avatarHash, - &avatarURL, + &p.Name, + &p.Topic, + &p.AvatarHash, + &p.AvatarURL, &p.NameSet, &p.AvatarSet, &p.Revision, &p.Encrypted, - &relayUserID, - &expirationTime, + &p.RelayUserID, + &p.ExpirationTime, ) if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - p.log.Warnfln("Error scanning portal row: %w", err) - } else { - p.log.Debugln("No portal row found") - } - return nil + return nil, err } - p.ChatID = chatID.String - p.Receiver = receiver.String p.MXID = id.RoomID(mxid.String) - p.Name = name.String - p.Topic = topic.String - p.AvatarHash = avatarHash.String - p.RelayUserID = id.UserID(relayUserID.String) - p.ExpirationTime = int(expirationTime.Int64) - parsedAvatarURL, err := id.ParseContentURI(avatarURL.String) - if err != nil { - p.log.Warnfln("Error parsing avatar URL: %w", err) - p.AvatarURL = id.ContentURI{} - } else { - p.AvatarURL = parsedAvatarURL - } - return p -} - -func (p *Portal) SetPortalKey(pk PortalKey) { - p.ChatID = pk.ChatID - p.Receiver = pk.Receiver + return p, nil } -func (p *Portal) Key() PortalKey { - return PortalKey{ - ChatID: p.ChatID, - Receiver: p.Receiver, - } -} - -func (p *Portal) Insert() error { - q := ` - INSERT INTO portal ( - chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, - revision, encrypted, relay_user_id, expiration_time - ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) - ` - _, err := p.db.Exec(q, p.values()...) - return err -} - -func (p *Portal) Update() error { - q := ` - UPDATE portal SET mxid=$3, name=$4, topic=$5, avatar_hash=$6, avatar_url=$7, name_set=$8, - avatar_set=$9, revision=$10, encrypted=$11, relay_user_id=$12, - expiration_time=$13 - WHERE chat_id=$1 AND receiver=$2 - ` - _, err := p.db.Exec(q, p.values()...) - return err -} - -func (p *Portal) Delete() error { - q := "DELETE FROM portal WHERE chat_id=$1 AND receiver=$2" - _, err := p.db.Exec(q, p.ChatID, p.Receiver) - return err -} - -const ( - portalColumns = ` - chat_id, receiver, mxid, name, topic, avatar_hash, avatar_url, name_set, avatar_set, - revision, encrypted, relay_user_id, expiration_time - ` -) - -func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { - q := fmt.Sprintf("SELECT %s FROM portal WHERE mxid=$1", portalColumns) - log.Debugfln("mxid: %s", mxid.String()) - log.Debugfln("QUERY: %s", q) - row := pq.db.QueryRow(q, mxid.String()) - log.Debugfln("ROW: %s", row) - p := pq.New() - return p.Scan(row) -} - -func (pq *PortalQuery) GetByChatID(pk PortalKey) *Portal { - q := fmt.Sprintf("SELECT %s FROM portal WHERE chat_id=$1 AND receiver=$2", portalColumns) - log.Debugfln("QUERY: %s", q) - row := pq.db.QueryRow(q, pk.ChatID, pk.Receiver) - log.Debugfln("ROW: %s", row) - p := pq.New() - return p.Scan(row) -} - -func (pq *PortalQuery) FindPrivateChatsOf(receiver string) []*Portal { - q := fmt.Sprintf("SELECT %s FROM portal WHERE receiver=$1", portalColumns) - rows, err := pq.db.Query(q, receiver) - log.Debugfln("receiver: %s", receiver) - log.Debugfln("QUERY: %s", q) - if err != nil { - pq.log.Warnfln("Error querying private chats of %s: %w", receiver, err) - return nil - } - defer rows.Close() - var portals []*Portal - for rows.Next() { - p := pq.New() - if p.Scan(rows) != nil { - portals = append(portals, p) - } +func (p *Portal) sqlVariables() []any { + return []any{ + p.ChatID, + p.Receiver, + dbutil.StrPtr(p.MXID), + p.Name, + p.Topic, + p.AvatarHash, + p.AvatarURL, + p.NameSet, + p.AvatarSet, + p.Revision, + p.Encrypted, + p.RelayUserID, + p.ExpirationTime, } - return portals } -func (pq *PortalQuery) FindPrivateChatsWith(otherUser string) []*Portal { - q := fmt.Sprintf("SELECT %s FROM portal WHERE chat_id=$1 AND receiver<>''", portalColumns) - rows, err := pq.db.Query(q, otherUser) - log.Debugfln("otherUser: %s", otherUser) - log.Debugfln("QUERY: %s", q) - if err != nil { - pq.log.Warnfln("Error querying private chats with %s: %w", otherUser, err) - return nil - } - defer rows.Close() - var portals []*Portal - for rows.Next() { - p := pq.New() - if p.Scan(rows) != nil { - portals = append(portals, p) - } - } - return portals +func (p *Portal) Insert(ctx context.Context) error { + return p.qh.Exec(ctx, insertPortalQuery, p.sqlVariables()...) } -func (pq *PortalQuery) AllWithRoom() []*Portal { - q := fmt.Sprintf("SELECT %s FROM portal WHERE mxid IS NOT NULL", portalColumns) - rows, err := pq.db.Query(q) - log.Debugfln("BY ISTESF QUERY: %s", q) - if err != nil { - pq.log.Warnfln("Error querying all portals with room: %w", err) - return nil - } - defer rows.Close() - var portals []*Portal - for rows.Next() { - p := pq.New() - if p.Scan(rows) != nil { - portals = append(portals, p) - } - } - return portals +func (p *Portal) Update(ctx context.Context) error { + return p.qh.Exec(ctx, updatePortalQuery, p.sqlVariables()...) } -func (pq *PortalQuery) GetAll() []*Portal { - q := fmt.Sprintf("SELECT %s FROM portal", portalColumns) - rows, err := pq.db.Query(q) - log.Debugfln("ALLQUERY: %s", q) - if err != nil { - pq.log.Warnfln("Error querying all portals: %w", err) - return nil - } - defer rows.Close() - var portals []*Portal - for rows.Next() { - p := pq.New() - if p.Scan(rows) != nil { - portals = append(portals, p) - } - } - return portals +func (p *Portal) Delete(ctx context.Context) error { + return p.qh.Exec(ctx, deletePortalQuery, p.ChatID, p.Receiver) } diff --git a/database/puppet.go b/database/puppet.go index c72ffa0a..5e3bdd01 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,33 +17,52 @@ package database import ( + "context" "database/sql" - "errors" - "fmt" + "github.com/google/uuid" "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" ) -type PuppetQuery struct { - db *Database - log log.Logger -} +const ( + puppetBaseSelect = ` + SELECT uuid, number, name, name_quality, avatar_hash, avatar_url, name_set, avatar_set, + contact_info_set, is_registered, custom_mxid, access_token + FROM puppet + ` + getPuppetBySignalIDQuery = puppetBaseSelect + `WHERE uuid=$1` + getPuppetByNumberQuery = puppetBaseSelect + `WHERE number=$1` + getPuppetByCustomMXIDQuery = puppetBaseSelect + `WHERE custom_mxid=$1` + getPuppetsWithCustomMXID = puppetBaseSelect + `WHERE custom_mxid<>''` + updatePuppetQuery = ` + UPDATE puppet SET + number=$2, name=$3, name_quality=$4, avatar_hash=$5, avatar_url=$6, + name_set=$7, avatar_set=$8, contact_info_set=$9, is_registered=$10, + custom_mxid=$11, access_token=$12 + WHERE uuid=$1 + ` + insertPuppetQuery = ` + INSERT INTO puppet ( + uuid, number, name, name_quality, avatar_hash, avatar_url, + name_set, avatar_set, contact_info_set, is_registered, + custom_mxid, access_token + ) + VALUES ( + $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12 + ) + ` +) -func (pq *PuppetQuery) New() *Puppet { - return &Puppet{ - db: pq.db, - log: pq.log, - } +type PuppetQuery struct { + *dbutil.QueryHelper[*Puppet] } type Puppet struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Puppet] - SignalID string - Number *string + SignalID uuid.UUID + Number string Name string NameQuality int AvatarHash string @@ -58,179 +77,71 @@ type Puppet struct { ContactInfoSet bool } -func (p *Puppet) values() []interface{} { - return []interface{}{ - p.SignalID, - p.Number, - p.Name, - p.NameQuality, - p.AvatarHash, - p.AvatarURL.String(), - p.NameSet, - p.AvatarSet, - p.ContactInfoSet, - p.IsRegistered, - p.CustomMXID.String(), - p.AccessToken, - } +func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet { + return &Puppet{qh: qh} +} + +func (pq *PuppetQuery) GetBySignalID(ctx context.Context, signalID uuid.UUID) (*Puppet, error) { + return pq.QueryOne(ctx, getPuppetBySignalIDQuery, signalID) } -func (p *Puppet) Scan(row dbutil.Scannable) *Puppet { - var number, name, avatarHash, avatarURL, customMXID, accessToken sql.NullString +func (pq *PuppetQuery) GetByNumber(ctx context.Context, number string) (*Puppet, error) { + return pq.QueryOne(ctx, getPuppetByNumberQuery, number) +} + +func (pq *PuppetQuery) GetByCustomMXID(ctx context.Context, mxid id.UserID) (*Puppet, error) { + return pq.QueryOne(ctx, getPuppetByCustomMXIDQuery, mxid) +} + +func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, error) { + return pq.QueryMany(ctx, getPuppetsWithCustomMXID) +} + +func (p *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { + var number, customMXID sql.NullString err := row.Scan( &p.SignalID, &number, - &name, + &p.Name, &p.NameQuality, - &avatarHash, - &avatarURL, + &p.AvatarHash, + &p.AvatarURL, &p.NameSet, &p.AvatarSet, &p.ContactInfoSet, &p.IsRegistered, &customMXID, - &accessToken, + &p.AccessToken, ) if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - p.log.Warnfln("Error scanning puppet row: %w", err) - } - return nil - } - parsedAvatarURL, err := id.ParseContentURI(avatarURL.String) - if err != nil { - p.log.Warnfln("Error parsing avatar URL: %w", err) - p.AvatarURL = id.ContentURI{} - } else { - p.AvatarURL = parsedAvatarURL + return nil, nil } - - if number.Valid { - p.Number = &number.String - } else { - p.Number = nil - } - p.Name = name.String - p.AvatarHash = avatarHash.String + p.Number = number.String p.CustomMXID = id.UserID(customMXID.String) - p.AccessToken = accessToken.String - return p -} - -func (p *Puppet) deleteExistingNumber(tx *dbutil.LoggingTxn) error { - if p.Number == nil || *p.Number == "" { - return nil - } - _, err := tx.Exec("UPDATE puppet SET number=null WHERE number=$1 AND uuid<>$2", p.Number, p.SignalID) - return err -} - -func (p *Puppet) Insert() error { - q := ` - INSERT INTO puppet (uuid, number, name, name_quality, avatar_hash, avatar_url, - name_set, avatar_set, contact_info_set, is_registered, - custom_mxid, access_token) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, - $11, $12) - ` - tx, err := p.db.Begin() - if err != nil { - return fmt.Errorf("error starting transaction: %w", err) - } - defer tx.Rollback() - err = p.deleteExistingNumber(tx) - if err != nil { - return fmt.Errorf("error deleting existing number: %w", err) - } - _, err = tx.Exec(q, p.values()...) - if err != nil { - return fmt.Errorf("error inserting puppet: %w", err) - } - err = tx.Commit() - if err != nil { - return fmt.Errorf("error committing transaction: %w", err) - } - return nil + return p, nil } -func (p *Puppet) UpdateNumber() error { - q := "UPDATE puppet SET number=$1 WHERE uuid=$2" - tx, err := p.db.Begin() - if err != nil { - return fmt.Errorf("error starting transaction: %w", err) - } - defer tx.Rollback() - err = p.deleteExistingNumber(tx) - if err != nil { - return fmt.Errorf("error deleting existing number: %w", err) - } - _, err = tx.Exec(q, p.Number, p.SignalID) - if err != nil { - return fmt.Errorf("error updating puppet number: %w", err) - } - err = tx.Commit() - if err != nil { - return fmt.Errorf("error committing transaction: %w", err) - } - return nil -} - -func (p *Puppet) Update() error { - q := ` - UPDATE puppet SET - number=$2, name=$3, name_quality=$4, avatar_hash=$5, avatar_url=$6, - name_set=$7, avatar_set=$8, contact_info_set=$9, is_registered=$10, - custom_mxid=$11, access_token=$12 - WHERE uuid=$1 - ` - // check for db - if p.db == nil { - return fmt.Errorf("no database connection") - } - _, err := p.db.Exec(q, p.values()...) - if err != nil { - return fmt.Errorf("error updating puppet: %w", err) +func (p *Puppet) sqlVariables() []any { + return []any{ + p.SignalID, + dbutil.StrPtr(p.Number), + p.Name, + p.NameQuality, + p.AvatarHash, + p.AvatarURL, + p.NameSet, + p.AvatarSet, + p.ContactInfoSet, + p.IsRegistered, + dbutil.StrPtr(p.CustomMXID), + p.AccessToken, } - return nil } -const ( - selectBase = ` - SELECT uuid, number, name, name_quality, avatar_hash, avatar_url, name_set, avatar_set, - contact_info_set, is_registered, custom_mxid, access_token - FROM puppet - ` -) - -func (pq *PuppetQuery) GetBySignalID(signalID string) *Puppet { - q := selectBase + " WHERE uuid=$1" - row := pq.db.QueryRow(q, signalID) - return pq.New().Scan(row) -} - -func (pq *PuppetQuery) GetByNumber(number string) *Puppet { - q := selectBase + " WHERE number=$1" - row := pq.db.QueryRow(q, number) - return pq.New().Scan(row) +func (p *Puppet) Insert(ctx context.Context) error { + return p.qh.Exec(ctx, insertPuppetQuery, p.sqlVariables()...) } -func (pq *PuppetQuery) GetByCustomMXID(mxid id.UserID) *Puppet { - q := selectBase + " WHERE custom_mxid=$1" - row := pq.db.QueryRow(q, mxid.String()) - return pq.New().Scan(row) -} - -func (pq *PuppetQuery) GetAllWithCustomMXID() ([]*Puppet, error) { - q := selectBase + " WHERE custom_mxid IS NOT NULL AND custom_mxid <> ''" - rows, err := pq.db.Query(q) - if err != nil { - return nil, fmt.Errorf("error getting all puppets with custom mxid: %w", err) - } - defer rows.Close() - puppets := []*Puppet{} - for rows.Next() { - pq.New().Scan(rows) - puppets = append(puppets, pq.New().Scan(rows)) - } - return puppets, nil +func (p *Puppet) Update(ctx context.Context) error { + return p.qh.Exec(ctx, updatePuppetQuery, p.sqlVariables()...) } diff --git a/database/reaction.go b/database/reaction.go index 6a15af89..81980c6e 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,103 +17,72 @@ package database import ( - "database/sql" - "errors" + "context" + "github.com/google/uuid" "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" ) +const ( + getReactionByMXIDQuery = `SELECT msg_author, msg_timestamp, author, emoji, signal_chat_id, signal_receiver, mxid, mx_room FROM reaction WHERE mxid=$1` + getReactionBySignalIDQuery = `SELECT msg_author, msg_timestamp, author, emoji, signal_chat_id, signal_receiver, mxid, mx_room FROM reaction WHERE msg_author=$1 AND msg_timestamp=$2 AND author=$3 AND signal_receiver=$4` + insertReactionQuery = ` + INSERT INTO reaction (msg_author, msg_timestamp, author, emoji, signal_chat_id, signal_receiver, mxid, mx_room) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ` + deleteReactionQuery = ` + DELETE FROM reaction WHERE msg_author=$1 AND msg_timestamp=$2 AND author=$3 AND signal_receiver=$4 + ` +) + type ReactionQuery struct { - db *Database - log log.Logger + *dbutil.QueryHelper[*Reaction] } -func (mq *ReactionQuery) New() *Reaction { - return &Reaction{ - db: mq.db, - log: mq.log, - } +func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction { + return &Reaction{qh: qh} } type Reaction struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*Reaction] - MXID id.EventID - MXRoom id.RoomID + MsgAuthor uuid.UUID + MsgTimestamp uint64 + Author uuid.UUID + Emoji string SignalChatID string - SignalReceiver string + SignalReceiver uuid.UUID - Author string - MsgAuthor string - MsgTimestamp uint64 - Emoji string + MXID id.EventID + RoomID id.RoomID } -func (r *Reaction) Insert(txn dbutil.Execable) { - if txn == nil { - txn = r.db - } - _, err := txn.Exec(` - INSERT INTO reaction (mxid, mx_room, signal_chat_id, signal_receiver, author, msg_author, msg_timestamp, emoji) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8) - `, - r.MXID.String(), r.MXRoom, r.SignalChatID, r.SignalReceiver, r.Author, r.MsgAuthor, r.MsgTimestamp, r.Emoji, - ) - r.log.Debugfln("Inserting reaction", r.MXID, r.MXRoom, r.SignalChatID, r.SignalReceiver, r.Author, r.MsgAuthor, r.MsgTimestamp, r.Emoji) - if err != nil { - r.log.Warnfln("Failed to insert %s, %s: %v", r.SignalChatID, r.MXID, err) - } +func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid) } -func (r *Reaction) Delete(txn dbutil.Execable) { - if txn == nil { - txn = r.db - } - _, err := txn.Exec(` - DELETE FROM reaction - WHERE signal_chat_id=$1 AND signal_receiver=$2 AND author=$3 AND msg_author=$4 AND msg_timestamp=$5 - `, - r.SignalChatID, r.SignalReceiver, r.Author, r.MsgAuthor, r.MsgTimestamp, - ) - if err != nil { - r.log.Warnfln("Failed to delete %s, %s: %v", r.SignalChatID, r.MXID, err) - } +func (rq *ReactionQuery) GetBySignalID(ctx context.Context, msgAuthor uuid.UUID, msgTimestamp uint64, author, signalReceiver uuid.UUID) (*Reaction, error) { + return rq.QueryOne(ctx, getReactionBySignalIDQuery, msgAuthor, msgTimestamp, author, signalReceiver) } -func (r *Reaction) Scan(row dbutil.Scannable) *Reaction { - err := row.Scan(&r.MXID, &r.MXRoom, &r.SignalChatID, &r.SignalReceiver, &r.Author, &r.MsgAuthor, &r.MsgTimestamp, &r.Emoji) - if err != nil { - if !errors.Is(err, sql.ErrNoRows) { - r.log.Errorln("Database scan failed:", err) - } - return nil - } - return r +func (r *Reaction) Scan(row dbutil.Scannable) (*Reaction, error) { + return dbutil.ValueOrErr(r, row.Scan( + &r.MsgAuthor, &r.MsgTimestamp, &r.Author, &r.Emoji, &r.SignalChatID, &r.SignalReceiver, &r.MXID, &r.RoomID, + )) } -func (rq *ReactionQuery) maybeScan(row *sql.Row) *Reaction { - if row == nil { - return nil +func (r *Reaction) sqlVariables() []any { + return []any{ + r.MsgAuthor, r.MsgTimestamp, r.Author, r.Emoji, r.SignalChatID, r.SignalReceiver, r.MXID, r.RoomID, } - return rq.New().Scan(row) } -func (rq *ReactionQuery) GetByMXID(mxid id.EventID, roomID id.RoomID) *Reaction { - const getReactionByMXIDQuery = ` - SELECT mxid, mx_room, signal_chat_id, signal_receiver, author, msg_author, msg_timestamp, emoji FROM reaction - WHERE mxid=$1 and mx_room=$2 - ` - return rq.maybeScan(rq.db.QueryRow(getReactionByMXIDQuery, mxid, roomID)) +func (r *Reaction) Insert(ctx context.Context) error { + return r.qh.Exec(ctx, insertReactionQuery, r.sqlVariables()...) } -func (rq *ReactionQuery) GetBySignalID(signalChatID string, signalReceiver string, author string, msgAuthor string, msgTimestamp uint64) *Reaction { - const getReactionBySignalIDQuery = ` - SELECT mxid, mx_room, signal_chat_id, signal_receiver, author, msg_author, msg_timestamp, emoji FROM reaction - WHERE signal_chat_id=$1 AND signal_receiver=$2 AND author=$3 AND msg_author=$4 AND msg_timestamp=$5 - ` - return rq.maybeScan(rq.db.QueryRow(getReactionBySignalIDQuery, signalChatID, signalReceiver, author, msgAuthor, msgTimestamp)) +func (r *Reaction) Delete(ctx context.Context) error { + return r.qh.Exec(ctx, deleteReactionQuery, r.MsgAuthor, r.MsgTimestamp, r.Author, r.SignalReceiver) } diff --git a/database/upgrades/00-latest.sql b/database/upgrades/00-latest.sql index eda8b846..b62643ab 100644 --- a/database/upgrades/00-latest.sql +++ b/database/upgrades/00-latest.sql @@ -1,87 +1,97 @@ --- v0 -> v15: Latest revision +-- v0 -> v17: Latest revision CREATE TABLE portal ( - chat_id TEXT, - receiver TEXT, + chat_id TEXT NOT NULL, + receiver uuid NOT NULL, mxid TEXT, - name TEXT, - topic TEXT, + name TEXT NOT NULL, + topic TEXT NOT NULL, encrypted BOOLEAN NOT NULL DEFAULT false, - avatar_hash TEXT, - avatar_url TEXT, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, name_set BOOLEAN NOT NULL DEFAULT false, avatar_set BOOLEAN NOT NULL DEFAULT false, revision INTEGER NOT NULL DEFAULT 0, - expiration_time BIGINT, - relay_user_id TEXT, - PRIMARY KEY (chat_id, receiver) + expiration_time BIGINT NOT NULL, + relay_user_id TEXT NOT NULL, + + PRIMARY KEY (chat_id, receiver), + CONSTRAINT portal_mxid_unique UNIQUE(mxid) ); CREATE TABLE puppet ( - uuid UUID PRIMARY KEY, - number TEXT UNIQUE, - name TEXT, - name_quality INTEGER NOT NULL DEFAULT 0, - avatar_hash TEXT, - avatar_url TEXT, + uuid uuid PRIMARY KEY, + number TEXT UNIQUE, + name TEXT NOT NULL, + name_quality INTEGER NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, name_set BOOLEAN NOT NULL DEFAULT false, avatar_set BOOLEAN NOT NULL DEFAULT false, - is_registered BOOLEAN NOT NULL DEFAULT false, + is_registered BOOLEAN NOT NULL DEFAULT false, + contact_info_set BOOLEAN NOT NULL DEFAULT false, custom_mxid TEXT, - access_token TEXT, - contact_info_set BOOLEAN NOT NULL DEFAULT false + access_token TEXT NOT NULL, + + CONSTRAINT puppet_custom_mxid_unique UNIQUE(custom_mxid) ); CREATE TABLE "user" ( - mxid TEXT PRIMARY KEY, - username TEXT, - uuid UUID, - management_room TEXT + mxid TEXT PRIMARY KEY, + uuid uuid, + phone TEXT, + + management_room TEXT, + + CONSTRAINT user_uuid_unique UNIQUE(uuid) ); CREATE TABLE message ( + sender uuid NOT NULL, + timestamp BIGINT NOT NULL, + part_index INTEGER NOT NULL, + + signal_chat_id TEXT NOT NULL, + signal_receiver uuid NOT NULL, + mxid TEXT NOT NULL, mx_room TEXT NOT NULL, - sender UUID, - timestamp BIGINT, - signal_chat_id TEXT, - signal_receiver TEXT, - PRIMARY KEY (sender, timestamp, signal_chat_id, signal_receiver), - FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver) ON DELETE CASCADE, + PRIMARY KEY (sender, timestamp, part_index, signal_receiver), + CONSTRAINT message_portal_fkey FOREIGN KEY (signal_chat_id, signal_receiver) + REFERENCES portal(chat_id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE, - UNIQUE (mxid, mx_room) + CONSTRAINT message_mxid_unique UNIQUE (mxid) ); CREATE TABLE reaction ( - mxid TEXT NOT NULL, - mx_room TEXT NOT NULL, - - signal_chat_id TEXT NOT NULL, - signal_receiver TEXT NOT NULL, - - author UUID NOT NULL, - msg_author UUID NOT NULL, - msg_timestamp BIGINT NOT NULL, - emoji TEXT NOT NULL, - - PRIMARY KEY (signal_chat_id, signal_receiver, msg_author, msg_timestamp, author), - CONSTRAINT reaction_message_fkey - FOREIGN KEY (msg_author, msg_timestamp, signal_chat_id, signal_receiver) - REFERENCES message(sender, timestamp, signal_chat_id, signal_receiver) - ON DELETE CASCADE, + msg_author uuid NOT NULL, + msg_timestamp BIGINT NOT NULL, + -- part_index is not used in reactions, but is required for the foreign key. + _part_index INTEGER NOT NULL DEFAULT 0, + + author uuid NOT NULL, + emoji TEXT NOT NULL, + + signal_chat_id TEXT NOT NULL, + signal_receiver uuid NOT NULL, + + mxid TEXT NOT NULL, + mx_room TEXT NOT NULL, + + PRIMARY KEY (msg_author, msg_timestamp, author, signal_receiver), + CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE, - UNIQUE (mxid, mx_room) + CONSTRAINT reaction_mxid_unique UNIQUE (mxid) ); CREATE TABLE disappearing_message ( - room_id TEXT, - mxid TEXT, - expiration_seconds BIGINT, - expiration_ts BIGINT, - - PRIMARY KEY (room_id, mxid) + mxid TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + expiration_seconds BIGINT NOT NULL, + expiration_ts BIGINT ); diff --git a/database/upgrades/16-refactor-postgres.sql b/database/upgrades/16-refactor-postgres.sql new file mode 100644 index 00000000..e1bb17e9 --- /dev/null +++ b/database/upgrades/16-refactor-postgres.sql @@ -0,0 +1,104 @@ +-- v16: Refactor types (Postgres) +-- only: postgres + +-- Drop constraints so we can fix timestamps. +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE message DROP CONSTRAINT message_pkey; + +-- Add part index to message and fix the hacky timestamps +ALTER TABLE message ADD COLUMN part_index INTEGER; +UPDATE message + SET timestamp=CASE WHEN timestamp > 1500000000000000 THEN timestamp / 1000 ELSE timestamp END, + part_index=CASE WHEN timestamp > 1500000000000000 THEN timestamp % 1000 ELSE 0 END; +ALTER TABLE message ALTER COLUMN part_index SET NOT NULL; +ALTER TABLE reaction ADD COLUMN _part_index INTEGER NOT NULL DEFAULT 0; + +-- Re-add the dropped constraints (but with part index and no chat) +ALTER TABLE message ADD PRIMARY KEY (sender, timestamp, part_index, signal_receiver); +ALTER TABLE message DROP CONSTRAINT message_signal_chat_id_signal_receiver_fkey; +ALTER TABLE message ADD CONSTRAINT message_portal_fkey + FOREIGN KEY (signal_chat_id, signal_receiver) + REFERENCES portal (chat_id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE; +-- Also update the reaction primary key +ALTER TABLE reaction DROP CONSTRAINT reaction_pkey; +ALTER TABLE reaction ADD PRIMARY KEY (author, msg_author, msg_timestamp, signal_receiver); + +-- Change unique constraint from (mxid, mx_room) to just mxid. +ALTER TABLE message DROP CONSTRAINT message_mxid_mx_room_key; +ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (mxid); +ALTER TABLE reaction DROP CONSTRAINT reaction_mxid_mx_room_key; +ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (mxid); + +CREATE TABLE lost_portals ( + mxid TEXT PRIMARY KEY, + chat_id TEXT, + receiver TEXT +); +INSERT INTO lost_portals SELECT mxid, chat_id, receiver FROM portal WHERE mxid<>''; + +-- Make mxid column unique (requires using nulls for missing values) +UPDATE portal SET mxid=NULL WHERE mxid=''; +ALTER TABLE portal ADD CONSTRAINT portal_mxid_unique UNIQUE(mxid); +-- Delete any portals that aren't associated with logged-in users. +DELETE FROM portal WHERE receiver<>'' AND receiver NOT IN (SELECT username FROM "user" WHERE uuid IS NOT NULL); +-- Change receiver to uuid instead of phone number, also add nil uuid for groups. +UPDATE portal SET receiver=(SELECT uuid FROM "user" WHERE username=receiver) WHERE receiver<>''; +UPDATE portal SET receiver='00000000-0000-0000-0000-000000000000' WHERE receiver=''; +-- Drop the foreign keys again to allow changing types (the ON UPDATE CASCADEs are needed for the above step) +ALTER TABLE message DROP CONSTRAINT message_portal_fkey; +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE portal ALTER COLUMN receiver TYPE uuid USING receiver::uuid; +ALTER TABLE message ALTER COLUMN signal_receiver TYPE uuid USING signal_receiver::uuid; +ALTER TABLE reaction ALTER COLUMN signal_receiver TYPE uuid USING signal_receiver::uuid; +-- Re-add the dropped constraints again +ALTER TABLE message ADD CONSTRAINT message_portal_fkey + FOREIGN KEY (signal_chat_id, signal_receiver) + REFERENCES portal (chat_id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE; +-- Delete group v1 portal entries +DELETE FROM portal WHERE chat_id NOT LIKE '________-____-____-____-____________' AND LENGTH(chat_id) <> 44; +DELETE FROM lost_portals WHERE mxid IN (SELECT mxid FROM portal WHERE mxid<>''); + +-- Remove unnecessary nullables in portal +UPDATE portal SET name='' WHERE name IS NULL; +UPDATE portal SET topic='' WHERE topic IS NULL; +UPDATE portal SET avatar_hash='' WHERE avatar_hash IS NULL; +UPDATE portal SET avatar_url='' WHERE avatar_url IS NULL; +UPDATE portal SET expiration_time=0 WHERE expiration_time IS NULL; +UPDATE portal SET relay_user_id='' WHERE relay_user_id IS NULL; +ALTER TABLE portal ALTER COLUMN name SET NOT NULL; +ALTER TABLE portal ALTER COLUMN topic SET NOT NULL; +ALTER TABLE portal ALTER COLUMN avatar_hash SET NOT NULL; +ALTER TABLE portal ALTER COLUMN avatar_url SET NOT NULL; +ALTER TABLE portal ALTER COLUMN expiration_time SET NOT NULL; +ALTER TABLE portal ALTER COLUMN relay_user_id SET NOT NULL; + +-- Add unique constraint to custom_mxid +UPDATE puppet SET custom_mxid=NULL WHERE custom_mxid=''; +ALTER TABLE puppet ADD CONSTRAINT puppet_custom_mxid_unique UNIQUE(custom_mxid); +-- Remove unnecessary nullables in puppet +UPDATE puppet SET name='' WHERE name IS NULL; +UPDATE puppet SET avatar_hash='' WHERE avatar_hash IS NULL; +UPDATE puppet SET avatar_url='' WHERE avatar_url IS NULL; +UPDATE puppet SET access_token='' WHERE access_token IS NULL; +ALTER TABLE puppet ALTER COLUMN name SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN avatar_hash SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN avatar_url SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN access_token SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN name_quality DROP DEFAULT; + +ALTER TABLE "user" ADD CONSTRAINT user_uuid_unique UNIQUE(uuid); +ALTER TABLE "user" RENAME COLUMN username TO phone; + +-- Drop room_id from disappearing message primary key +ALTER TABLE disappearing_message DROP CONSTRAINT disappearing_message_pkey; +ALTER TABLE disappearing_message ADD PRIMARY KEY (mxid); +-- Remove unnecessary nullables in disappearing_message +ALTER TABLE disappearing_message ALTER COLUMN room_id SET NOT NULL; +UPDATE disappearing_message SET expiration_seconds=0 WHERE expiration_seconds IS NULL; +ALTER TABLE disappearing_message ALTER COLUMN expiration_seconds SET NOT NULL; diff --git a/database/upgrades/17-refactor-sqlite.sql b/database/upgrades/17-refactor-sqlite.sql new file mode 100644 index 00000000..78f62eb2 --- /dev/null +++ b/database/upgrades/17-refactor-sqlite.sql @@ -0,0 +1,187 @@ +-- v17: Refactor types (SQLite) +-- transaction: off +-- only: sqlite + +-- This is separate from v16 so that postgres can run with transaction: on +-- (split upgrades by dialect don't currently allow disabling transaction in only one dialect) + +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + sender uuid NOT NULL, + timestamp BIGINT NOT NULL, + part_index INTEGER NOT NULL, + + signal_chat_id TEXT NOT NULL, + signal_receiver TEXT NOT NULL, + + mxid TEXT NOT NULL, + mx_room TEXT NOT NULL, + + PRIMARY KEY (sender, timestamp, part_index, signal_receiver), + CONSTRAINT message_portal_fkey FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE, + CONSTRAINT message_mxid_unique UNIQUE (mxid) +); + +CREATE TABLE reaction_new ( + msg_author uuid NOT NULL, + msg_timestamp BIGINT NOT NULL, + -- part_index is not used in reactions, but is required for the foreign key. + _part_index INTEGER NOT NULL DEFAULT 0, + + author uuid NOT NULL, + emoji TEXT NOT NULL, + + signal_chat_id TEXT NOT NULL, + signal_receiver TEXT NOT NULL, + + mxid TEXT NOT NULL, + mx_room TEXT NOT NULL, + + PRIMARY KEY (msg_author, msg_timestamp, author, signal_receiver), + CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE, + CONSTRAINT reaction_mxid_unique UNIQUE (mxid) +); + + +INSERT INTO message_new +SELECT sender, + CASE WHEN timestamp > 1500000000000000 THEN timestamp / 1000 ELSE timestamp END, + CASE WHEN timestamp > 1500000000000000 THEN timestamp % 1000 ELSE 0 END, + COALESCE(signal_chat_id, ''), + COALESCE(signal_receiver, ''), + mxid, + mx_room +FROM message; + +INSERT INTO reaction_new +SELECT msg_author, + msg_timestamp, + 0, -- _part_index + author, + emoji, + COALESCE(signal_chat_id, ''), + COALESCE(signal_receiver, ''), + mxid, + mx_room +FROM reaction; + +DROP TABLE message; +DROP TABLE reaction; +ALTER TABLE message_new RENAME TO message; +ALTER TABLE reaction_new RENAME TO reaction; + +PRAGMA foreign_key_check; +COMMIT; + +PRAGMA foreign_keys = ON; + +BEGIN; +CREATE TABLE lost_portals ( + mxid TEXT PRIMARY KEY, + chat_id TEXT, + receiver TEXT +); +INSERT INTO lost_portals SELECT mxid, chat_id, receiver FROM portal WHERE mxid<>''; +DELETE FROM portal WHERE receiver<>'' AND receiver NOT IN (SELECT username FROM "user" WHERE uuid<>''); +UPDATE portal SET receiver=(SELECT uuid FROM "user" WHERE username=receiver) WHERE receiver<>''; +UPDATE portal SET receiver='00000000-0000-0000-0000-000000000000' WHERE receiver=''; +DELETE FROM portal WHERE chat_id NOT LIKE '________-____-____-____-____________' AND LENGTH(chat_id) <> 44; +DELETE FROM lost_portals WHERE mxid IN (SELECT mxid FROM portal WHERE mxid<>''); +COMMIT; + +PRAGMA foreign_keys = OFF; + +BEGIN; + +CREATE TABLE portal_new ( + chat_id TEXT NOT NULL, + receiver uuid NOT NULL, + mxid TEXT, + name TEXT NOT NULL, + topic TEXT NOT NULL, + encrypted BOOLEAN NOT NULL DEFAULT false, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, + name_set BOOLEAN NOT NULL DEFAULT false, + avatar_set BOOLEAN NOT NULL DEFAULT false, + revision INTEGER NOT NULL DEFAULT 0, + + expiration_time BIGINT NOT NULL, + relay_user_id TEXT NOT NULL, + + PRIMARY KEY (chat_id, receiver), + CONSTRAINT portal_mxid_unique UNIQUE(mxid) +); + +INSERT INTO portal_new + SELECT chat_id, receiver, CASE WHEN mxid='' THEN NULL ELSE mxid END, + COALESCE(name, ''), COALESCE(topic, ''), encrypted, COALESCE(avatar_hash, ''), COALESCE(avatar_url, ''), + name_set, avatar_set, revision, COALESCE(expiration_time, 0), COALESCE(relay_user_id, '') + FROM portal; +DROP TABLE portal; +ALTER TABLE portal_new RENAME TO portal; + +CREATE TABLE puppet_new ( + uuid uuid PRIMARY KEY, + number TEXT UNIQUE, + name TEXT NOT NULL, + name_quality INTEGER NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, + name_set BOOLEAN NOT NULL DEFAULT false, + avatar_set BOOLEAN NOT NULL DEFAULT false, + + is_registered BOOLEAN NOT NULL DEFAULT false, + contact_info_set BOOLEAN NOT NULL DEFAULT false, + + custom_mxid TEXT, + access_token TEXT NOT NULL, + + CONSTRAINT puppet_custom_mxid_unique UNIQUE(custom_mxid) +); + +INSERT INTO puppet_new + SELECT uuid, number, COALESCE(name, ''), COALESCE(name_quality, 0), COALESCE(avatar_hash, ''), + COALESCE(avatar_url, ''), name_set, avatar_set, is_registered, contact_info_set, + CASE WHEN custom_mxid='' THEN NULL ELSE custom_mxid END, COALESCE(access_token, '') + FROM puppet; +DROP TABLE puppet; +ALTER TABLE puppet_new RENAME TO puppet; + +CREATE TABLE user_new ( + mxid TEXT PRIMARY KEY, + uuid uuid, + phone TEXT, + + management_room TEXT, + + CONSTRAINT user_uuid_unique UNIQUE(uuid) +); + +INSERT INTO user_new + SELECT mxid, uuid, username, management_room + FROM user; +DROP TABLE user; +ALTER TABLE user_new RENAME TO user; + +CREATE TABLE disappearing_message_new ( + mxid TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + expiration_seconds BIGINT NOT NULL, + expiration_ts BIGINT +); + +INSERT INTO disappearing_message_new + SELECT mxid, room_id, COALESCE(expiration_seconds, 0), expiration_ts + FROM disappearing_message; +DROP TABLE disappearing_message; +ALTER TABLE disappearing_message_new RENAME TO disappearing_message; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/database/user.go b/database/user.go index 28581d8d..3a8e465d 100644 --- a/database/user.go +++ b/database/user.go @@ -1,5 +1,5 @@ // mautrix-signal - A Matrix-signal puppeting bridge. -// Copyright (C) 2023 Scott Weber +// Copyright (C) 2023 Scott Weber, Tulir Asokan // // This program is free software: you can redistribute it and/or modify // it under the terms of the GNU Affero General Public License as published by @@ -17,124 +17,85 @@ package database import ( + "context" "database/sql" + "github.com/google/uuid" "go.mau.fi/util/dbutil" - log "maunium.net/go/maulogger/v2" "maunium.net/go/mautrix/id" ) -type UserQuery struct { - db *Database - log log.Logger -} +const ( + getUserByMXIDQuery = `SELECT mxid, phone, uuid, management_room FROM "user" WHERE mxid=$1` + getUserByPhoneQuery = `SELECT mxid, phone, uuid, management_room FROM "user" WHERE phone=$1` + getUserByUUIDQuery = `SELECT mxid, phone, uuid, management_room FROM "user" WHERE uuid=$1` + getAllLoggedInUsersQuery = `SELECT mxid, phone, uuid, management_room FROM "user" WHERE phone IS NOT NULL` + insertUserQuery = `INSERT INTO "user" (mxid, phone, uuid, management_room) VALUES ($1, $2, $3, $4)` + updateUserQuery = `UPDATE "user" SET phone=$2, uuid=$3, management_room=$4 WHERE mxid=$1` +) -func (uq *UserQuery) New() *User { - return &User{ - db: uq.db, - log: uq.log, - } +type UserQuery struct { + *dbutil.QueryHelper[*User] } type User struct { - db *Database - log log.Logger + qh *dbutil.QueryHelper[*User] MXID id.UserID SignalUsername string - SignalID string + SignalID uuid.UUID ManagementRoom id.RoomID } -func (u *User) sqlVariables() []any { - var username, signalID, managementRoom *string - if u.SignalUsername != "" { - username = &u.SignalUsername - } - if u.SignalID != "" { - signalID = &u.SignalID - } - if u.ManagementRoom != "" { - managementRoom = (*string)(&u.ManagementRoom) - } - return []any{u.MXID, username, signalID, managementRoom} +func newUser(qh *dbutil.QueryHelper[*User]) *User { + return &User{qh: qh} } -func (u *User) Insert() error { - q := `INSERT INTO "user" (mxid, username, uuid, management_room) VALUES ($1, $2, $3, $4)` - _, err := u.db.Exec(q, u.sqlVariables()...) - return err +func (uq *UserQuery) GetByMXID(ctx context.Context, mxid id.UserID) (*User, error) { + return uq.QueryOne(ctx, getUserByMXIDQuery, mxid) } -func (u *User) Update() error { - q := `UPDATE "user" SET username=$2, uuid=$3, management_room=$4 WHERE mxid=$1` - _, err := u.db.Exec(q, u.sqlVariables()...) - return err +func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error) { + return uq.QueryOne(ctx, getUserByPhoneQuery, phone) } -func (u *User) Scan(row dbutil.Scannable) *User { - var username, managementRoom, signalID sql.NullString - err := row.Scan( - &u.MXID, - &username, - &signalID, - &managementRoom, - ) - if err != nil { - if err != sql.ErrNoRows { - u.log.Errorln("Database scan failed:", err) - } - return nil - } - u.SignalUsername = username.String - u.SignalID = signalID.String - u.ManagementRoom = id.RoomID(managementRoom.String) - return u +func (uq *UserQuery) GetBySignalID(ctx context.Context, uuid uuid.UUID) (*User, error) { + return uq.QueryOne(ctx, getUserByUUIDQuery, uuid) } -func (uq *UserQuery) GetByMXID(mxid id.UserID) *User { - q := `SELECT mxid, username, uuid, management_room FROM "user" WHERE mxid=$1` - row := uq.db.QueryRow(q, mxid) - if row == nil { - return nil - } - return uq.New().Scan(row) +func (uq *UserQuery) GetAllLoggedIn(ctx context.Context) ([]*User, error) { + return uq.QueryMany(ctx, getAllLoggedInUsersQuery) } -func (uq *UserQuery) GetByUsername(username string) *User { - q := `SELECT mxid, username, uuid, management_room FROM "user" WHERE username=$1` - row := uq.db.QueryRow(q, username) - if row == nil { - return nil - } - return uq.New().Scan(row) +func (u *User) sqlVariables() []any { + var nu uuid.NullUUID + nu.UUID = u.SignalID + nu.Valid = u.SignalID != uuid.Nil + return []any{u.MXID, dbutil.StrPtr(u.SignalUsername), nu, dbutil.StrPtr(u.ManagementRoom)} } -func (uq *UserQuery) GetBySignalID(uuid string) *User { - q := `SELECT mxid, username, uuid, management_room FROM "user" WHERE uuid=$1` - row := uq.db.QueryRow(q, uuid) - if row == nil { - return nil - } - return uq.New().Scan(row) +func (u *User) Insert(ctx context.Context) error { + return u.qh.Exec(ctx, insertUserQuery, u.sqlVariables()...) +} + +func (u *User) Update(ctx context.Context) error { + return u.qh.Exec(ctx, updateUserQuery, u.sqlVariables()...) } -func (uq *UserQuery) AllLoggedIn() []*User { - q := `SELECT mxid, username, uuid, management_room FROM "user" WHERE username IS NOT NULL` - rows, err := uq.db.Query(q) +func (u *User) Scan(row dbutil.Scannable) (*User, error) { + var phone, managementRoom sql.NullString + var signalID uuid.NullUUID + err := row.Scan( + &u.MXID, + &phone, + &signalID, + &managementRoom, + ) if err != nil { - uq.log.Errorln("Database query failed:", err) - return nil - } - defer rows.Close() - - var users []*User - for rows.Next() { - u := uq.New().Scan(rows) - if u == nil { - continue - } - users = append(users, u) + return nil, err } - return users + u.SignalUsername = phone.String + u.SignalID = signalID.UUID + u.ManagementRoom = id.RoomID(managementRoom.String) + return u, nil } diff --git a/disappearing.go b/disappearing.go index c93dd443..8bedbbb6 100644 --- a/disappearing.go +++ b/disappearing.go @@ -35,12 +35,23 @@ type DisappearingMessagesManager struct { checkMessagesChan chan struct{} } -func (dmm *DisappearingMessagesManager) ScheduleDisappearingForRoom(roomID id.RoomID) { - dmm.Log.Debug().Msgf("Scheduling disappearing messages for %s", roomID) - disappearingMessages := dmm.DB.DisappearingMessage.GetUnscheduledForRoom(roomID) +func (dmm *DisappearingMessagesManager) ScheduleDisappearingForRoom(ctx context.Context, roomID id.RoomID) { + log := dmm.Log.With().Str("room_id", roomID.String()).Logger() + disappearingMessages, err := dmm.DB.DisappearingMessage.GetUnscheduledForRoom(ctx, roomID) + if err != nil { + log.Err(err).Msg("Failed to get unscheduled disappearing messages") + return + } for _, disappearingMessage := range disappearingMessages { - dmm.Log.Debug().Msgf("Scheduling disappearing message %s", disappearingMessage.EventID) - disappearingMessage.StartExpirationTimer() + err = disappearingMessage.StartExpirationTimer(ctx) + if err != nil { + log.Err(err).Msg("Failed to schedule disappearing message") + } else { + log.Debug(). + Str("event_id", disappearingMessage.EventID.String()). + Time("expire_at", disappearingMessage.ExpireAt). + Msg("Scheduling disappearing message") + } } // Tell the disappearing messages loop to check again @@ -50,69 +61,88 @@ func (dmm *DisappearingMessagesManager) ScheduleDisappearingForRoom(roomID id.Ro func (dmm *DisappearingMessagesManager) StartDisappearingLoop(ctx context.Context) { dmm.checkMessagesChan = make(chan struct{}, 1) go func() { + log := dmm.Log.With().Str("action", "loop").Logger() + ctx = log.WithContext(ctx) for { - dmm.redactExpiredMessages() + dmm.redactExpiredMessages(ctx) duration := 10 * time.Minute // Check again in 10 minutes just in case - nextMsg := dmm.DB.DisappearingMessage.GetNextScheduledMessage() - if nextMsg != nil { - dmm.Log.Debug().Msgf("Next message to expire is %s in %s", nextMsg.EventID, nextMsg.ExpireAt.Sub(time.Now())) + nextMsg, err := dmm.DB.DisappearingMessage.GetNextScheduledMessage(ctx) + if err != nil { + if ctx.Err() != nil { + return + } + log.Err(err).Msg("Failed to get next disappearing message") + continue + } else if nextMsg != nil { duration = nextMsg.ExpireAt.Sub(time.Now()) } select { case <-time.After(duration): - // We should have at least one expired message now, so we should check again - dmm.Log.Debug().Msgf("Duration (%s) is up, checking for expired messages", duration) case <-dmm.checkMessagesChan: - // There are new messages in the DB, so we should check again - dmm.Log.Debug().Msg("New messages in DB, checking again") case <-ctx.Done(): - // We've been told to stop - dmm.Log.Debug().Msg("Stopping disappearing messages loop") return } } }() } -func (dmm *DisappearingMessagesManager) redactExpiredMessages() { - // Get all expired messages and redact them - expiredMessages := dmm.DB.DisappearingMessage.GetExpiredMessages() +func (dmm *DisappearingMessagesManager) redactExpiredMessages(ctx context.Context) { + log := zerolog.Ctx(ctx) + expiredMessages, err := dmm.DB.DisappearingMessage.GetExpiredMessages(ctx) + if err != nil { + log.Err(err).Msg("Failed to get expired disappearing messages") + return + } for _, msg := range expiredMessages { portal := dmm.Bridge.GetPortalByMXID(msg.RoomID) if portal == nil { - dmm.Log.Warn().Msgf("Failed to redact message %s in room %s: portal not found", msg.EventID, msg.RoomID) - return + log.Warn().Str("event_id", msg.EventID.String()).Str("room_id", msg.RoomID.String()).Msg("Failed to redact message: portal not found") + continue } - // Redact the message - _, err := portal.MainIntent().RedactEvent(msg.RoomID, msg.EventID, mautrix.ReqRedact{ + _, err = portal.MainIntent().RedactEvent(msg.RoomID, msg.EventID, mautrix.ReqRedact{ Reason: "Message expired", - TxnID: fmt.Sprintf("mxsig_disappear_%s", msg.EventID), + TxnID: fmt.Sprintf("mxsg_disappear_%s", msg.EventID), }) if err != nil { - portal.log.Warn().Msgf("Failed to make %s disappear: %v", msg.EventID, err) + log.Err(err). + Str("event_id", msg.EventID.String()). + Str("room_id", msg.RoomID.String()). + Msg("Failed to redact message") } else { - portal.log.Debug().Msgf("Disappeared %s", msg.EventID) + log.Err(err). + Str("event_id", msg.EventID.String()). + Str("room_id", msg.RoomID.String()). + Msg("Redacted message") + } + err = msg.Delete(ctx) + if err != nil { + log.Err(err). + Str("event_id", msg.EventID.String()). + Msg("Failed to delete disappearing message row in database") } - msg.Delete() } } -func (dmm *DisappearingMessagesManager) AddDisappearingMessage(eventID id.EventID, roomID id.RoomID, expireInSeconds int64, startTimerNow bool) { - if expireInSeconds == 0 { - dmm.Log.Debug().Msgf("Not adding disappearing message %s: expireIn is 0", eventID) +func (dmm *DisappearingMessagesManager) AddDisappearingMessage(ctx context.Context, eventID id.EventID, roomID id.RoomID, expireIn time.Duration, startTimerNow bool) { + if expireIn == 0 { return } - dmm.Log.Debug().Msgf("Adding disappearing message %s", eventID) - expireAt := time.Time{} + var expireAt time.Time if startTimerNow { - expireAt = time.Now().Add(time.Duration(expireInSeconds) * time.Second) + expireAt = time.Now().Add(expireIn) } - disappearingMessage := dmm.DB.DisappearingMessage.NewWithValues(roomID, eventID, expireInSeconds, expireAt) - disappearingMessage.Insert(nil) - + disappearingMessage := dmm.DB.DisappearingMessage.NewWithValues(roomID, eventID, expireIn, expireAt) + err := disappearingMessage.Insert(ctx) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("event_id", eventID.String()). + Msg("Failed to add disappearing message to database") + return + } + zerolog.Ctx(ctx).Debug().Str("event_id", eventID.String()). + Msg("Added disappearing message row to database") if startTimerNow { // Tell the disappearing messages loop to check again dmm.checkMessagesChan <- struct{}{} diff --git a/main.go b/main.go index 641228cf..85730036 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "os" "sync" + "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/configupgrade" "go.mau.fi/util/dbutil" @@ -62,7 +63,7 @@ type SignalBridge struct { provisioning *ProvisioningAPI usersByMXID map[id.UserID]*User - usersBySignalID map[string]*User + usersBySignalID map[uuid.UUID]*User usersLock sync.Mutex managementRooms map[id.RoomID]*User @@ -72,7 +73,7 @@ type SignalBridge struct { portalsByID map[database.PortalKey]*Portal portalsLock sync.Mutex - puppets map[string]*Puppet + puppets map[uuid.UUID]*Puppet puppetsByCustomMXID map[id.UserID]*Puppet puppetsByNumber map[string]*Puppet puppetsLock sync.Mutex @@ -101,7 +102,7 @@ func (br *SignalBridge) Init() { signalmeow.SetLogger(br.ZLog.With().Str("component", "signalmeow").Logger().Level(zerolog.DebugLevel)) //signalmeow.SetLogger(br.ZLog.With().Str("component", "signalmeow").Caller().Logger()) - br.DB = database.New(br.Bridge.DB, br.Log.Sub("Database")) + br.DB = database.New(br.Bridge.DB) br.MeowStore = signalmeow.NewStore(br.Bridge.DB, dbutil.ZeroLogger(br.ZLog.With().Str("db_section", "signalmeow").Logger())) ss := br.Config.Bridge.Provisioning.SharedSecret @@ -110,7 +111,7 @@ func (br *SignalBridge) Init() { } br.disappearingMessagesManager = &DisappearingMessagesManager{ DB: br.DB, - Log: br.ZLog.With().Str("component", "disappearingMessagesManager").Logger(), + Log: br.ZLog.With().Str("component", "disappearing messages").Logger(), Bridge: br, } @@ -118,12 +119,12 @@ func (br *SignalBridge) Init() { br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent signalFormatParams = &signalfmt.FormatParams{ - GetUserInfo: func(uuid string) signalfmt.UserInfo { - puppet := br.GetPuppetBySignalID(uuid) + GetUserInfo: func(u uuid.UUID) signalfmt.UserInfo { + puppet := br.GetPuppetBySignalID(u) if puppet == nil { return signalfmt.UserInfo{} } - user := br.GetUserBySignalID(uuid) + user := br.GetUserBySignalID(u) if user != nil { return signalfmt.UserInfo{ MXID: user.MXID, @@ -137,24 +138,44 @@ func (br *SignalBridge) Init() { }, } matrixFormatParams = &matrixfmt.HTMLParser{ - GetUUIDFromMXID: func(userID id.UserID) string { + GetUUIDFromMXID: func(userID id.UserID) uuid.UUID { parsed, ok := br.ParsePuppetMXID(userID) if ok { return parsed } // TODO only get if exists user := br.GetUserByMXID(userID) - if user != nil && user.SignalID != "" { + if user != nil && user.SignalID != uuid.Nil { return user.SignalID } - return "" + return uuid.Nil }, } signalmeow.HackyCaptionToggle = br.Config.Bridge.CaptionInMessage } +func (br *SignalBridge) logLostPortals(ctx context.Context) { + lostPortals, err := br.DB.LostPortal.GetAll(ctx) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get lost portals") + return + } else if len(lostPortals) == 0 { + return + } + lostCountByReceiver := make(map[string]int) + for _, lost := range lostPortals { + lostCountByReceiver[lost.Receiver]++ + } + br.ZLog.Warn(). + Any("count_by_receiver", lostCountByReceiver). + Msg("Some portals were discarded due to the receiver not being logged into the bridge anymore. " + + "Use `!signal cleanup-lost-portals` to remove them from the database. " + + "Alternatively, you can re-insert the data into the portal table with the appropriate receiver column to restore the portals.") +} + func (br *SignalBridge) Start() { + go br.logLostPortals(context.TODO()) err := br.MeowStore.Upgrade() if err != nil { br.Log.Fatalln("Failed to upgrade signalmeow database: %v", err) @@ -212,7 +233,7 @@ func (br *SignalBridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.U br.Log.Debugln("CreatePrivatePortal", roomID, brInviter, brGhost) inviter := brInviter.(*User) puppet := brGhost.(*Puppet) - key := database.NewPortalKey(puppet.SignalID, inviter.SignalUsername) + key := database.NewPortalKey(puppet.SignalID.String(), inviter.SignalID) portal := br.GetPortalByChatID(key) if len(portal.MXID) == 0 { @@ -284,7 +305,10 @@ func (br *SignalBridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter _, err = portal.MainIntent().SetRoomAvatar(portal.MXID, portal.AvatarURL) portal.AvatarSet = err == nil } - portal.Update() + err = portal.Update(context.TODO()) + if err != nil { + portal.log.Err(err).Msg("Failed to update portal in database") + } portal.UpdateBridgeInfo() _, _ = intent.SendNotice(roomID, "Private chat portal created") } @@ -292,14 +316,14 @@ func (br *SignalBridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter func main() { br := &SignalBridge{ usersByMXID: make(map[id.UserID]*User), - usersBySignalID: make(map[string]*User), + usersBySignalID: make(map[uuid.UUID]*User), managementRooms: make(map[id.RoomID]*User), portalsByMXID: make(map[id.RoomID]*Portal), portalsByID: make(map[database.PortalKey]*Portal), - puppets: make(map[string]*Puppet), + puppets: make(map[uuid.UUID]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet), puppetsByNumber: make(map[string]*Puppet), } diff --git a/msgconv/matrixfmt/convert_test.go b/msgconv/matrixfmt/convert_test.go index 719e4be0..0ee25141 100644 --- a/msgconv/matrixfmt/convert_test.go +++ b/msgconv/matrixfmt/convert_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -13,11 +14,11 @@ import ( ) var formatParams = &matrixfmt.HTMLParser{ - GetUUIDFromMXID: func(id id.UserID) string { + GetUUIDFromMXID: func(id id.UserID) uuid.UUID { if id.Homeserver() == "signal" { - return id.Localpart() + return uuid.MustParse(id.Localpart()) } - return "" + return uuid.Nil }, } diff --git a/msgconv/matrixfmt/html.go b/msgconv/matrixfmt/html.go index 28daec6c..074caff5 100644 --- a/msgconv/matrixfmt/html.go +++ b/msgconv/matrixfmt/html.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "github.com/google/uuid" "golang.org/x/exp/slices" "golang.org/x/net/html" "maunium.net/go/mautrix/event" @@ -223,7 +224,7 @@ func (ctx Context) WithWhitespace() Context { // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { - GetUUIDFromMXID func(id.UserID) string + GetUUIDFromMXID func(id.UserID) uuid.UUID } // TaggedString is a string that also contains a HTML tag. @@ -355,8 +356,8 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) *EntityStri // Mention not allowed, use name as-is return str } - uuid := parser.GetUUIDFromMXID(mxid) - if uuid == "" { + u := parser.GetUUIDFromMXID(mxid) + if u == uuid.Nil { // Don't include the link for mentions of non-Signal users, the name is enough return str } @@ -365,7 +366,7 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) *EntityStri MXID: mxid, Name: str.String.String(), }, - UUID: uuid, + UUID: u, }) } if str.String.String() == href { diff --git a/msgconv/signalfmt/convert.go b/msgconv/signalfmt/convert.go index d88ab747..317374d8 100644 --- a/msgconv/signalfmt/convert.go +++ b/msgconv/signalfmt/convert.go @@ -20,6 +20,7 @@ import ( "html" "strings" + "github.com/google/uuid" "golang.org/x/exp/maps" "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" @@ -34,7 +35,7 @@ type UserInfo struct { } type FormatParams struct { - GetUserInfo func(uuid string) UserInfo + GetUserInfo func(uuid uuid.UUID) UserInfo } type formatContext struct { @@ -87,7 +88,11 @@ func Parse(message string, ranges []*signalpb.BodyRange, params *FormatParams) * case *signalpb.BodyRange_Style_: br.Value = Style(rv.Style) case *signalpb.BodyRange_MentionAci: - userInfo := params.GetUserInfo(rv.MentionAci) + parsed, err := uuid.Parse(rv.MentionAci) + if err != nil { + continue + } + userInfo := params.GetUserInfo(parsed) if userInfo.MXID == "" { continue } @@ -96,7 +101,7 @@ func Parse(message string, ranges []*signalpb.BodyRange, params *FormatParams) * // Maybe use NewUTF16String and do index replacements for the plaintext body too, // or just replace the plaintext body by parsing the generated HTML. content.Body = strings.Replace(content.Body, "\uFFFC", userInfo.Name, 1) - br.Value = Mention{UserInfo: userInfo, UUID: rv.MentionAci} + br.Value = Mention{UserInfo: userInfo, UUID: parsed} } lrt.Add(br) } diff --git a/msgconv/signalfmt/convert_test.go b/msgconv/signalfmt/convert_test.go index 1c066c63..eaed12dc 100644 --- a/msgconv/signalfmt/convert_test.go +++ b/msgconv/signalfmt/convert_test.go @@ -29,11 +29,11 @@ import ( signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" ) -var realUser = uuid.NewString() +var realUser = uuid.New() func TestParse(t *testing.T) { formatParams := &signalfmt.FormatParams{ - GetUserInfo: func(uuid string) signalfmt.UserInfo { + GetUserInfo: func(uuid uuid.UUID) signalfmt.UserInfo { if uuid == realUser { return signalfmt.UserInfo{ MXID: "@test:example.com", @@ -41,7 +41,7 @@ func TestParse(t *testing.T) { } } else { return signalfmt.UserInfo{ - MXID: id.UserID("@signal_" + uuid + ":example.com"), + MXID: id.UserID("@signal_" + uuid.String() + ":example.com"), Name: "Signal User", } } @@ -79,7 +79,7 @@ func TestParse(t *testing.T) { Start: proto.Uint32(6), Length: proto.Uint32(1), AssociatedValue: &signalpb.BodyRange_MentionAci{ - MentionAci: realUser, + MentionAci: realUser.String(), }, }}, body: "Hello Matrix User", diff --git a/msgconv/signalfmt/tags.go b/msgconv/signalfmt/tags.go index b81709f9..043bb437 100644 --- a/msgconv/signalfmt/tags.go +++ b/msgconv/signalfmt/tags.go @@ -19,6 +19,8 @@ package signalfmt import ( "fmt" + "github.com/google/uuid" + signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" ) @@ -30,7 +32,7 @@ type BodyRangeValue interface { type Mention struct { UserInfo - UUID string + UUID uuid.UUID } func (m Mention) String() string { @@ -39,7 +41,7 @@ func (m Mention) String() string { func (m Mention) Proto() signalpb.BodyRangeAssociatedValue { return &signalpb.BodyRange_MentionAci{ - MentionAci: m.UUID, + MentionAci: m.UUID.String(), } } diff --git a/pkg/signalmeow/events/message.go b/pkg/signalmeow/events/message.go new file mode 100644 index 00000000..d4f59e53 --- /dev/null +++ b/pkg/signalmeow/events/message.go @@ -0,0 +1,16 @@ +package events + +import ( + "github.com/google/uuid" + + signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" +) + +type MessageInfo struct { + Sender uuid.UUID + Chat string +} + +type Receipt struct { + *signalpb.ReceiptMessage +} diff --git a/pkg/signalmeow/incoming_messages.go b/pkg/signalmeow/incoming_messages.go index 78b723c4..fd766ddd 100644 --- a/pkg/signalmeow/incoming_messages.go +++ b/pkg/signalmeow/incoming_messages.go @@ -28,6 +28,7 @@ type IncomingSignalMessageBase struct { RecipientUUID string // Usually our UUID, unless this is a message we sent on another device GroupID *GroupIdentifier // Unique identifier for the group chat, or nil for 1:1 chats Timestamp uint64 // With SenderUUID, treated as a unique identifier for a specific Signal message + PartIndex int // Quote *IncomingSignalMessageQuoteData // If this message is a quote (reply), this will be non-nil ExpiresIn int64 // If this message is ephemeral, this will be non-zero (in seconds) } diff --git a/pkg/signalmeow/receiving.go b/pkg/signalmeow/receiving.go index 57e51f8b..e0d82eb7 100644 --- a/pkg/signalmeow/receiving.go +++ b/pkg/signalmeow/receiving.go @@ -828,7 +828,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa expiresIn = int64(dataMessage.GetExpireTimer()) } - tsIndex := 0 + partIndex := 0 // If there's attachements, handle them (one at a time for now) if dataMessage.Attachments != nil { for index, attachmentPointer := range dataMessage.Attachments { @@ -837,19 +837,14 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa zlog.Err(err).Msg("fetchAndDecryptAttachment error") continue } - // TODO: this is a hack to make sure each attachment has a unique timestamp - // this allows us to associate up to 999 attachments with a single Signal message - var timestamp uint64 = dataMessage.GetTimestamp() - if index > 0 { - timestamp = (dataMessage.GetTimestamp() * 1000) + uint64(index) - } // TODO: right now this will be one message per image, each with the same caption incomingMessage := IncomingSignalMessageAttachment{ IncomingSignalMessageBase: IncomingSignalMessageBase{ SenderUUID: senderUUID, RecipientUUID: recipientUUID, GroupID: gidPointer, - Timestamp: timestamp, + Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, Quote: quoteData, ExpiresIn: expiresIn, }, @@ -861,33 +856,31 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa Height: attachmentPointer.GetHeight(), BlurHash: attachmentPointer.GetBlurHash(), } + partIndex++ if HackyCaptionToggle && index == 0 { incomingMessage.Caption = dataMessage.GetBody() incomingMessage.CaptionRanges = dataMessage.GetBodyRanges() } incomingMessages = append(incomingMessages, incomingMessage) } - tsIndex += len(dataMessage.Attachments) } // If there's a body but no attachments, pass along as a text message if dataMessage.Body != nil && (dataMessage.Attachments == nil || !HackyCaptionToggle) { - timestamp := dataMessage.GetTimestamp() - if tsIndex != 0 { - timestamp = (dataMessage.GetTimestamp() * 1000) + uint64(tsIndex+1) - } incomingMessage := IncomingSignalMessageText{ IncomingSignalMessageBase: IncomingSignalMessageBase{ SenderUUID: senderUUID, RecipientUUID: recipientUUID, GroupID: gidPointer, - Timestamp: timestamp, + Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, Quote: quoteData, ExpiresIn: expiresIn, }, Content: dataMessage.GetBody(), ContentRanges: dataMessage.GetBodyRanges(), } + partIndex++ incomingMessages = append(incomingMessages, incomingMessage) } @@ -903,6 +896,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa RecipientUUID: recipientUUID, GroupID: gidPointer, Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, Quote: quoteData, ExpiresIn: expiresIn, }, @@ -914,6 +908,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa Emoji: dataMessage.GetSticker().GetEmoji(), } incomingMessages = append(incomingMessages, incomingMessage) + partIndex++ } } @@ -978,6 +973,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa RecipientUUID: recipientUUID, GroupID: gidPointer, Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, }, DisplayName: contactCard.GetName().GetDisplayName(), Organization: contactCard.GetOrganization(), @@ -1017,6 +1013,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa addressString := strings.Join(addressParts, ", ") incomingMessage.Addresses = append(incomingMessage.Addresses, addressString) } + partIndex++ incomingMessages = append(incomingMessages, incomingMessage) } } diff --git a/pkg/signalmeow/sending.go b/pkg/signalmeow/sending.go index 0ca75b9f..de90b706 100644 --- a/pkg/signalmeow/sending.go +++ b/pkg/signalmeow/sending.go @@ -26,6 +26,7 @@ import ( "strings" "time" + "github.com/google/uuid" "google.golang.org/protobuf/proto" "go.mau.fi/mautrix-signal/pkg/libsignalgo" @@ -437,14 +438,14 @@ func DataMessageForAttachment(attachmentPointer *AttachmentPointer, caption stri return wrapDataMessageInContent(dm) } -func DataMessageForReaction(reaction string, targetMessageSender string, targetMessageTimestamp uint64, removing bool) *SignalContent { +func DataMessageForReaction(reaction string, targetMessageSender uuid.UUID, targetMessageTimestamp uint64, removing bool) *SignalContent { timestamp := currentMessageTimestamp() dm := &signalpb.DataMessage{ Timestamp: ×tamp, Reaction: &signalpb.DataMessage_Reaction{ Emoji: proto.String(reaction), Remove: proto.Bool(removing), - TargetAuthorAci: proto.String(targetMessageSender), + TargetAuthorAci: proto.String(targetMessageSender.String()), TargetSentTimestamp: proto.Uint64(targetMessageTimestamp), }, } @@ -462,9 +463,9 @@ func DataMessageForDelete(targetMessageTimestamp uint64) *SignalContent { return wrapDataMessageInContent(dm) } -func AddQuoteToDataMessage(content *SignalContent, quotedMessageSender string, quotedMessageTimestamp uint64) { +func AddQuoteToDataMessage(content *SignalContent, quotedMessageSender uuid.UUID, quotedMessageTimestamp uint64) { content.DataMessage.Quote = &signalpb.DataMessage_Quote{ - AuthorAci: proto.String(quotedMessageSender), + AuthorAci: proto.String(quotedMessageSender.String()), Id: proto.Uint64(quotedMessageTimestamp), Type: signalpb.DataMessage_Quote_NORMAL.Enum(), @@ -547,7 +548,7 @@ func SendGroupMessage(ctx context.Context, device *Device, gid GroupIdentifier, return result, nil } -func SendMessage(ctx context.Context, device *Device, recipientUuid string, message *SignalContent) SendMessageResult { +func SendMessage(ctx context.Context, device *Device, recipientID string, message *SignalContent) SendMessageResult { // Assemble the content to send content := (*signalpb.Content)(message) dataMessage := content.DataMessage @@ -559,12 +560,12 @@ func SendMessage(ctx context.Context, device *Device, recipientUuid string, mess } // Send to the recipient - sentUnidentified, err := sendContent(ctx, device, recipientUuid, messageTimestamp, content, 0) + sentUnidentified, err := sendContent(ctx, device, recipientID, messageTimestamp, content, 0) if err != nil { return SendMessageResult{ WasSuccessful: false, FailedSendResult: &FailedSendResult{ - RecipientUuid: recipientUuid, + RecipientUuid: recipientID, Error: err, }, } @@ -572,7 +573,7 @@ func SendMessage(ctx context.Context, device *Device, recipientUuid string, mess result := SendMessageResult{ WasSuccessful: true, SuccessfulSendResult: &SuccessfulSendResult{ - RecipientUuid: recipientUuid, + RecipientUuid: recipientID, Unidentified: sentUnidentified, }, } @@ -589,7 +590,7 @@ func SendMessage(ctx context.Context, device *Device, recipientUuid string, mess syncContent = syncMessageFromSoloDataMessage(dataMessage, *result.SuccessfulSendResult) } if content.ReceiptMessage != nil && *content.ReceiptMessage.Type == signalpb.ReceiptMessage_READ { - syncContent = syncMessageFromReadReceiptMessage(content.ReceiptMessage, recipientUuid) + syncContent = syncMessageFromReadReceiptMessage(content.ReceiptMessage, recipientID) } if syncContent != nil { _, selfSendErr := sendContent(ctx, device, device.Data.AciUuid, messageTimestamp, syncContent, 0) diff --git a/portal.go b/portal.go index 4f777cbe..0ab09db4 100644 --- a/portal.go +++ b/portal.go @@ -114,7 +114,10 @@ func (portal *Portal) IsEncrypted() bool { func (portal *Portal) MarkEncrypted() { portal.Encrypted = true - portal.Update() + err := portal.Update(context.TODO()) + if err != nil { + portal.log.Err(err).Msg("Failed to update portal in database after marking as encrypted") + } } func (portal *Portal) ReceiveMatrixEvent(user bridge.User, evt *event.Event) { @@ -146,7 +149,7 @@ func (portal *Portal) IsPrivateChat() bool { func (portal *Portal) MainIntent() *appservice.IntentAPI { if portal.IsPrivateChat() { - return portal.bridge.GetPuppetBySignalID(portal.ChatID).DefaultIntent() + return portal.bridge.GetPuppetBySignalID(portal.UserID()).DefaultIntent() } return portal.bridge.Bot @@ -168,14 +171,14 @@ func (portal *Portal) getBridgeInfo() (string, CustomBridgeInfoContent) { ExternalURL: "https://signal.org/", }, Channel: event.BridgeInfoSection{ - ID: portal.Key().ChatID, + ID: portal.ChatID, DisplayName: portal.Name, AvatarURL: portal.AvatarURL.CUString(), }, } var bridgeInfoStateKey string - bridgeInfoStateKey = fmt.Sprintf("fi.mau.signal://signal/%s", portal.Key().ChatID) - bridgeInfo.Channel.ExternalURL = fmt.Sprintf("https://signal.me/#p/%s", portal.Key().ChatID) + bridgeInfoStateKey = fmt.Sprintf("fi.mau.signal://signal/%s", portal.ChatID) + bridgeInfo.Channel.ExternalURL = fmt.Sprintf("https://signal.me/#p/%s", portal.ChatID) var roomType string if portal.IsPrivateChat() { roomType = "dm" @@ -203,8 +206,21 @@ func (portal *Portal) UpdateBridgeInfo() { // ** bridge.ChildOverride methods (for SignalBridge in main.go) ** +func (br *SignalBridge) GetAllPortalsWithMXID() []*Portal { + portals, err := br.dbPortalsToPortals(br.DB.Portal.GetAllWithMXID(context.TODO())) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get all portals with mxid") + return nil + } + return portals +} + func (br *SignalBridge) GetAllIPortals() (iportals []bridge.Portal) { - portals := br.getAllPortals() + portals, err := br.dbPortalsToPortals(br.DB.Portal.GetAllWithMXID(context.TODO())) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get all portals with mxid") + return nil + } iportals = make([]bridge.Portal, len(portals)) for i, portal := range portals { iportals[i] = portal @@ -212,11 +228,10 @@ func (br *SignalBridge) GetAllIPortals() (iportals []bridge.Portal) { return iportals } -func (br *SignalBridge) getAllPortals() []*Portal { - return br.dbPortalsToPortals(br.DB.Portal.GetAll()) -} - -func (br *SignalBridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Portal { +func (br *SignalBridge) dbPortalsToPortals(dbPortals []*database.Portal, err error) ([]*Portal, error) { + if err != nil { + return nil, err + } br.portalsLock.Lock() defer br.portalsLock.Unlock() @@ -226,15 +241,15 @@ func (br *SignalBridge) dbPortalsToPortals(dbPortals []*database.Portal) []*Port continue } - portal, ok := br.portalsByID[dbPortal.Key()] + portal, ok := br.portalsByID[dbPortal.PortalKey] if !ok { - portal = br.loadPortal(dbPortal, nil) + portal = br.loadPortal(context.TODO(), dbPortal, nil) } output[index] = portal } - return output + return output, nil } // ** Portal Creation and Message Handling ** @@ -243,12 +258,10 @@ func (br *SignalBridge) NewPortal(dbPortal *database.Portal) *Portal { portal := &Portal{ Portal: dbPortal, bridge: br, - log: br.ZLog.With().Str("chat_id", dbPortal.Key().ChatID).Logger(), + log: br.ZLog.With().Str("chat_id", dbPortal.ChatID).Logger(), signalMessages: make(chan portalSignalMessage, br.Config.Bridge.PortalMessageBuffer), matrixMessages: make(chan portalMatrixMessage, br.Config.Bridge.PortalMessageBuffer), - - //commands: make(map[string]*discordgo.ApplicationCommand), } go portal.messageLoop() @@ -258,13 +271,10 @@ func (br *SignalBridge) NewPortal(dbPortal *database.Portal) *Portal { func (portal *Portal) messageLoop() { for { - portal.log.Debug().Msg("Waiting for message") select { case msg := <-portal.matrixMessages: - portal.log.Debug().Msg("Got message from matrix") portal.handleMatrixMessages(msg) case msg := <-portal.signalMessages: - portal.log.Debug().Msg("Got message from signal") portal.handleSignalMessages(msg) } } @@ -278,20 +288,22 @@ func (portal *Portal) handleMatrixMessages(msg portalMatrixMessage) { msg.user.BridgeState.Send(status.BridgeState{StateEvent: status.StateBadCredentials, Message: "You have been logged out of Signal, please reconnect"}) return } + log := portal.log.With().Str("event_id", msg.evt.ID.String()).Logger() + ctx := log.WithContext(context.TODO()) switch msg.evt.Type { case event.EventMessage, event.EventSticker: - portal.handleMatrixMessage(msg.user, msg.evt) + portal.handleMatrixMessage(ctx, msg.user, msg.evt) case event.EventRedaction: - portal.handleMatrixRedaction(msg.user, msg.evt) + portal.handleMatrixRedaction(ctx, msg.user, msg.evt) case event.EventReaction: - portal.handleMatrixReaction(msg.user, msg.evt) + portal.handleMatrixReaction(ctx, msg.user, msg.evt) default: - portal.log.Warn().Str("type", msg.evt.Type.String()).Msg("Unhandled matrix message type") + log.Warn().Str("type", msg.evt.Type.String()).Msg("Unhandled matrix message type") } } -func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { +func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *User, evt *event.Event) { evtTS := time.UnixMilli(evt.Timestamp) timings := messageTimings{ initReceive: evt.Mautrix.ReceivedAt.Sub(evtTS), @@ -303,24 +315,7 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { start := time.Now() messageAge := timings.totalReceive - origEvtID := evt.ID ms := metricSender{portal: portal, timings: &timings} - var dbMsg *database.Message - if retryMeta := evt.Content.AsMessage().MessageSendRetry; retryMeta != nil { - origEvtID = retryMeta.OriginalEventID - dbMsg = portal.bridge.DB.Message.GetByMXID(origEvtID) - if dbMsg != nil { - //portal.log.Debugfln("Ignoring retry request %s (#%d, age: %s) for %s/%s from %s as message was already sent", evt.ID, retryMeta.RetryCount, messageAge, origEvtID, dbMsg.JID, evt.Sender) - go ms.sendMessageMetrics(evt, nil, "", true) - return - } else if dbMsg != nil { - //portal.log.Debugfln("Got retry request %s (#%d, age: %s) for %s/%s from %s", evt.ID, retryMeta.RetryCount, messageAge, origEvtID, dbMsg.JID, evt.Sender) - } else { - //portal.log.Debugfln("Got retry request %s (#%d, age: %s) for %s from %s (original message not known)", evt.ID, retryMeta.RetryCount, messageAge, origEvtID, evt.Sender) - } - } else { - //portal.log.Debugfln("Received message %s from %s (age: %s)", evt.ID, evt.Sender, messageAge) - } portal.log.Debug().Msgf("Received message %s from %s (age: %s)", evt.ID, evt.Sender, messageAge) errorAfter := portal.bridge.Config.Bridge.MessageHandlingTimeout.ErrorAfter @@ -346,7 +341,6 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { }() } - ctx := context.Background() if deadline > 0 { var cancel context.CancelFunc ctx, cancel = context.WithTimeout(ctx, deadline) @@ -385,27 +379,28 @@ func (portal *Portal) handleMatrixMessage(sender *User, evt *event.Event) { timings.totalSend = time.Since(start) go ms.sendMessageMetrics(evt, err, "Error sending", true) if err == nil { - portal.storeMessageInDB(evt.ID, sender.SignalID, timestamp) + portal.storeMessageInDB(ctx, evt.ID, sender.SignalID, timestamp, 0) if portal.ExpirationTime > 0 { - portal.addDisappearingMessage(evt.ID, int64(portal.ExpirationTime), true) + portal.addDisappearingMessage(ctx, evt.ID, int64(portal.ExpirationTime), true) } } } -func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) { +func (portal *Portal) handleMatrixRedaction(ctx context.Context, sender *User, evt *event.Event) { + log := zerolog.Ctx(ctx) // Find the original signal message based on eventID - dbMessage := portal.bridge.DB.Message.GetByMXID(evt.Redacts) - if dbMessage == nil { - portal.log.Info().Msgf("Could not find original message for redaction %s", evt.ID) + dbMessage, err := portal.bridge.DB.Message.GetByMXID(ctx, evt.Redacts) + if err != nil { + log.Err(err).Msg("Failed to get redaction target message") } // Might be a reaction redaction, find the original message for the reaction - dbReaction := portal.bridge.DB.Reaction.GetByMXID(evt.Redacts, evt.RoomID) - if dbReaction == nil { - portal.log.Info().Msgf("Could not find original reaction for redaction %s", evt.ID) + dbReaction, err := portal.bridge.DB.Reaction.GetByMXID(ctx, evt.Redacts) + if err != nil { + log.Err(err).Msg("Failed to get redaction target reaction") } if dbMessage == nil && dbReaction == nil { portal.sendMessageStatusCheckpointFailed(evt, errors.New("could not find original message or reaction")) - portal.log.Error().Msgf("Could not find original message or reaction for redaction %s", evt.ID) + log.Warn().Msg("No target message or reaction found for redaction") return } @@ -415,51 +410,84 @@ func (portal *Portal) handleMatrixRedaction(sender *User, evt *event.Event) { // If this is a message redaction, send a redaction to Signal if dbMessage != nil { - msg := signalmeow.DataMessageForDelete(dedupedTimestamp(dbMessage)) - err := portal.sendSignalMessage(context.Background(), msg, sender, evt.ID) + msg := signalmeow.DataMessageForDelete(dbMessage.Timestamp) + err = portal.sendSignalMessage(ctx, msg, sender, evt.ID) if err != nil { portal.sendMessageStatusCheckpointFailed(evt, err) - portal.log.Error().Msgf("Failed to send redaction %s", evt.ID) + log.Err(err).Msg("Failed to send message redaction to Signal") return } - dbMessage.Delete(nil) + err = dbMessage.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete redacted message from database") + } else if otherParts, err := portal.bridge.DB.Message.GetAllPartsBySignalID(ctx, dbMessage.Sender, dbMessage.Timestamp, portal.Receiver); err != nil { + log.Err(err).Msg("Failed to get other parts of redacted message from database") + } else if len(otherParts) > 0 { + // If there are other parts of the message, send a redaction for each of them + for _, otherPart := range otherParts { + _, err = portal.MainIntent().RedactEvent(portal.MXID, otherPart.MXID, mautrix.ReqRedact{ + Reason: "Other part of Signal message redacted", + TxnID: "mxsg_partredact_" + otherPart.MXID.String(), + }) + if err != nil { + log.Err(err). + Str("part_event_id", otherPart.MXID.String()). + Int("part_index", otherPart.PartIndex). + Msg("Failed to redact other part of redacted message") + } + err = otherPart.Delete(ctx) + if err != nil { + log.Err(err). + Str("part_event_id", otherPart.MXID.String()). + Int("part_index", otherPart.PartIndex). + Msg("Failed to delete other part of redacted message from database") + } + } + } + } - // If this is a reaction redaction, send a reaction to Signal with remove == true if dbReaction != nil { msg := signalmeow.DataMessageForReaction(dbReaction.Emoji, dbReaction.MsgAuthor, dbReaction.MsgTimestamp, true) - err := portal.sendSignalMessage(context.Background(), msg, sender, evt.ID) + err = portal.sendSignalMessage(ctx, msg, sender, evt.ID) if err != nil { portal.sendMessageStatusCheckpointFailed(evt, err) - portal.log.Error().Msgf("Failed to send reaction %s", evt.ID) + log.Err(err).Msg("Failed to send reaction redaction to Signal") return } - dbReaction.Delete(nil) + err = dbReaction.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete redacted reaction from database") + } } portal.sendMessageStatusCheckpointSuccess(evt) } -func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { - // Find the original signal message based on eventID - relatedEventID := evt.Content.AsReaction().RelatesTo.EventID - dbMessage := portal.bridge.DB.Message.GetByMXID(relatedEventID) +func (portal *Portal) handleMatrixReaction(ctx context.Context, sender *User, evt *event.Event) { + log := zerolog.Ctx(ctx) if !sender.IsLoggedIn() { - portal.log.Error().Msgf("Cannot relay reaction from non-logged-in user. Ignoring") + log.Error().Msg("Cannot relay reaction from non-logged-in user. Ignoring") return } - if dbMessage == nil { + // Find the original signal message based on eventID + relatedEventID := evt.Content.AsReaction().RelatesTo.EventID + dbMessage, err := portal.bridge.DB.Message.GetByMXID(ctx, relatedEventID) + if err != nil { + portal.sendMessageStatusCheckpointFailed(evt, err) + log.Err(err).Msg("Failed to get reaction target message") + return + } else if dbMessage == nil { portal.sendMessageStatusCheckpointFailed(evt, errors.New("could not find original message for reaction")) - portal.log.Error().Msgf("Could not find original message for reaction %s", evt.ID) + log.Warn().Msg("No target message found for reaction") return } emoji := evt.Content.AsReaction().RelatesTo.Key signalEmoji := variationselector.FullyQualify(emoji) // Signal seems to require fully qualified emojis targetAuthorUUID := dbMessage.Sender - targetTimestamp := dedupedTimestamp(dbMessage) + targetTimestamp := dbMessage.Timestamp msg := signalmeow.DataMessageForReaction(signalEmoji, targetAuthorUUID, targetTimestamp, false) - - err := portal.sendSignalMessage(context.Background(), msg, sender, evt.ID) + err = portal.sendSignalMessage(context.Background(), msg, sender, evt.ID) if err != nil { portal.sendMessageStatusCheckpointFailed(evt, err) portal.log.Error().Msgf("Failed to send reaction %s", evt.ID) @@ -468,27 +496,30 @@ func (portal *Portal) handleMatrixReaction(sender *User, evt *event.Event) { // Signal only allows one reaction from each user // Check if there's an existing reaction in the database for this sender and redact/delete it - dbReaction := portal.bridge.DB.Reaction.GetBySignalID( - portal.ChatID, - portal.Receiver, - sender.SignalID, + dbReaction, err := portal.bridge.DB.Reaction.GetBySignalID( + ctx, targetAuthorUUID, targetTimestamp, + sender.SignalID, + portal.Receiver, ) - if dbReaction != nil { - portal.log.Debug().Msgf("Deleting existing reaction with author %s, target %s, targettime: %d", sender.SignalID, targetAuthorUUID, targetTimestamp) - // Send a redaction to redact the existing reaction - intent := portal.MainIntent() - _, err := intent.RedactEvent(portal.MXID, dbReaction.MXID) + if err != nil { + log.Err(err).Msg("Failed to get existing reaction from database") + } else if dbReaction != nil { + log.Debug().Str("existing_event_id", dbReaction.MXID.String()).Msg("Redacting existing reaction after sending new one") + _, err = portal.MainIntent().RedactEvent(portal.MXID, dbReaction.MXID) if err != nil { - portal.sendMessageStatusCheckpointFailed(evt, err) - portal.log.Warn().Msgf("Failed to redact existing reaction: %v", err) + log.Err(err).Msg("Failed to redact existing reaction") + } + // TODO update instead of deleting + err = dbReaction.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete reaction from database") } - dbReaction.Delete(nil) } // Store our new reaction in the database - portal.storeReactionInDB(evt.ID, sender.SignalID, targetAuthorUUID, targetTimestamp, signalEmoji) + portal.storeReactionInDB(ctx, evt.ID, sender.SignalID, targetAuthorUUID, targetTimestamp, signalEmoji) portal.sendMessageStatusCheckpointSuccess(evt) } @@ -828,15 +859,18 @@ func (portal *Portal) convertMatrixMessage(ctx context.Context, sender *User, ev // Include a quote if this is a reply replyID := content.RelatesTo.GetReplyTo() if replyID != "" { - originalMessage := portal.bridge.DB.Message.GetByMXID(replyID) - if originalMessage == nil { - return nil, fmt.Errorf("%v %s", "Reply not found", replyID) + originalMessage, err := portal.bridge.DB.Message.GetByMXID(ctx, replyID) + if err != nil { + return nil, fmt.Errorf("failed to get reply target: %w", err) + } else if originalMessage != nil { + signalmeow.AddQuoteToDataMessage( + outgoingMessage, + originalMessage.Sender, + originalMessage.Timestamp, + ) + } else { + zerolog.Ctx(ctx).Warn().Str("reply_event_id", replyID.String()).Msg("Reply target not found") } - signalmeow.AddQuoteToDataMessage( - outgoingMessage, - originalMessage.Sender, - dedupedTimestamp(originalMessage), - ) } return outgoingMessage, nil } @@ -902,22 +936,32 @@ func (portal *Portal) sendMessageStatusCheckpointFailed(evt *event.Event, err er } func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { - if old_message := portal.bridge.DB.Message.GetBySignalID( + log := portal.log.With(). + Str("action", "handle signal message"). + Str("sender", portalMessage.sender.SignalID.String()). + Uint64("timestamp", portalMessage.message.Base().Timestamp). + Int("part_index", portalMessage.message.Base().PartIndex). + Logger() + ctx := log.WithContext(context.TODO()) + if existingMessage, err := portal.bridge.DB.Message.GetBySignalID( + ctx, portalMessage.sender.SignalID, portalMessage.message.Base().Timestamp, - portal.ChatID, + portalMessage.message.Base().PartIndex, portal.Receiver, - ); old_message != nil { - portal.log.Info().Msgf("Ignoring message %d by %s as it was already handled", old_message.Timestamp, old_message.Sender) + ); err != nil { + log.Err(err).Msg("Failed to check if message was already handled") + return + } else if existingMessage != nil { + log.Debug().Msg("Ignoring duplicate message") return } if portal.MXID == "" { - portal.log.Debug().Msg("Creating Matrix room from incoming message") + log.Debug().Msg("Creating Matrix room from incoming message") if err := portal.CreateMatrixRoom(portalMessage.user, nil); err != nil { - portal.log.Error().Err(err).Msg("Failed to create portal room") + log.Error().Err(err).Msg("Failed to create portal room") return } - portal.log.Info().Msgf("Created matrix room: %s", portal.MXID) ensureGroupPuppetsAreJoinedToPortal(context.Background(), portalMessage.user, portal) signalmeow.SendContactSyncRequest(context.TODO(), portalMessage.user.SignalDevice) } @@ -930,31 +974,23 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { var err error if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeText { - err = portal.handleSignalTextMessage(portalMessage, intent) + err = portal.handleSignalTextMessage(ctx, portalMessage, intent) if err != nil { portal.log.Error().Err(err).Msg("Failed to handle text message") return } } else if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeAttachment { - err = portal.handleSignalAttachmentMessage(portalMessage, intent) + err = portal.handleSignalAttachmentMessage(ctx, portalMessage, intent) if err != nil { portal.log.Error().Err(err).Msg("Failed to handle attachment message") return } } else if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeReaction { - err := portal.handleSignalReactionMessage(portalMessage, intent) - if err != nil { - portal.log.Error().Err(err).Msg("Failed to handle reaction message") - return - } + portal.handleSignalReactionMessage(ctx, portalMessage, intent) } else if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeDelete { - err := portal.handleSignalDeleteMessage(portalMessage, intent) - if err != nil { - portal.log.Error().Err(err).Msg("Failed to handle redaction message") - return - } + portal.handleSignalDeleteMessage(ctx, portalMessage, intent) } else if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeSticker { - err := portal.handleSignalStickerMessage(portalMessage, intent) + err := portal.handleSignalStickerMessage(ctx, portalMessage, intent) if err != nil { portal.log.Error().Err(err).Msg("Failed to handle sticker message") return @@ -966,12 +1002,7 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { return } } else if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeReceipt { - portal.log.Debug().Msg("Received receipt message") - err := portal.handleSignalReceiptMessage(portalMessage, intent) - if err != nil { - portal.log.Error().Err(err).Msg("Failed to handle receipt message") - return - } + portal.handleSignalReceiptMessage(ctx, portalMessage, intent) } else if portalMessage.message.MessageType() == signalmeow.IncomingSignalMessageTypeCall { err := portal.handleSignalCallMessage(portalMessage, intent) if err != nil { @@ -996,91 +1027,102 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { } } -func (portal *Portal) storeMessageInDB(eventID id.EventID, senderSignalID string, timestamp uint64) { +func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID uuid.UUID, timestamp uint64, partIndex int) { dbMessage := portal.bridge.DB.Message.New() dbMessage.MXID = eventID - dbMessage.MXRoom = portal.MXID + dbMessage.RoomID = portal.MXID dbMessage.Sender = senderSignalID dbMessage.Timestamp = timestamp + dbMessage.PartIndex = partIndex dbMessage.SignalChatID = portal.ChatID dbMessage.SignalReceiver = portal.Receiver - dbMessage.Insert(nil) + err := dbMessage.Insert(ctx) + if err != nil { + portal.log.Err(err).Msg("Failed to insert message into database") + } } func (portal *Portal) storeReactionInDB( + ctx context.Context, eventID id.EventID, - senderSignalID string, - msgAuthor string, + senderSignalID, + msgAuthor uuid.UUID, msgTimestamp uint64, emoji string, ) { dbReaction := portal.bridge.DB.Reaction.New() dbReaction.MXID = eventID - dbReaction.MXRoom = portal.MXID + dbReaction.RoomID = portal.MXID dbReaction.SignalChatID = portal.ChatID dbReaction.SignalReceiver = portal.Receiver dbReaction.Author = senderSignalID dbReaction.MsgAuthor = msgAuthor dbReaction.MsgTimestamp = msgTimestamp dbReaction.Emoji = emoji - dbReaction.Insert(nil) + err := dbReaction.Insert(ctx) + if err != nil { + portal.log.Err(err).Msg("Failed to insert reaction into database") + } } -func (portal *Portal) addSignalQuote(content *event.MessageEventContent, quote *signalmeow.IncomingSignalMessageQuoteData) { - if quote != nil { - originalMessage := portal.bridge.DB.Message.GetBySignalID( - quote.QuotedSender, quote.QuotedTimestamp, portal.ChatID, portal.Receiver, - ) - if originalMessage == nil { - portal.log.Warn().Msgf("Couldn't find message with Signal ID %s/%d", quote.QuotedSender, quote.QuotedTimestamp) - return - } - eventID := originalMessage.MXID - if eventID != "" { - content.RelatesTo = &event.RelatesTo{ - InReplyTo: &event.InReplyTo{ - EventID: eventID, - }, - } - mentionMXID := portal.bridge.FormatPuppetMXID(originalMessage.Sender) - user := portal.bridge.GetUserBySignalID(originalMessage.Sender) - if user != nil { - mentionMXID = user.MXID - } - if !slices.Contains(content.Mentions.UserIDs, mentionMXID) { - content.Mentions.UserIDs = append(content.Mentions.UserIDs, mentionMXID) - } - } else { - portal.log.Warn().Msgf("Couldn't find event ID for message with Signal ID %s/%d", quote.QuotedSender, quote.QuotedTimestamp) - } +func (portal *Portal) addSignalQuote(ctx context.Context, content *event.MessageEventContent, quote *signalmeow.IncomingSignalMessageQuoteData) { + if quote == nil { + return + } + quotedSender, err := uuid.Parse(quote.QuotedSender) + if err != nil { + return + } + originalMessage, err := portal.bridge.DB.Message.GetBySignalID( + ctx, quotedSender, quote.QuotedTimestamp, 0, portal.Receiver, + ) + if err != nil { + zerolog.Ctx(ctx).Err(err).Str("quoted_sender", quote.QuotedSender).Uint64("quoted_timestamp", quote.QuotedTimestamp).Msg("Failed to get quoted message from database") + return + } else if originalMessage == nil { + zerolog.Ctx(ctx).Warn().Str("quoted_sender", quote.QuotedSender).Uint64("quoted_timestamp", quote.QuotedTimestamp).Msg("Quote target message not found") + return + } + content.RelatesTo = &event.RelatesTo{ + InReplyTo: &event.InReplyTo{ + EventID: originalMessage.MXID, + }, + } + mentionMXID := portal.bridge.FormatPuppetMXID(originalMessage.Sender) + user := portal.bridge.GetUserBySignalID(originalMessage.Sender) + if user != nil { + mentionMXID = user.MXID + } + if !slices.Contains(content.Mentions.UserIDs, mentionMXID) { + content.Mentions.UserIDs = append(content.Mentions.UserIDs, mentionMXID) } } -func (portal *Portal) addDisappearingMessage(eventID id.EventID, expireInSeconds int64, startTimerNow bool) { - portal.bridge.disappearingMessagesManager.AddDisappearingMessage(eventID, portal.MXID, expireInSeconds, startTimerNow) +func (portal *Portal) addDisappearingMessage(ctx context.Context, eventID id.EventID, expireInSeconds int64, startTimerNow bool) { + portal.bridge.disappearingMessagesManager.AddDisappearingMessage(ctx, eventID, portal.MXID, time.Duration(expireInSeconds)*time.Second, startTimerNow) } var signalFormatParams *signalfmt.FormatParams var matrixFormatParams *matrixfmt.HTMLParser -func (portal *Portal) handleSignalTextMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { +func (portal *Portal) handleSignalTextMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { timestamp := portalMessage.message.Base().Timestamp msg := (portalMessage.message).(signalmeow.IncomingSignalMessageText) content := signalfmt.Parse(msg.Content, msg.ContentRanges, signalFormatParams) - portal.addSignalQuote(content, msg.Quote) - resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0) + portal.addSignalQuote(ctx, content, msg.Quote) + resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, int64(timestamp)) if err != nil { return err } if resp.EventID == "" { return errors.New("Didn't receive event ID from Matrix") } - portal.storeMessageInDB(resp.EventID, portalMessage.sender.SignalID, timestamp) - portal.addDisappearingMessage(resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) + portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex) + portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) return err } -func (portal *Portal) handleSignalStickerMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { +func (portal *Portal) handleSignalStickerMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { timestamp := portalMessage.message.Base().Timestamp msg := (portalMessage.message).(signalmeow.IncomingSignalMessageSticker) content := &event.MessageEventContent{ @@ -1095,21 +1137,21 @@ func (portal *Portal) handleSignalStickerMessage(portalMessage portalSignalMessa Mentions: &event.Mentions{}, } - portal.addSignalQuote(content, msg.Quote) + portal.addSignalQuote(ctx, content, msg.Quote) err := portal.uploadMediaToMatrix(intent, msg.Sticker, content) if err != nil { portal.log.Error().Err(err).Msg("Failed to upload media") } - resp, err := portal.sendMatrixMessage(intent, event.EventSticker, content, nil, 0) + resp, err := portal.sendMatrixMessage(intent, event.EventSticker, content, nil, int64(timestamp)) if err != nil { return err } if resp.EventID == "" { return errors.New("Didn't receive event ID from Matrix") } - portal.storeMessageInDB(resp.EventID, portalMessage.sender.SignalID, timestamp) - portal.addDisappearingMessage(resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) + portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex) + portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) return err } @@ -1156,53 +1198,56 @@ func (portal *Portal) handleSignalUnhandledMessage(portalMessage portalSignalMes return nil } -func (portal *Portal) handleSignalReceiptMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { +func (portal *Portal) handleSignalReceiptMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) { receiptMessage := (portalMessage.message).(signalmeow.IncomingSignalMessageReceipt) - messageSender := receiptMessage.OriginalSender + log := zerolog.Ctx(ctx) + messageSender, err := uuid.Parse(receiptMessage.OriginalSender) + // TODO handle err timestamp := receiptMessage.OriginalTimestamp - dbMessages := portal.dbMessagesBySenderAndTimestamp(messageSender, timestamp) - if len(dbMessages) == 0 { - return fmt.Errorf("Couldn't find message with Signal ID %s/%d", messageSender, timestamp) + lastPart, err := portal.bridge.DB.Message.GetLastPartBySignalID(ctx, messageSender, timestamp, portal.Receiver) + if err != nil { + log.Err(err).Msg("Failed to get receipt target message") + return + } else if lastPart == nil { + log.Err(err).Msg("Receipt target message not found") + return } - // If there are multiple dbMessages, we want the last one - dbMessage := dbMessages[len(dbMessages)-1] if receiptMessage.ReceiptType == signalmeow.IncomingSignalMessageReceiptTypeRead { - portal.log.Debug().Msgf("Received read receipt") + log.Debug().Msg("Received read receipt") // Don't process read receipts for messages older than the latest one we've seen if receiptMessage.OriginalTimestamp <= portal.latestReadTimestamp { - portal.log.Debug().Msgf("Ignoring read receipt for timestamp %d", receiptMessage.OriginalTimestamp) - return nil + log.Debug().Msgf("Ignoring read receipt for timestamp %d", receiptMessage.OriginalTimestamp) + return } portal.latestReadTimestamp = receiptMessage.OriginalTimestamp - portal.log.Debug().Msgf("Marking message %s as read", dbMessage.MXID) - err := portal.SetReadMarkers(dbMessage, portalMessage.sender) + log.Debug().Msgf("Marking message %s as read", lastPart.MXID) + err := portal.SetReadMarkers(lastPart, portalMessage.sender) if err != nil { - err = fmt.Errorf("Failed to set read markers: %w", err) - portal.log.Error().Err(err).Msgf("Failed to set read markers for message %s", dbMessage.MXID) - return err + log.Error().Err(err).Msgf("Failed to set read markers for message %s", lastPart.MXID) + return } + // TODO only schedule disappearing when user reads from other device portal.ScheduleDisappearing() - } else if receiptMessage.ReceiptType == signalmeow.IncomingSignalMessageReceiptTypeDelivery { - portal.log.Debug().Msgf("Received delivery receipt") + log.Debug().Msg("Received delivery receipt") // Only send delivery MSS for DMs, not groups if portal.IsPrivateChat() { time := jsontime.UMInt(int64(receiptMessage.Timestamp)) portal.bridge.SendRawMessageCheckpoint(&status.MessageCheckpoint{ - EventID: dbMessage.MXID, + EventID: lastPart.MXID, RoomID: portal.MXID, Step: status.MsgStepRemote, Timestamp: time, Status: status.MsgStatusDelivered, ReportedBy: status.MsgReportedByBridge, }) - portal.sendStatusEvent(dbMessage.MXID, "", nil, &[]id.UserID{portal.MainIntent().UserID}) + portal.sendStatusEvent(lastPart.MXID, "", nil, &[]id.UserID{portal.MainIntent().UserID}) } } - return nil + return } func (portal *Portal) SetReadMarkers(dbMessage *database.Message, sender *Puppet) error { @@ -1307,16 +1352,25 @@ func (portal *Portal) HandleMatrixTyping(newTyping []id.UserID) { // mautrix-go ReadReceiptHandlingPortal interface func (portal *Portal) HandleMatrixReadReceipt(sender bridge.User, eventID id.EventID, receipt event.ReadReceipt) { - portal.log.Debug().Msgf("Received read receipt for event %s", eventID) + log := portal.log.With(). + Str("action", "handle matrix read receipt"). + Str("event_id", eventID.String()). + Str("sender", sender.GetMXID().String()). + Logger() + log.Debug().Msg("Received read receipt") portal.ScheduleDisappearing() // Find event in the DB - dbMessage := portal.bridge.DB.Message.GetByMXID(eventID) - if dbMessage == nil { - portal.log.Info().Msgf("Read receipt: Couldn't find message with event ID %s", eventID) + dbMessage, err := portal.bridge.DB.Message.GetByMXID(context.TODO(), eventID) + if err != nil { + log.Err(err).Msg("Failed to get read receipt target message") + return + } else if dbMessage == nil { + log.Warn().Msg("Read receipt target message not found") return } - msg := signalmeow.ReadReceptMessageForTimestamps([]uint64{dedupedTimestamp(dbMessage)}) + // TODO find all messages that haven't been marked as read by the user + msg := signalmeow.ReadReceptMessageForTimestamps([]uint64{dbMessage.Timestamp}) receiptDestination := dbMessage.Sender receiptSender := sender.(*User) @@ -1324,16 +1378,17 @@ func (portal *Portal) HandleMatrixReadReceipt(sender bridge.User, eventID id.Eve // who sent the original message, not the portal's ChatID ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - result := signalmeow.SendMessage(ctx, receiptSender.SignalDevice, receiptDestination, msg) + result := signalmeow.SendMessage(ctx, receiptSender.SignalDevice, receiptDestination.String(), msg) if !result.WasSuccessful { - err := result.FailedSendResult.Error - portal.log.Error().Msgf("Error sending read receipt to Signal %s: %s", receiptDestination, err) - return + log.Err(result.FailedSendResult.Error). + Str("receipt_destination", receiptDestination.String()). + Msg("Failed to send read receipt to Signal") + } else { + log.Debug().Str("receipt_destination", receiptDestination.String()).Msg("Sent read receipt to Signal") } - portal.log.Debug().Msgf("Sent read receipt for event %s to Signal %s", eventID, receiptDestination) } -func (portal *Portal) handleSignalAttachmentMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { +func (portal *Portal) handleSignalAttachmentMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { timestamp := portalMessage.message.Base().Timestamp msg := (portalMessage.message).(signalmeow.IncomingSignalMessageAttachment) content := signalfmt.Parse(msg.Caption, msg.CaptionRanges, signalFormatParams) @@ -1368,7 +1423,7 @@ func (portal *Portal) handleSignalAttachmentMessage(portalMessage portalSignalMe portal.log.Debug().Msgf("Received file attachment: %s", msg.ContentType) content.MsgType = event.MsgFile } - portal.addSignalQuote(content, msg.Quote) + portal.addSignalQuote(ctx, content, msg.Quote) err := portal.uploadMediaToMatrix(intent, msg.Attachment, content) if err != nil { failureMessage := "Failed to bridge media: " @@ -1382,154 +1437,123 @@ func (portal *Portal) handleSignalAttachmentMessage(portalMessage portalSignalMe portal.log.Error().Err(err).Msg(failureMessage) portal.MainIntent().SendNotice(portal.MXID, failureMessage) } - resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, 0) + resp, err := portal.sendMatrixMessage(intent, event.EventMessage, content, nil, int64(timestamp)) if err != nil { return err } if resp.EventID == "" { return errors.New("Didn't receive event ID from Matrix") } - portal.storeMessageInDB(resp.EventID, portalMessage.sender.SignalID, timestamp) - portal.addDisappearingMessage(resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) + portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex) + portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) return err } -func dedupedTimestamp(msg *database.Message) uint64 { - // TODO: This is a hack to represent multiple attachments from one Signal message - // as multiple Matrix messages. Currently signalmeow/receiving.go will give the - // 2nd and subsequent attachments a fake timestamp that is the Signal message timestamp - // but multiplied by 1000 and with their index added. This method returns the timestamp - // of the Signal message it's based on even if it's a fake timestamp. - if msg.Timestamp > 1700000000000*1000 { - return msg.Timestamp / 1000 - } - return msg.Timestamp -} - -func (portal *Portal) dbMessagesBySenderAndTimestamp(sender string, timestamp uint64) []*database.Message { - var messages []*database.Message - firstMessage := portal.bridge.DB.Message.FindBySenderAndTimestamp(sender, timestamp) - if firstMessage != nil { - messages = append(messages, firstMessage) - - // Check for subsequent messages with the same timestamp (see dedupedTimestamp) - i := uint64(1) - for { - nextMessage := portal.bridge.DB.Message.FindBySenderAndTimestamp(sender, timestamp*1000+i) - if nextMessage == nil { - break - } - messages = append(messages, nextMessage) - i++ - } - } - portal.log.Debug().Msgf("Found %d messages with sender %s and timestamp %d", len(messages), sender, timestamp) - return messages -} - -func (portal *Portal) dbReactionsBySignalID(chatID, receiver, author, msgAuthor string, msgTimestamp uint64) []*database.Reaction { - var reactions []*database.Reaction - firstReaction := portal.bridge.DB.Reaction.GetBySignalID(chatID, receiver, author, msgAuthor, msgTimestamp) - if firstReaction != nil { - reactions = append(reactions, firstReaction) - - // Check for subsequent reactions with the same timestamp (see dedupedTimestamp) - i := uint64(1) - for { - nextReaction := portal.bridge.DB.Reaction.GetBySignalID(chatID, receiver, author, msgAuthor, msgTimestamp*1000+i) - if nextReaction == nil { - break - } - reactions = append(reactions, nextReaction) - i++ - } - } - return reactions -} - -func (portal *Portal) handleSignalReactionMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { +func (portal *Portal) handleSignalReactionMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) { msg := (portalMessage.message).(signalmeow.IncomingSignalMessageReaction) - portal.log.Debug().Msgf("Reaction message received from %s (group: %v) at %v", msg.SenderUUID, msg.GroupID, msg.Timestamp) - portal.log.Debug().Msgf("Incoming Reaction details: remove: %v, target author: %v, target timestamp: %d", msg.Remove, msg.TargetAuthorUUID, msg.TargetMessageTimestamp) - matrixEmoji := variationselector.Add(msg.Emoji) // Add variation selector for Matrix - // Find the event ID of the message that was reacted to (or messages, if this is a group of images) - dbMessages := portal.dbMessagesBySenderAndTimestamp(msg.TargetAuthorUUID, msg.TargetMessageTimestamp) - if len(dbMessages) == 0 { - portal.log.Warn().Msgf("Couldn't find message with Signal ID %s/%d", msg.TargetAuthorUUID, msg.TargetMessageTimestamp) - return fmt.Errorf("couldn't find message with Signal ID %s/%d", msg.TargetAuthorUUID, msg.TargetMessageTimestamp) - } - - for _, dbMessage := range dbMessages { - dbReaction := portal.bridge.DB.Reaction.GetBySignalID( - portal.ChatID, - portal.Receiver, - msg.SenderUUID, - msg.TargetAuthorUUID, - dbMessage.Timestamp, - ) - // If there's an existing reaction, delete it - if dbReaction != nil { - portal.log.Debug().Msgf("Deleting existing reaction with author %s, target %s, targettime: %d", msg.SenderUUID, msg.TargetAuthorUUID, msg.TargetMessageTimestamp) - // Send a redaction to redact the existing reaction - _, err := intent.RedactEvent(portal.MXID, dbReaction.MXID) - if err != nil { - portal.log.Warn().Msgf("Failed to redact existing reaction: %v", err) - } - dbReaction.Delete(nil) - } else if msg.Remove { - portal.log.Warn().Msgf("Couldn't find reaction to remove with author %s, target %s, targettime: %d", msg.SenderUUID, msg.TargetAuthorUUID, msg.TargetMessageTimestamp) - } - - if !msg.Remove { - // Create a new message event with the reaction - content := &event.ReactionEventContent{ - RelatesTo: event.RelatesTo{ - Type: event.RelAnnotation, - Key: matrixEmoji, - EventID: dbMessage.MXID, - }, - } - resp, err := portal.sendMatrixReaction(intent, event.EventReaction, content, nil, 0) - if err != nil { - portal.log.Err(err).Msgf("Failed to send reaction: %v", err) - continue - } - - // Store our new reaction in the DB - portal.storeReactionInDB( - resp.EventID, - portalMessage.sender.SignalID, - msg.TargetAuthorUUID, - dbMessage.Timestamp, - msg.Emoji, // Store without variation selector, as they come from Signal - ) + log := zerolog.Ctx(ctx) + log.Debug(). + Str("target_message_sender", msg.TargetAuthorUUID). + Uint64("target_message_timestamp", msg.TargetMessageTimestamp). + Msg("Received reaction from Signal") + parsedTargetAuthor, err := uuid.Parse(msg.TargetAuthorUUID) + // TODO handle err + senderUUID, err := uuid.Parse(msg.SenderUUID) + // TODO handle err + dbMessage, err := portal.bridge.DB.Message.GetBySignalID(ctx, parsedTargetAuthor, msg.TargetMessageTimestamp, 0, portal.Receiver) + if err != nil { + log.Err(err).Msg("Failed to get reaction target message") + return + } else if dbMessage == nil { + log.Warn().Msg("Reaction target message not found") + return + } + existingReaction, err := portal.bridge.DB.Reaction.GetBySignalID( + ctx, + parsedTargetAuthor, + msg.TargetMessageTimestamp, + senderUUID, + portal.Receiver, + ) + if err != nil { + log.Err(err).Msg("Failed to get existing reaction from database") + return + } + if existingReaction != nil { + _, err = intent.RedactEvent(portal.MXID, existingReaction.MXID, mautrix.ReqRedact{ + TxnID: "mxsg_unreact_" + existingReaction.MXID.String(), + }) + if err != nil { + log.Err(err).Msg("Failed to redact reaction") + } + // TODO only delete when removing reaction, update row in db when changing + err = existingReaction.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete reaction from database") } + if msg.Remove { + return + } + } else if msg.Remove { + log.Warn().Msg("Reaction removal target reaction not found") + return } - return nil + // Create a new message event with the reaction + content := &event.ReactionEventContent{ + RelatesTo: event.RelatesTo{ + Type: event.RelAnnotation, + Key: matrixEmoji, + EventID: dbMessage.MXID, + }, + } + resp, err := portal.sendMatrixReaction(intent, event.EventReaction, content, nil, 0) + if err != nil { + portal.log.Err(err).Msgf("Failed to send reaction: %v", err) + return + } + + // Store our new reaction in the DB + portal.storeReactionInDB( + ctx, + resp.EventID, + portalMessage.sender.SignalID, + parsedTargetAuthor, + dbMessage.Timestamp, + msg.Emoji, // Store without variation selector, as they come from Signal + ) } -func (portal *Portal) handleSignalDeleteMessage(portalMessage portalSignalMessage, intent *appservice.IntentAPI) error { +func (portal *Portal) handleSignalDeleteMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) { msg := (portalMessage.message).(signalmeow.IncomingSignalMessageDelete) - portal.log.Debug().Msgf("Delete message received from %s (group: %v) at %v", msg.SenderUUID, msg.GroupID, msg.Timestamp) + senderUUID, err := uuid.Parse(msg.SenderUUID) + // TODO handle err + + log := zerolog.Ctx(ctx) // Find the event ID of the message to delete - dbMessages := portal.dbMessagesBySenderAndTimestamp(msg.SenderUUID, msg.TargetMessageTimestamp) - if len(dbMessages) == 0 { - portal.log.Warn().Msgf("Couldn't find message with Signal ID %s/%d", msg.SenderUUID, msg.TargetMessageTimestamp) - return fmt.Errorf("couldn't find message with Signal ID %s/%d", msg.SenderUUID, msg.TargetMessageTimestamp) + messages, err := portal.bridge.DB.Message.GetAllPartsBySignalID(ctx, senderUUID, msg.TargetMessageTimestamp, portal.Receiver) + if err != nil { + log.Err(err).Msg("Failed to get messages to delete") + return + } else if len(messages) == 0 { + log.Warn().Msg("Didn't find any messages to delete") + return } - for _, dbMessage := range dbMessages { - _, err := intent.RedactEvent(portal.MXID, dbMessage.MXID) + for _, targetMsg := range messages { + _, err = intent.RedactEvent(portal.MXID, targetMsg.MXID) if err != nil { - portal.log.Warn().Msgf("Failed to redact existing reaction: %v", err) - return err + log.Err(err).Msg("Failed to redact message") + continue + } + err = targetMsg.Delete(ctx) + if err != nil { + log.Err(err).Msg("Failed to delete message from database") + continue } - dbMessage.Delete(nil) } - - return nil + return } func (portal *Portal) sendMainIntentMessage(content *event.MessageEventContent) (*mautrix.RespSendEvent, error) { @@ -1637,32 +1661,6 @@ func (portal *Portal) sendMatrixEventContent(intent *appservice.IntentAPI, event } } -func (portal *Portal) getMessagePuppet(user *User, senderUUID string) (puppet *Puppet) { - if portal.IsPrivateChat() { - puppet = portal.bridge.GetPuppetBySignalID(portal.ChatID) - } else if senderUUID != "" { - puppet = portal.bridge.GetPuppetBySignalID(senderUUID) - } - if puppet == nil { - return nil - } - return puppet -} - -func (portal *Portal) getMessageIntent(user *User, senderUUID string) *appservice.IntentAPI { - puppet := portal.getMessagePuppet(user, senderUUID) - if puppet == nil { - portal.log.Debug().Msg("Not handling: puppet is nil") - return nil - } - intent := puppet.IntentFor(portal) - //if !intent.IsCustomPuppet && portal.IsPrivateChat() { //&& info.Sender.User == portal.Key.Receiver.User && portal.Key.Receiver != portal.Key.JID { - // portal.log.Debugfln("Not handling: user doesn't have double puppeting enabled") - // return nil - //} - return intent -} - func (portal *Portal) getEncryptionEventContent() (evt *event.EncryptionEventContent) { evt = &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} if rot := portal.bridge.Config.Bridge.Encryption.Rotation; rot.EnableCustom { @@ -1767,7 +1765,10 @@ func (portal *Portal) CreateMatrixRoom(user *User, meta *any) error { portal.bridge.portalsLock.Lock() portal.bridge.portalsByMXID[portal.MXID] = portal portal.bridge.portalsLock.Unlock() - portal.Update() + err = portal.Update(context.TODO()) + if err != nil { + portal.log.Err(err).Msg("Failed to save created portal mxid") + } portal.log.Info().Msgf("Created matrix room %s", portal.MXID) if portal.Encrypted && portal.IsPrivateChat() { @@ -1794,14 +1795,6 @@ func (portal *Portal) CreateMatrixRoom(user *User, meta *any) error { user.UpdateDirectChats(chats) } - _, err = portal.MainIntent().SendMessageEvent(portal.MXID, portalCreationDummyEvent, struct{}{}) - if err != nil { - portal.log.Error().Err(err).Msg("Failed to send dummy event to mark portal creation") - } else { - portal.log.Debug().Msg("Sent dummy event to mark portal creation") - portal.Update() - } - return nil } @@ -1814,15 +1807,15 @@ var ( portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType} ) -func (br *SignalBridge) loadPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { +func (br *SignalBridge) loadPortal(ctx context.Context, dbPortal *database.Portal, key *database.PortalKey) *Portal { if dbPortal == nil { if key == nil { return nil } dbPortal = br.DB.Portal.New() - dbPortal.SetPortalKey(*key) - err := dbPortal.Insert() + dbPortal.PortalKey = *key + err := dbPortal.Insert(ctx) if err != nil { br.ZLog.Err(err).Msg("Failed to insert new portal") return nil @@ -1831,7 +1824,7 @@ func (br *SignalBridge) loadPortal(dbPortal *database.Portal, key *database.Port portal := br.NewPortal(dbPortal) - br.portalsByID[portal.Key()] = portal + br.portalsByID[portal.PortalKey] = portal if portal.MXID != "" { br.portalsByMXID[portal.MXID] = portal } @@ -1845,7 +1838,12 @@ func (br *SignalBridge) GetPortalByMXID(mxid id.RoomID) *Portal { portal, ok := br.portalsByMXID[mxid] if !ok { - return br.loadPortal(br.DB.Portal.GetByMXID(mxid), nil) + dbPortal, err := br.DB.Portal.GetByMXID(context.TODO(), mxid) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get portal from database") + return nil + } + return br.loadPortal(context.TODO(), dbPortal, nil) } return portal @@ -1855,12 +1853,17 @@ func (br *SignalBridge) GetPortalByChatID(key database.PortalKey) *Portal { br.portalsLock.Lock() defer br.portalsLock.Unlock() // If this PortalKey is for a group, Receiver should be empty - if !isUUID(key.ChatID) { - key.Receiver = "" + if key.UserID() == uuid.Nil { + key.Receiver = uuid.Nil } portal, ok := br.portalsByID[key] if !ok { - return br.loadPortal(br.DB.Portal.GetByChatID(key), &key) + dbPortal, err := br.DB.Portal.GetByChatID(context.TODO(), key) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get portal from database") + return nil + } + return br.loadPortal(context.TODO(), dbPortal, &key) } return portal } @@ -1871,7 +1874,7 @@ func (portal *Portal) getBridgeInfoStateKey() string { // ** DisappearingPortal interface ** func (portal *Portal) ScheduleDisappearing() { - portal.bridge.disappearingMessagesManager.ScheduleDisappearingForRoom(portal.MXID) + portal.bridge.disappearingMessagesManager.ScheduleDisappearingForRoom(context.TODO(), portal.MXID) } func (portal *Portal) HandleNewDisappearingMessageTime(newTimer uint32) { @@ -1903,12 +1906,12 @@ func (portal *Portal) addRelaybotFormat(userID id.UserID, content *event.Message } func (portal *Portal) Delete() { - err := portal.Portal.Delete() + err := portal.Portal.Delete(context.TODO()) if err != nil { portal.log.Err(err).Msg("Failed to delete portal from db") } portal.bridge.portalsLock.Lock() - delete(portal.bridge.portalsByID, portal.Key()) + delete(portal.bridge.portalsByID, portal.PortalKey) if len(portal.MXID) > 0 { delete(portal.bridge.portalsByMXID, portal.MXID) } @@ -1917,41 +1920,44 @@ func (portal *Portal) Delete() { } func (portal *Portal) Cleanup(puppetsOnly bool) { - if len(portal.MXID) == 0 { + portal.bridge.CleanupRoom(&portal.log, portal.MainIntent(), portal.MXID, puppetsOnly) +} + +func (br *SignalBridge) CleanupRoom(log *zerolog.Logger, intent *appservice.IntentAPI, mxid id.RoomID, puppetsOnly bool) { + if len(mxid) == 0 { return } - intent := portal.MainIntent() - if portal.bridge.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { - err := intent.BeeperDeleteRoom(portal.MXID) + if br.SpecVersions.Supports(mautrix.BeeperFeatureRoomYeeting) { + err := intent.BeeperDeleteRoom(mxid) if err == nil || errors.Is(err, mautrix.MNotFound) { return } - portal.log.Warn().Err(err).Msg("Failed to delete room using beeper yeet endpoint, falling back to normal behavior") + log.Warn().Err(err).Msg("Failed to delete room using beeper yeet endpoint, falling back to normal behavior") } - members, err := intent.JoinedMembers(portal.MXID) + members, err := intent.JoinedMembers(mxid) if err != nil { - portal.log.Err(err).Msg("Failed to get portal members for cleanup") + log.Err(err).Msg("Failed to get portal members for cleanup") return } for member := range members.Joined { if member == intent.UserID { continue } - puppet := portal.bridge.GetPuppetByMXID(member) + puppet := br.GetPuppetByMXID(member) if puppet != nil { - _, err = puppet.DefaultIntent().LeaveRoom(portal.MXID) + _, err = puppet.DefaultIntent().LeaveRoom(mxid) if err != nil { - portal.log.Err(err).Msg("Failed to leave as puppet while cleaning up portal") + log.Err(err).Msg("Failed to leave as puppet while cleaning up portal") } } else if !puppetsOnly { - _, err = intent.KickUser(portal.MXID, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) + _, err = intent.KickUser(mxid, &mautrix.ReqKickUser{UserID: member, Reason: "Deleting portal"}) if err != nil { - portal.log.Err(err).Msg("Failed to kick user while cleaning up portal") + log.Err(err).Msg("Failed to kick user while cleaning up portal") } } } - _, err = intent.LeaveRoom(portal.MXID) + _, err = intent.LeaveRoom(mxid) if err != nil { - portal.log.Err(err).Msg("Failed to leave room while cleaning up portal") + log.Err(err).Msg("Failed to leave room while cleaning up portal") } } diff --git a/provisioning.go b/provisioning.go index ae995c93..1eb9b9b7 100644 --- a/provisioning.go +++ b/provisioning.go @@ -30,6 +30,7 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" "github.com/rs/zerolog" "maunium.net/go/mautrix/id" @@ -184,7 +185,7 @@ func (prov *ProvisioningAPI) resolveIdentifier(user *User, phoneNum string) (int } portal := user.GetPortalByChatID(contact.UUID) - puppet := prov.bridge.GetPuppetBySignalID(contact.UUID) + puppet := prov.bridge.GetPuppetBySignalIDString(contact.UUID) return http.StatusOK, &ResolveIdentifierResponse{ RoomID: portal.MXID.String(), @@ -511,9 +512,13 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ // Update user with SignalID if resp.ProvisioningData.AciUuid != "" { - user.SignalID = resp.ProvisioningData.AciUuid + user.SignalID, err = uuid.Parse(resp.ProvisioningData.AciUuid) + // TODO handle err user.SignalUsername = resp.ProvisioningData.Number - user.Update() + err = user.Update(r.Context()) + if err != nil { + prov.log.Err(err).Msg("Failed to save user after login") + } } return case <-time.After(45 * time.Second): @@ -586,7 +591,7 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R jsonResponse(w, http.StatusOK, Response{ Success: true, Status: "prekeys_registered", - UUID: user.SignalID, + UUID: user.SignalID.String(), Number: user.SignalUsername, }) diff --git a/puppet.go b/puppet.go index a0395c57..45cc79a4 100644 --- a/puppet.go +++ b/puppet.go @@ -17,10 +17,12 @@ package main import ( + "context" "fmt" "regexp" "sync" + "github.com/google/uuid" "github.com/rs/zerolog" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -66,7 +68,7 @@ func (puppet *Puppet) CustomIntent() *appservice.IntentAPI { func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI { if puppet != nil { - if puppet.customIntent == nil || portal.Key().ChatID == puppet.SignalID { + if puppet.customIntent == nil || portal.UserID() == puppet.SignalID { return puppet.DefaultIntent() } return puppet.customIntent @@ -88,13 +90,13 @@ func (br *SignalBridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { return &Puppet{ Puppet: dbPuppet, bridge: br, - log: br.ZLog.With().Str("signal_user_id", dbPuppet.SignalID).Logger(), + log: br.ZLog.With().Str("signal_user_id", dbPuppet.SignalID.String()).Logger(), MXID: br.FormatPuppetMXID(dbPuppet.SignalID), } } -func (br *SignalBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) { +func (br *SignalBridge) ParsePuppetMXID(mxid id.UserID) (uuid.UUID, bool) { if userIDRegex == nil { pattern := fmt.Sprintf( "^@%s:%s$", @@ -109,10 +111,14 @@ func (br *SignalBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) { match := userIDRegex.FindStringSubmatch(string(mxid)) if len(match) == 2 { - return match[1], true + parsed, err := uuid.Parse(match[1]) + if err != nil { + return uuid.Nil, false + } + return parsed, true } - return "", false + return uuid.Nil, false } func (br *SignalBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { @@ -124,26 +130,37 @@ func (br *SignalBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { return br.GetPuppetBySignalID(signalID) } -func (br *SignalBridge) GetPuppetBySignalID(id string) *Puppet { +func (br *SignalBridge) GetPuppetBySignalIDString(id string) *Puppet { + parsed, err := uuid.Parse(id) + if err != nil { + return nil + } + return br.GetPuppetBySignalID(parsed) +} + +func (br *SignalBridge) GetPuppetBySignalID(id uuid.UUID) *Puppet { br.puppetsLock.Lock() defer br.puppetsLock.Unlock() - if id == "" { + if id == uuid.Nil { br.ZLog.Warn().Msg("Trying to get puppet with empty signal_user_id") return nil } puppet, ok := br.puppets[id] if !ok { - dbPuppet := br.DB.Puppet.GetBySignalID(id) - if dbPuppet == nil { - br.ZLog.Info().Str("signal_user_id", id).Msg("Puppet not found in database, creating new entry") + dbPuppet, err := br.DB.Puppet.GetBySignalID(context.TODO(), id) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get puppet from database") + return nil + } else if dbPuppet == nil { + br.ZLog.Info().Str("signal_user_id", id.String()).Msg("Puppet not found in database, creating new entry") dbPuppet = br.DB.Puppet.New() dbPuppet.SignalID = id //dbPuppet.Number = - err := dbPuppet.Insert() + err = dbPuppet.Insert(context.TODO()) if err != nil { - br.ZLog.Error().Err(err).Str("signal_user_id", id).Msg("Error creating new puppet") + br.ZLog.Error().Err(err).Str("signal_user_id", id.String()).Msg("Error creating new puppet") return nil } } @@ -152,8 +169,8 @@ func (br *SignalBridge) GetPuppetBySignalID(id string) *Puppet { if puppet.CustomMXID != "" { br.puppetsByCustomMXID[puppet.CustomMXID] = puppet } - if puppet.Number != nil { - br.puppetsByNumber[*puppet.Number] = puppet + if puppet.Number != "" { + br.puppetsByNumber[puppet.Number] = puppet } } return puppet @@ -165,8 +182,11 @@ func (br *SignalBridge) GetPuppetByNumber(number string) *Puppet { puppet, ok := br.puppetsByNumber[number] if !ok { - dbPuppet := br.DB.Puppet.GetByNumber(number) - if dbPuppet == nil { + dbPuppet, err := br.DB.Puppet.GetByNumber(context.TODO(), number) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get puppet from database") + return nil + } else if dbPuppet == nil { return nil } @@ -175,8 +195,8 @@ func (br *SignalBridge) GetPuppetByNumber(number string) *Puppet { if puppet.CustomMXID != "" { br.puppetsByCustomMXID[puppet.CustomMXID] = puppet } - if puppet.Number != nil { - br.puppetsByNumber[*puppet.Number] = puppet + if puppet.Number != "" { + br.puppetsByNumber[puppet.Number] = puppet } } return puppet @@ -188,23 +208,26 @@ func (br *SignalBridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet { puppet, ok := br.puppetsByCustomMXID[mxid] if !ok { - dbPuppet := br.DB.Puppet.GetByCustomMXID(mxid) - if dbPuppet == nil { + dbPuppet, err := br.DB.Puppet.GetByCustomMXID(context.TODO(), mxid) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get puppet from database") + return nil + } else if dbPuppet == nil { return nil } puppet = br.NewPuppet(dbPuppet) br.puppets[puppet.SignalID] = puppet br.puppetsByCustomMXID[puppet.CustomMXID] = puppet - if puppet.Number != nil { - br.puppetsByNumber[*puppet.Number] = puppet + if puppet.Number != "" { + br.puppetsByNumber[puppet.Number] = puppet } } return puppet } func (br *SignalBridge) GetAllPuppetsWithCustomMXID() []*Puppet { - puppets, err := br.DB.Puppet.GetAllWithCustomMXID() + puppets, err := br.DB.Puppet.GetAllWithCustomMXID(context.TODO()) if err != nil { br.ZLog.Error().Err(err).Msg("Failed to get all puppets with custom MXID") return nil @@ -212,9 +235,9 @@ func (br *SignalBridge) GetAllPuppetsWithCustomMXID() []*Puppet { return br.dbPuppetsToPuppets(puppets) } -func (br *SignalBridge) FormatPuppetMXID(did string) id.UserID { +func (br *SignalBridge) FormatPuppetMXID(u uuid.UUID) id.UserID { return id.NewUserID( - br.Config.Bridge.FormatUsername(did), + br.Config.Bridge.FormatUsername(u.String()), br.Config.Homeserver.Domain, ) } @@ -232,8 +255,8 @@ func (br *SignalBridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Pupp if !ok { puppet = br.NewPuppet(dbPuppet) br.puppets[dbPuppet.SignalID] = puppet - if dbPuppet.Number != nil { - br.puppetsByNumber[*dbPuppet.Number] = puppet + if dbPuppet.Number != "" { + br.puppetsByNumber[dbPuppet.Number] = puppet } if dbPuppet.CustomMXID != "" { br.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet diff --git a/user.go b/user.go index 1f9660c0..96a4e551 100644 --- a/user.go +++ b/user.go @@ -90,12 +90,15 @@ func (user *User) SetManagementRoom(roomID id.RoomID) { existing, ok := user.bridge.managementRooms[roomID] if ok { existing.ManagementRoom = "" - existing.Update() + err := existing.Update(context.TODO()) + if err != nil { + existing.log.Err(err).Msg("Failed to update user when removing management room") + } } user.ManagementRoom = roomID user.bridge.managementRooms[user.ManagementRoom] = user - err := user.Update() + err := user.Update(context.TODO()) if err != nil { user.log.Error().Err(err).Msg("Error setting management room") } @@ -119,14 +122,14 @@ func (user *User) GetIGhost() bridge.Ghost { // ** User creation and fetching ** -func (br *SignalBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { +func (br *SignalBridge) loadUser(ctx context.Context, dbUser *database.User, mxid *id.UserID) *User { if dbUser == nil { if mxid == nil { return nil } dbUser = br.DB.User.New() dbUser.MXID = *mxid - err := dbUser.Insert() + err := dbUser.Insert(ctx) if err != nil { br.ZLog.Err(err).Msg("Error creating user %s") return nil @@ -135,7 +138,7 @@ func (br *SignalBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { user := br.NewUser(dbUser) br.usersByMXID[user.MXID] = user - if user.SignalID != "" { + if user.SignalID != uuid.Nil { br.usersBySignalID[user.SignalID] = user } if user.ManagementRoom != "" { @@ -143,11 +146,12 @@ func (br *SignalBridge) loadUser(dbUser *database.User, mxid *id.UserID) *User { br.managementRooms[user.ManagementRoom] = user br.managementRoomsLock.Unlock() } + // TODO this is completely wrong and shouldn't be here at all // Ensure a puppet is created for this user newPuppet := br.GetPuppetBySignalID(user.SignalID) if newPuppet != nil && newPuppet.CustomMXID == "" { newPuppet.CustomMXID = user.MXID - err := newPuppet.Update() + err := newPuppet.Update(ctx) if err != nil { br.ZLog.Err(err).Msg("Error updating puppet for user %s") } @@ -164,18 +168,28 @@ func (br *SignalBridge) GetUserByMXID(userID id.UserID) *User { user, ok := br.usersByMXID[userID] if !ok { - return br.loadUser(br.DB.User.GetByMXID(userID), &userID) + dbUser, err := br.DB.User.GetByMXID(context.TODO(), userID) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get user from database") + return nil + } + return br.loadUser(context.TODO(), dbUser, &userID) } return user } -func (br *SignalBridge) GetUserBySignalID(id string) *User { +func (br *SignalBridge) GetUserBySignalID(id uuid.UUID) *User { br.usersLock.Lock() defer br.usersLock.Unlock() user, ok := br.usersBySignalID[id] if !ok { - return br.loadUser(br.DB.User.GetBySignalID(id), nil) + dbUser, err := br.DB.User.GetBySignalID(context.TODO(), id) + if err != nil { + br.ZLog.Err(err).Msg("Failed to get user from database") + return nil + } + return br.loadUser(context.TODO(), dbUser, nil) } return user } @@ -274,7 +288,7 @@ func (user *User) GetMXID() id.UserID { return user.MXID } func (user *User) GetRemoteID() string { - return user.SignalID + return user.SignalID.String() } func (user *User) GetRemoteName() string { @@ -287,13 +301,17 @@ func (br *SignalBridge) getAllLoggedInUsers() []*User { br.usersLock.Lock() defer br.usersLock.Unlock() - dbUsers := br.DB.User.AllLoggedIn() + dbUsers, err := br.DB.User.GetAllLoggedIn(context.TODO()) + if err != nil { + br.ZLog.Err(err).Msg("Error getting all logged in users") + return nil + } users := make([]*User, len(dbUsers)) for idx, dbUser := range dbUsers { user, ok := br.usersByMXID[dbUser.MXID] if !ok { - user = br.loadUser(dbUser, nil) + user = br.loadUser(context.TODO(), dbUser, nil) } users[idx] = user } @@ -488,11 +506,11 @@ func (user *User) populateSignalDevice() *signalmeow.Device { user.Lock() defer user.Unlock() - if user.SignalID == "" { + if user.SignalID == uuid.Nil { return nil } - device, err := user.bridge.MeowStore.DeviceByAci(user.SignalID) + device, err := user.bridge.MeowStore.DeviceByAci(user.SignalID.String()) if err != nil { user.log.Err(err).Msgf("problem looking up aci %s", user.SignalID) return nil @@ -508,7 +526,7 @@ func (user *User) populateSignalDevice() *signalmeow.Device { } func updatePuppetWithSignalContact(ctx context.Context, user *User, puppet *Puppet, newContactAvatar *signalmeow.ContactAvatar) error { - contact, newProfileAvatar, err := user.SignalDevice.ContactByIDWithProfileAvatar(puppet.SignalID) + contact, newProfileAvatar, err := user.SignalDevice.ContactByIDWithProfileAvatar(puppet.SignalID.String()) if err != nil { user.log.Err(err).Msg("updatePuppetWithSignalContact: error retrieving contact") return err @@ -525,7 +543,7 @@ func updatePuppetWithSignalContact(ctx context.Context, user *User, puppet *Pupp return err } puppet.NameSet = true - err = puppet.Update() + err = puppet.Update(ctx) if err != nil { user.log.Err(err).Msg("updatePuppetWithSignalContact: error updating puppet with new name") return err @@ -553,7 +571,7 @@ func updatePuppetWithSignalContact(ctx context.Context, user *User, puppet *Pupp user.log.Err(err).Msg("updatePuppetWithSignalContact: error clearing avatar url") return err } - err = puppet.Update() + err = puppet.Update(ctx) if err != nil { user.log.Err(err).Msg("updatePuppetWithSignalContact: error updating puppet while clearing avatar") return err @@ -578,7 +596,7 @@ func updatePuppetWithSignalContact(ctx context.Context, user *User, puppet *Pupp user.log.Err(err).Msg("updatePuppetWithSignalContact: error setting avatar url") return err } - err = puppet.Update() + err = puppet.Update(ctx) if err != nil { user.log.Err(err).Msg("updatePuppetWithSignalContact: error updating puppet with new avatar") return err @@ -607,10 +625,15 @@ func ensureGroupPuppetsAreJoinedToPortal(ctx context.Context, user *User, portal return err } for _, member := range group.Members { - if member.UserId == user.SignalID { + parsedUserID, err := uuid.Parse(member.UserId) + if err != nil { + // TODO log? continue } - memberPuppet := portal.bridge.GetPuppetBySignalID(member.UserId) + if parsedUserID == user.SignalID { + continue + } + memberPuppet := portal.bridge.GetPuppetBySignalID(parsedUserID) if memberPuppet == nil { user.log.Err(err).Msgf("no puppet found for signalID %s", member.UserId) continue @@ -629,8 +652,12 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign m := incomingMessage.Base() var chatID string var senderPuppet *Puppet + parsedSenderUUID, err := uuid.Parse(m.SenderUUID) + if err != nil { + return err + } - isSyncMessage := m.SenderUUID == user.SignalID + isSyncMessage := parsedSenderUUID == user.SignalID // Get and update the puppet for this message user.tryAutomaticDoublePuppeting() @@ -640,7 +667,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign chatID = m.RecipientUUID senderPuppet = user.bridge.GetPuppetByCustomMXID(user.MXID) if senderPuppet == nil { - senderPuppet = user.bridge.GetPuppetBySignalID(m.SenderUUID) + senderPuppet = user.bridge.GetPuppetBySignalID(parsedSenderUUID) if senderPuppet == nil { err := fmt.Errorf("no puppet found for me (%s)", user.MXID) user.log.Err(err).Msg("error getting puppet") @@ -650,7 +677,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign } else { user.log.Debug().Msgf("Message received from %s (group: %v)", m.SenderUUID, m.GroupID) chatID = m.SenderUUID - senderPuppet = user.bridge.GetPuppetBySignalID(m.SenderUUID) + senderPuppet = user.bridge.GetPuppetBySignalID(parsedSenderUUID) if senderPuppet == nil { err := fmt.Errorf("no puppet found for sender: %s", m.SenderUUID) user.log.Err(err).Msg("error getting puppet") @@ -678,9 +705,16 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign if incomingMessage.MessageType() == signalmeow.IncomingSignalMessageTypeReceipt { receiptMessage := incomingMessage.(signalmeow.IncomingSignalMessageReceipt) timestamp := receiptMessage.OriginalTimestamp - sender := receiptMessage.OriginalSender - dbMessage := user.bridge.DB.Message.FindBySenderAndTimestamp(sender, timestamp) - if dbMessage == nil { + sender, err := uuid.Parse(receiptMessage.OriginalSender) + if err != nil { + user.log.Err(err).Msg("Failed to parse sender UUID in receipt") + return nil + } + dbMessage, err := user.bridge.DB.Message.GetBySignalIDWithUnknownReceiver(context.TODO(), sender, timestamp, 0, user.SignalID) + if err != nil { + user.log.Err(err).Msg("Failed to get receipt target message from database") + return nil + } else if dbMessage == nil { user.log.Warn().Msgf("Receipt received for unknown message %v %d", user.SignalID, timestamp) return nil } @@ -701,7 +735,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign portal.log.Debug().Msgf("Updating expiration time to %d (DM)", expireTimerMessage.NewExpireTimer) if portal.ExpirationTime != int(expireTimerMessage.NewExpireTimer) { portal.ExpirationTime = int(expireTimerMessage.NewExpireTimer) - err := portal.Update() + err := portal.Update(context.TODO()) if err != nil { user.log.Err(err).Msg("error updating exipration time in portal") } @@ -768,7 +802,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign user.log.Err(err).Msg("error setting room avatar") } portal.AvatarSet = err == nil - err = portal.Update() + err = portal.Update(context.TODO()) if err != nil { user.log.Err(err).Msg("error updating portal") } @@ -796,7 +830,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign func (user *User) GetPortalByChatID(signalID string) *Portal { pk := database.PortalKey{ ChatID: signalID, - Receiver: user.SignalUsername, + Receiver: user.SignalID, } return user.bridge.GetPortalByChatID(pk) } @@ -894,10 +928,14 @@ func (user *User) UpdateDirectChats(chats map[id.UserID][]id.RoomID) { func (user *User) getDirectChats() map[id.UserID][]id.RoomID { chats := map[id.UserID][]id.RoomID{} - privateChats := user.bridge.DB.Portal.FindPrivateChatsOf(user.SignalID) + privateChats, err := user.bridge.DB.Portal.FindPrivateChatsOf(context.TODO(), user.SignalID) + if err != nil { + user.log.Err(err).Msg("Failed to get private chats") + return chats + } for _, portal := range privateChats { if portal.MXID != "" { - puppetMXID := user.bridge.FormatPuppetMXID(portal.Key().Receiver) + puppetMXID := user.bridge.FormatPuppetMXID(portal.UserID()) chats[puppetMXID] = []id.RoomID{portal.MXID} }