From b453b08b853b1c95bcc2dbac2f0f40b44f3113e7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 25 Dec 2023 16:10:14 +0200 Subject: [PATCH] Refactor remaining tables --- commands.go | 7 +- database/disappearingmessage.go | 26 +-- database/message.go | 15 +- database/portal.go | 53 +++--- database/puppet.go | 31 +-- database/reaction.go | 9 +- database/upgrades/00-latest.sql | 74 ++++---- database/upgrades/16-message-primary-key.sql | 28 --- database/upgrades/16-refactor-postgres.sql | 92 +++++++++ database/upgrades/17-message-primary-key.sql | 80 -------- database/upgrades/17-refactor-sqlite.sql | 187 +++++++++++++++++++ database/user.go | 19 +- disappearing.go | 8 +- main.go | 23 +-- msgconv/matrixfmt/convert_test.go | 7 +- msgconv/matrixfmt/html.go | 9 +- msgconv/signalfmt/convert.go | 11 +- msgconv/signalfmt/convert_test.go | 8 +- msgconv/signalfmt/tags.go | 6 +- pkg/signalmeow/sending.go | 19 +- portal.go | 70 +++---- provisioning.go | 8 +- puppet.go | 35 ++-- user.go | 43 +++-- 24 files changed, 534 insertions(+), 334 deletions(-) delete mode 100644 database/upgrades/16-message-primary-key.sql create mode 100644 database/upgrades/16-refactor-postgres.sql delete mode 100644 database/upgrades/17-message-primary-key.sql create mode 100644 database/upgrades/17-refactor-sqlite.sql diff --git a/commands.go b/commands.go index 0eea6cba..bef20657 100644 --- a/commands.go +++ b/commands.go @@ -20,6 +20,7 @@ import ( "context" "strings" + "github.com/google/uuid" "github.com/skip2/go-qrcode" "maunium.net/go/mautrix/bridge/commands" "maunium.net/go/mautrix/event" @@ -223,7 +224,11 @@ func fnLogin(ce *WrappedCommandEvent) { // Update user with SignalID if signalID != "" { - ce.User.SignalID = signalID + ce.User.SignalID, err = uuid.Parse(signalID) + if err != nil { + ce.Reply("Problem logging in - SignalID is not a valid UUID") + return + } ce.User.SignalUsername = signalUsername } else { ce.Reply("Problem logging in - No SignalID received") diff --git a/database/disappearingmessage.go b/database/disappearingmessage.go index 18900ae3..dd0f4a22 100644 --- a/database/disappearingmessage.go +++ b/database/disappearingmessage.go @@ -57,23 +57,23 @@ type DisappearingMessageQuery struct { type DisappearingMessage struct { qh *dbutil.QueryHelper[*DisappearingMessage] - RoomID id.RoomID - EventID id.EventID - ExpireInSeconds int64 // TODO change to time.Duration - ExpireAt time.Time + RoomID id.RoomID + EventID id.EventID + ExpireIn time.Duration + ExpireAt time.Time } func newDisappearingMessage(qh *dbutil.QueryHelper[*DisappearingMessage]) *DisappearingMessage { return &DisappearingMessage{qh: qh} } -func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireInSeconds int64, expireAt time.Time) *DisappearingMessage { +func (dmq *DisappearingMessageQuery) NewWithValues(roomID id.RoomID, eventID id.EventID, expireIn time.Duration, expireAt time.Time) *DisappearingMessage { return &DisappearingMessage{ - qh: dmq.QueryHelper, - RoomID: roomID, - EventID: eventID, - ExpireInSeconds: expireInSeconds, - ExpireAt: expireAt, + qh: dmq.QueryHelper, + RoomID: roomID, + EventID: eventID, + ExpireIn: expireIn, + ExpireAt: expireAt, } } @@ -96,7 +96,7 @@ func (msg *DisappearingMessage) Scan(row dbutil.Scannable) (*DisappearingMessage if err != nil { return nil, err } - msg.ExpireInSeconds = expireIn + msg.ExpireIn = time.Duration(expireIn) * time.Second if expireAt.Valid { msg.ExpireAt = time.Unix(expireAt.Int64, 0) } @@ -109,7 +109,7 @@ func (msg *DisappearingMessage) sqlVariables() []any { expireAt.Valid = true expireAt.Int64 = msg.ExpireAt.Unix() } - return []any{msg.RoomID, msg.EventID, msg.ExpireInSeconds, expireAt} + return []any{msg.RoomID, msg.EventID, int64(msg.ExpireIn.Seconds()), expireAt} } func (msg *DisappearingMessage) Insert(ctx context.Context) error { @@ -117,7 +117,7 @@ func (msg *DisappearingMessage) Insert(ctx context.Context) error { } func (msg *DisappearingMessage) StartExpirationTimer(ctx context.Context) error { - msg.ExpireAt = time.Now().Add(time.Duration(msg.ExpireInSeconds) * time.Second) + msg.ExpireAt = time.Now().Add(msg.ExpireIn) return msg.qh.Exec(ctx, updateDisappearingMessageQuery, msg.EventID, msg.ExpireAt.Unix()) } diff --git a/database/message.go b/database/message.go index 51fd94af..9914a270 100644 --- a/database/message.go +++ b/database/message.go @@ -21,6 +21,7 @@ import ( "fmt" "strings" + "github.com/google/uuid" "maunium.net/go/mautrix/id" "go.mau.fi/util/dbutil" @@ -79,12 +80,12 @@ type MessageQuery struct { type Message struct { qh *dbutil.QueryHelper[*Message] - Sender string + Sender uuid.UUID Timestamp uint64 PartIndex int SignalChatID string - SignalReceiver string + SignalReceiver uuid.UUID MXID id.EventID RoomID id.RoomID @@ -98,23 +99,23 @@ func (mq *MessageQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*Messag return mq.QueryOne(ctx, getMessageByMXIDQuery, mxid) } -func (mq *MessageQuery) GetBySignalIDWithUnknownReceiver(ctx context.Context, sender string, timestamp uint64, partIndex int, receiver string) (*Message, error) { +func (mq *MessageQuery) GetBySignalIDWithUnknownReceiver(ctx context.Context, sender uuid.UUID, timestamp uint64, partIndex int, receiver uuid.UUID) (*Message, error) { return mq.QueryOne(ctx, getMessagePartBySignalIDWithUnknownReceiverQuery, sender, timestamp, partIndex, receiver) } -func (mq *MessageQuery) GetBySignalID(ctx context.Context, sender string, timestamp uint64, partIndex int, receiver string) (*Message, error) { +func (mq *MessageQuery) GetBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, partIndex int, receiver uuid.UUID) (*Message, error) { return mq.QueryOne(ctx, getMessagePartBySignalIDQuery, sender, timestamp, partIndex, receiver) } -func (mq *MessageQuery) GetLastPartBySignalID(ctx context.Context, sender string, timestamp uint64, receiver string) (*Message, error) { +func (mq *MessageQuery) GetLastPartBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, receiver uuid.UUID) (*Message, error) { return mq.QueryOne(ctx, getLastMessagePartBySignalIDQuery, sender, timestamp, receiver) } -func (mq *MessageQuery) GetAllPartsBySignalID(ctx context.Context, sender string, timestamp uint64, receiver string) ([]*Message, error) { +func (mq *MessageQuery) GetAllPartsBySignalID(ctx context.Context, sender uuid.UUID, timestamp uint64, receiver uuid.UUID) ([]*Message, error) { return mq.QueryMany(ctx, getAllMessagePartsBySignalIDQuery, sender, timestamp, receiver) } -func (mq *MessageQuery) GetManyBySignalID(ctx context.Context, sender string, timestamps []uint64, receiver string) ([]*Message, error) { +func (mq *MessageQuery) GetManyBySignalID(ctx context.Context, sender uuid.UUID, timestamps []uint64, receiver uuid.UUID) ([]*Message, error) { if mq.GetDB().Dialect == dbutil.Postgres { return mq.QueryMany(ctx, getManyMessagesBySignalIDQueryPostgres, sender, receiver, timestamps) } else { diff --git a/database/portal.go b/database/portal.go index 8ba01a8d..691dd22c 100644 --- a/database/portal.go +++ b/database/portal.go @@ -19,11 +19,13 @@ package database import ( "context" "database/sql" - "fmt" + "github.com/google/uuid" "maunium.net/go/mautrix/id" "go.mau.fi/util/dbutil" + + "go.mau.fi/mautrix-signal/pkg/signalmeow" ) const ( @@ -56,11 +58,23 @@ type PortalQuery struct { } type PortalKey struct { - ChatID string // TODO use some kind of union type between *uuid.UUID and a group ID as bytes? - Receiver string // TODO change to *uuid.UUID? + ChatID string + Receiver uuid.UUID +} + +func (pk *PortalKey) UserID() uuid.UUID { + parsed, _ := uuid.Parse(pk.ChatID) + return parsed } -func NewPortalKey(chatID, receiver string) PortalKey { +func (pk *PortalKey) GroupID() signalmeow.GroupIdentifier { + if len(pk.ChatID) == 44 { + return signalmeow.GroupIdentifier(pk.ChatID) + } + return "" +} + +func NewPortalKey(chatID string, receiver uuid.UUID) PortalKey { return PortalKey{ ChatID: chatID, Receiver: receiver, @@ -96,7 +110,7 @@ func (pq *PortalQuery) GetByChatID(ctx context.Context, pk PortalKey) (*Portal, return pq.QueryOne(ctx, getPortalByChatIDQuery, pk.ChatID, pk.Receiver) } -func (pq *PortalQuery) FindPrivateChatsOf(ctx context.Context, receiver string) ([]*Portal, error) { +func (pq *PortalQuery) FindPrivateChatsOf(ctx context.Context, receiver uuid.UUID) ([]*Portal, error) { return pq.QueryMany(ctx, getPortalsByReceiver, receiver) } @@ -105,39 +119,26 @@ func (pq *PortalQuery) GetAllWithMXID(ctx context.Context) ([]*Portal, error) { } func (p *Portal) Scan(row dbutil.Scannable) (*Portal, error) { - var mxid, name, topic, avatarHash, avatarURL, relayUserID sql.NullString - var expirationTime sql.NullInt64 + var mxid sql.NullString err := row.Scan( &p.ChatID, &p.Receiver, &mxid, - &name, - &topic, - &avatarHash, - &avatarURL, + &p.Name, + &p.Topic, + &p.AvatarHash, + &p.AvatarURL, &p.NameSet, &p.AvatarSet, &p.Revision, &p.Encrypted, - &relayUserID, - &expirationTime, + &p.RelayUserID, + &p.ExpirationTime, ) if err != nil { return nil, err } p.MXID = id.RoomID(mxid.String) - p.Name = name.String - p.Topic = topic.String - p.AvatarHash = avatarHash.String - p.RelayUserID = id.UserID(relayUserID.String) - p.ExpirationTime = int(expirationTime.Int64) - if len(avatarURL.String) > 0 { - parsedAvatarURL, err := id.ParseContentURI(avatarURL.String) - if err != nil { - return nil, fmt.Errorf("failed to parse avatar URL: %w", err) - } - p.AvatarURL = parsedAvatarURL - } return p, nil } @@ -149,7 +150,7 @@ func (p *Portal) sqlVariables() []any { p.Name, p.Topic, p.AvatarHash, - p.AvatarURL.String(), + p.AvatarURL, p.NameSet, p.AvatarSet, p.Revision, diff --git a/database/puppet.go b/database/puppet.go index d90688c5..ca19a6df 100644 --- a/database/puppet.go +++ b/database/puppet.go @@ -19,8 +19,8 @@ package database import ( "context" "database/sql" - "fmt" + "github.com/google/uuid" "maunium.net/go/mautrix/id" "go.mau.fi/util/dbutil" @@ -62,7 +62,7 @@ type PuppetQuery struct { type Puppet struct { qh *dbutil.QueryHelper[*Puppet] - SignalID string // TODO change to uuid.UUID + SignalID uuid.UUID Number string Name string NameQuality int @@ -82,7 +82,7 @@ func newPuppet(qh *dbutil.QueryHelper[*Puppet]) *Puppet { return &Puppet{qh: qh} } -func (pq *PuppetQuery) GetBySignalID(ctx context.Context, signalID string) (*Puppet, error) { +func (pq *PuppetQuery) GetBySignalID(ctx context.Context, signalID uuid.UUID) (*Puppet, error) { return pq.QueryOne(ctx, getPuppetBySignalIDQuery, signalID) } @@ -99,37 +99,26 @@ func (pq *PuppetQuery) GetAllWithCustomMXID(ctx context.Context) ([]*Puppet, err } func (p *Puppet) Scan(row dbutil.Scannable) (*Puppet, error) { - var number, name, avatarHash, avatarURL, customMXID, accessToken sql.NullString + var number, customMXID sql.NullString err := row.Scan( &p.SignalID, &number, - &name, + &p.Name, &p.NameQuality, - &avatarHash, - &avatarURL, + &p.AvatarHash, + &p.AvatarURL, &p.NameSet, &p.AvatarSet, &p.ContactInfoSet, &p.IsRegistered, &customMXID, - &accessToken, + &p.AccessToken, ) if err != nil { return nil, nil } - if len(avatarURL.String) > 0 { - parsedAvatarURL, err := id.ParseContentURI(avatarURL.String) - if err != nil { - return nil, fmt.Errorf("failed to parse avatar URL: %w", err) - } - p.AvatarURL = parsedAvatarURL - } - p.Number = number.String - p.Name = name.String - p.AvatarHash = avatarHash.String p.CustomMXID = id.UserID(customMXID.String) - p.AccessToken = accessToken.String return p, nil } @@ -140,12 +129,12 @@ func (p *Puppet) sqlVariables() []any { p.Name, p.NameQuality, p.AvatarHash, - p.AvatarURL.String(), + p.AvatarURL, p.NameSet, p.AvatarSet, p.ContactInfoSet, p.IsRegistered, - p.CustomMXID.String(), + dbutil.StrPtr(p.CustomMXID), p.AccessToken, } } diff --git a/database/reaction.go b/database/reaction.go index 57dbe8d2..d9239de4 100644 --- a/database/reaction.go +++ b/database/reaction.go @@ -19,6 +19,7 @@ package database import ( "context" + "github.com/google/uuid" "maunium.net/go/mautrix/id" "go.mau.fi/util/dbutil" @@ -47,13 +48,13 @@ func newReaction(qh *dbutil.QueryHelper[*Reaction]) *Reaction { type Reaction struct { qh *dbutil.QueryHelper[*Reaction] - MsgAuthor string // TODO change to uuid.UUID + MsgAuthor uuid.UUID MsgTimestamp uint64 - Author string // TODO change to uuid.UUID + Author uuid.UUID Emoji string SignalChatID string - SignalReceiver string // TODO change to uuid.UUID + SignalReceiver uuid.UUID MXID id.EventID RoomID id.RoomID @@ -63,7 +64,7 @@ func (rq *ReactionQuery) GetByMXID(ctx context.Context, mxid id.EventID) (*React return rq.QueryOne(ctx, getReactionByMXIDQuery, mxid) } -func (rq *ReactionQuery) GetBySignalID(ctx context.Context, msgAuthor string, msgTimestamp uint64, author, signalReceiver string) (*Reaction, error) { +func (rq *ReactionQuery) GetBySignalID(ctx context.Context, msgAuthor uuid.UUID, msgTimestamp uint64, author, signalReceiver uuid.UUID) (*Reaction, error) { return rq.QueryOne(ctx, getReactionBySignalIDQuery, msgAuthor, msgTimestamp, author, signalReceiver) } diff --git a/database/upgrades/00-latest.sql b/database/upgrades/00-latest.sql index 4334f4a2..b62643ab 100644 --- a/database/upgrades/00-latest.sql +++ b/database/upgrades/00-latest.sql @@ -1,45 +1,52 @@ -- v0 -> v17: Latest revision CREATE TABLE portal ( - chat_id TEXT, -- TODO NOT NULL - receiver TEXT, -- TODO NOT NULL - mxid TEXT, -- TODO UNIQUE constraint - name TEXT, -- TODO NOT NULL - topic TEXT, -- TODO NOT NULL + chat_id TEXT NOT NULL, + receiver uuid NOT NULL, + mxid TEXT, + name TEXT NOT NULL, + topic TEXT NOT NULL, encrypted BOOLEAN NOT NULL DEFAULT false, - avatar_hash TEXT, -- TODO NOT NULL - avatar_url TEXT, -- TODO NOT NULL + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, name_set BOOLEAN NOT NULL DEFAULT false, avatar_set BOOLEAN NOT NULL DEFAULT false, revision INTEGER NOT NULL DEFAULT 0, - expiration_time BIGINT, -- TODO NOT NULL - relay_user_id TEXT, -- TODO NOT NULL - PRIMARY KEY (chat_id, receiver) + expiration_time BIGINT NOT NULL, + relay_user_id TEXT NOT NULL, + + PRIMARY KEY (chat_id, receiver), + CONSTRAINT portal_mxid_unique UNIQUE(mxid) ); CREATE TABLE puppet ( - uuid UUID PRIMARY KEY, - number TEXT UNIQUE, - name TEXT, -- TODO NOT NULL - name_quality INTEGER NOT NULL DEFAULT 0, - avatar_hash TEXT, -- TODO NOT NULL - avatar_url TEXT, -- TODO NOT NULL + uuid uuid PRIMARY KEY, + number TEXT UNIQUE, + name TEXT NOT NULL, + name_quality INTEGER NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, name_set BOOLEAN NOT NULL DEFAULT false, avatar_set BOOLEAN NOT NULL DEFAULT false, - is_registered BOOLEAN NOT NULL DEFAULT false, + is_registered BOOLEAN NOT NULL DEFAULT false, + contact_info_set BOOLEAN NOT NULL DEFAULT false, + + custom_mxid TEXT, + access_token TEXT NOT NULL, - custom_mxid TEXT, -- TODO UNIQUE? - access_token TEXT, - contact_info_set BOOLEAN NOT NULL DEFAULT false + CONSTRAINT puppet_custom_mxid_unique UNIQUE(custom_mxid) ); CREATE TABLE "user" ( - mxid TEXT PRIMARY KEY, - username TEXT, -- TODO rename to phone? - uuid UUID, -- TODO UNIQUE constraint - management_room TEXT -- TODO NOT NULL? + mxid TEXT PRIMARY KEY, + uuid uuid, + phone TEXT, + + management_room TEXT, + + CONSTRAINT user_uuid_unique UNIQUE(uuid) ); CREATE TABLE message ( @@ -48,13 +55,14 @@ CREATE TABLE message ( part_index INTEGER NOT NULL, signal_chat_id TEXT NOT NULL, - signal_receiver TEXT NOT NULL, + signal_receiver uuid NOT NULL, mxid TEXT NOT NULL, mx_room TEXT NOT NULL, PRIMARY KEY (sender, timestamp, part_index, signal_receiver), - FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver) ON DELETE CASCADE, + CONSTRAINT message_portal_fkey FOREIGN KEY (signal_chat_id, signal_receiver) + REFERENCES portal(chat_id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE, CONSTRAINT message_mxid_unique UNIQUE (mxid) ); @@ -69,23 +77,21 @@ CREATE TABLE reaction ( emoji TEXT NOT NULL, signal_chat_id TEXT NOT NULL, - signal_receiver TEXT NOT NULL, + signal_receiver uuid NOT NULL, mxid TEXT NOT NULL, mx_room TEXT NOT NULL, PRIMARY KEY (msg_author, msg_timestamp, author, signal_receiver), CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) - REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE, + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE, FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE, CONSTRAINT reaction_mxid_unique UNIQUE (mxid) ); CREATE TABLE disappearing_message ( - room_id TEXT, -- TODO NOT NULL - mxid TEXT, -- TODO NOT NULL - expiration_seconds BIGINT, -- TODO NOT NULL - expiration_ts BIGINT, -- TODO NOT NULL? - - PRIMARY KEY (room_id, mxid) -- TODO drop room_id + mxid TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + expiration_seconds BIGINT NOT NULL, + expiration_ts BIGINT ); diff --git a/database/upgrades/16-message-primary-key.sql b/database/upgrades/16-message-primary-key.sql deleted file mode 100644 index ab2d8a78..00000000 --- a/database/upgrades/16-message-primary-key.sql +++ /dev/null @@ -1,28 +0,0 @@ --- v16: Update message table (Postgres) --- only: postgres - --- Drop constraints so we can fix timestamps. -ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; -ALTER TABLE message DROP CONSTRAINT message_pkey; - --- Add part index to message and fix the hacky timestamps -ALTER TABLE message ADD COLUMN part_index INTEGER; -UPDATE message - SET timestamp=CASE WHEN timestamp > 1500000000000000 THEN timestamp / 1000 ELSE timestamp END, - part_index=CASE WHEN timestamp > 1500000000000000 THEN timestamp % 1000 ELSE 0 END; -ALTER TABLE message ALTER COLUMN part_index SET NOT NULL; -ALTER TABLE reaction ADD COLUMN _part_index INTEGER NOT NULL DEFAULT 0; - --- Re-add the dropped constraints (but with part index and no chat) -ALTER TABLE message ADD PRIMARY KEY (sender, timestamp, part_index, signal_receiver); -ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) - REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE; --- Also update the reaction primary key -ALTER TABLE reaction DROP CONSTRAINT reaction_pkey; -ALTER TABLE reaction ADD PRIMARY KEY (author, msg_author, msg_timestamp, signal_receiver); - --- Change unique constraint from (mxid, mx_room) to just mxid. -ALTER TABLE message DROP CONSTRAINT message_mxid_mx_room_key; -ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (mxid); -ALTER TABLE reaction DROP CONSTRAINT reaction_mxid_mx_room_key; -ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (mxid); diff --git a/database/upgrades/16-refactor-postgres.sql b/database/upgrades/16-refactor-postgres.sql new file mode 100644 index 00000000..4b46d163 --- /dev/null +++ b/database/upgrades/16-refactor-postgres.sql @@ -0,0 +1,92 @@ +-- v16: Refactor types (Postgres) +-- only: postgres + +-- Drop constraints so we can fix timestamps. +ALTER TABLE reaction DROP CONSTRAINT reaction_message_fkey; +ALTER TABLE message DROP CONSTRAINT message_pkey; + +-- Add part index to message and fix the hacky timestamps +ALTER TABLE message ADD COLUMN part_index INTEGER; +UPDATE message + SET timestamp=CASE WHEN timestamp > 1500000000000000 THEN timestamp / 1000 ELSE timestamp END, + part_index=CASE WHEN timestamp > 1500000000000000 THEN timestamp % 1000 ELSE 0 END; +ALTER TABLE message ALTER COLUMN part_index SET NOT NULL; +ALTER TABLE reaction ADD COLUMN _part_index INTEGER NOT NULL DEFAULT 0; + +-- Re-add the dropped constraints (but with part index and no chat) +ALTER TABLE message ADD PRIMARY KEY (sender, timestamp, part_index, signal_receiver); +ALTER TABLE message DROP CONSTRAINT message_signal_chat_id_signal_receiver_fkey; +ALTER TABLE message ADD CONSTRAINT message_portal_fkey + FOREIGN KEY (signal_chat_id, signal_receiver) + REFERENCES portal (chat_id, receiver) + ON DELETE CASCADE ON UPDATE CASCADE; +ALTER TABLE reaction ADD CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE; +-- Also update the reaction primary key +ALTER TABLE reaction DROP CONSTRAINT reaction_pkey; +ALTER TABLE reaction ADD PRIMARY KEY (author, msg_author, msg_timestamp, signal_receiver); + +-- Change unique constraint from (mxid, mx_room) to just mxid. +ALTER TABLE message DROP CONSTRAINT message_mxid_mx_room_key; +ALTER TABLE message ADD CONSTRAINT message_mxid_unique UNIQUE (mxid); +ALTER TABLE reaction DROP CONSTRAINT reaction_mxid_mx_room_key; +ALTER TABLE reaction ADD CONSTRAINT reaction_mxid_unique UNIQUE (mxid); + +CREATE TABLE lost_portals ( + mxid TEXT PRIMARY KEY, + chat_id TEXT, + receiver TEXT +); +INSERT INTO lost_portals SELECT mxid, chat_id, receiver FROM portal WHERE mxid<>''; + +-- Make mxid column unique (requires using nulls for missing values) +UPDATE portal SET mxid=NULL WHERE mxid=''; +ALTER TABLE portal ADD CONSTRAINT portal_mxid_unique UNIQUE(mxid); +-- Delete any portals that aren't associated with logged-in users. +DELETE FROM portal WHERE receiver NOT IN (SELECT username FROM "user" WHERE uuid<>''); +-- Change receiver to uuid instead of phone number, also add nil uuid for groups. +UPDATE portal SET receiver=(SELECT uuid FROM "user" WHERE username=receiver) WHERE receiver<>''; +UPDATE portal SET receiver='00000000-0000-0000-0000-000000000000' WHERE receiver=''; +ALTER TABLE portal ALTER COLUMN receiver TYPE uuid USING receiver::uuid; +-- Delete group v1 portal entries +DELETE FROM portal WHERE chat_id NOT LIKE '________-____-____-____-____________' AND LENGTH(chat_id) <> 44; +DELETE FROM lost_portals WHERE mxid IN (SELECT mxid FROM portal WHERE mxid<>''); + +-- Remove unnecessary nullables in portal +UPDATE portal SET name='' WHERE name IS NULL; +UPDATE portal SET topic='' WHERE topic IS NULL; +UPDATE portal SET avatar_hash='' WHERE avatar_hash IS NULL; +UPDATE portal SET avatar_url='' WHERE avatar_url IS NULL; +UPDATE portal SET expiration_time=0 WHERE expiration_time IS NULL; +UPDATE portal SET relay_user_id='' WHERE relay_user_id IS NULL; +ALTER TABLE portal ALTER COLUMN name SET NOT NULL; +ALTER TABLE portal ALTER COLUMN topic SET NOT NULL; +ALTER TABLE portal ALTER COLUMN avatar_hash SET NOT NULL; +ALTER TABLE portal ALTER COLUMN avatar_url SET NOT NULL; +ALTER TABLE portal ALTER COLUMN expiration_time SET NOT NULL; +ALTER TABLE portal ALTER COLUMN relay_user_id SET NOT NULL; + +-- Add unique constraint to custom_mxid +UPDATE puppet SET custom_mxid=NULL WHERE custom_mxid=''; +ALTER TABLE puppet ADD CONSTRAINT puppet_custom_mxid_unique UNIQUE(custom_mxid); +-- Remove unnecessary nullables in puppet +UPDATE puppet SET name='' WHERE name IS NULL; +UPDATE puppet SET avatar_hash='' WHERE avatar_hash IS NULL; +UPDATE puppet SET avatar_url='' WHERE avatar_url IS NULL; +UPDATE puppet SET access_token='' WHERE access_token IS NULL; +ALTER TABLE puppet ALTER COLUMN name SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN avatar_hash SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN avatar_url SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN access_token SET NOT NULL; +ALTER TABLE puppet ALTER COLUMN name_quality DROP DEFAULT; + +ALTER TABLE "user" ADD CONSTRAINT user_uuid_unique UNIQUE(uuid); +ALTER TABLE "user" RENAME COLUMN username TO phone; + +-- Drop room_id from disappearing message primary key +ALTER TABLE disappearing_message DROP CONSTRAINT disappearing_message_pkey; +ALTER TABLE disappearing_message ADD PRIMARY KEY (mxid); +-- Remove unnecessary nullables in disappearing_message +ALTER TABLE disappearing_message ALTER COLUMN room_id SET NOT NULL; +UPDATE disappearing_message SET expiration_seconds=0 WHERE expiration_seconds IS NULL; +ALTER TABLE disappearing_message ALTER COLUMN expiration_seconds SET NOT NULL; diff --git a/database/upgrades/17-message-primary-key.sql b/database/upgrades/17-message-primary-key.sql deleted file mode 100644 index 2dcbea9a..00000000 --- a/database/upgrades/17-message-primary-key.sql +++ /dev/null @@ -1,80 +0,0 @@ --- v17: Update message table (SQLite) --- transaction: off --- only: sqlite - --- This is separate from v16 so that postgres can run with transaction: on --- (split upgrades by dialect don't currently allow disabling transaction in only one dialect) - -PRAGMA foreign_keys = OFF; -BEGIN; - -CREATE TABLE message_new ( - sender uuid NOT NULL, - timestamp BIGINT NOT NULL, - part_index INTEGER NOT NULL, - - signal_chat_id TEXT NOT NULL, - signal_receiver TEXT NOT NULL, - - mxid TEXT NOT NULL, - mx_room TEXT NOT NULL, - - PRIMARY KEY (sender, timestamp, part_index, signal_receiver), - FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver) ON DELETE CASCADE, - FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE, - CONSTRAINT message_mxid_unique UNIQUE (mxid) -); - -CREATE TABLE reaction_new ( - msg_author uuid NOT NULL, - msg_timestamp BIGINT NOT NULL, - -- part_index is not used in reactions, but is required for the foreign key. - _part_index INTEGER NOT NULL DEFAULT 0, - - author uuid NOT NULL, - emoji TEXT NOT NULL, - - signal_chat_id TEXT NOT NULL, - signal_receiver TEXT NOT NULL, - - mxid TEXT NOT NULL, - mx_room TEXT NOT NULL, - - PRIMARY KEY (msg_author, msg_timestamp, author, signal_receiver), - CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) - REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE, - FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE, - CONSTRAINT reaction_mxid_unique UNIQUE (mxid) -); - - -INSERT INTO message_new -SELECT sender, - CASE WHEN timestamp > 1500000000000000 THEN timestamp / 1000 ELSE timestamp END, - CASE WHEN timestamp > 1500000000000000 THEN timestamp % 1000 ELSE 0 END, - mxid, - mx_room, - COALESCE(signal_chat_id, ''), - COALESCE(signal_receiver, '') -FROM message; - -INSERT INTO reaction_new -SELECT msg_author, - msg_timestamp, - 0, -- _part_index - author, - emoji, - mxid, - mx_room, - COALESCE(signal_chat_id, ''), - COALESCE(signal_receiver, '') -FROM reaction; - -DROP TABLE message; -DROP TABLE reaction; -ALTER TABLE message_new RENAME TO message; -ALTER TABLE reaction_new RENAME TO reaction; - -PRAGMA foreign_key_check; -COMMIT; -PRAGMA foreign_keys = ON; diff --git a/database/upgrades/17-refactor-sqlite.sql b/database/upgrades/17-refactor-sqlite.sql new file mode 100644 index 00000000..4bdd1402 --- /dev/null +++ b/database/upgrades/17-refactor-sqlite.sql @@ -0,0 +1,187 @@ +-- v17: Refactor types (SQLite) +-- transaction: off +-- only: sqlite + +-- This is separate from v16 so that postgres can run with transaction: on +-- (split upgrades by dialect don't currently allow disabling transaction in only one dialect) + +PRAGMA foreign_keys = OFF; +BEGIN; + +CREATE TABLE message_new ( + sender uuid NOT NULL, + timestamp BIGINT NOT NULL, + part_index INTEGER NOT NULL, + + signal_chat_id TEXT NOT NULL, + signal_receiver TEXT NOT NULL, + + mxid TEXT NOT NULL, + mx_room TEXT NOT NULL, + + PRIMARY KEY (sender, timestamp, part_index, signal_receiver), + CONSTRAINT message_portal_fkey FOREIGN KEY (signal_chat_id, signal_receiver) REFERENCES portal(chat_id, receiver) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (sender) REFERENCES puppet(uuid) ON DELETE CASCADE, + CONSTRAINT message_mxid_unique UNIQUE (mxid) +); + +CREATE TABLE reaction_new ( + msg_author uuid NOT NULL, + msg_timestamp BIGINT NOT NULL, + -- part_index is not used in reactions, but is required for the foreign key. + _part_index INTEGER NOT NULL DEFAULT 0, + + author uuid NOT NULL, + emoji TEXT NOT NULL, + + signal_chat_id TEXT NOT NULL, + signal_receiver TEXT NOT NULL, + + mxid TEXT NOT NULL, + mx_room TEXT NOT NULL, + + PRIMARY KEY (msg_author, msg_timestamp, author, signal_receiver), + CONSTRAINT reaction_message_fkey FOREIGN KEY (msg_author, msg_timestamp, _part_index, signal_receiver) + REFERENCES message (sender, timestamp, part_index, signal_receiver) ON DELETE CASCADE ON UPDATE CASCADE, + FOREIGN KEY (author) REFERENCES puppet(uuid) ON DELETE CASCADE, + CONSTRAINT reaction_mxid_unique UNIQUE (mxid) +); + + +INSERT INTO message_new +SELECT sender, + CASE WHEN timestamp > 1500000000000000 THEN timestamp / 1000 ELSE timestamp END, + CASE WHEN timestamp > 1500000000000000 THEN timestamp % 1000 ELSE 0 END, + mxid, + mx_room, + COALESCE(signal_chat_id, ''), + COALESCE(signal_receiver, '') +FROM message; + +INSERT INTO reaction_new +SELECT msg_author, + msg_timestamp, + 0, -- _part_index + author, + emoji, + mxid, + mx_room, + COALESCE(signal_chat_id, ''), + COALESCE(signal_receiver, '') +FROM reaction; + +DROP TABLE message; +DROP TABLE reaction; +ALTER TABLE message_new RENAME TO message; +ALTER TABLE reaction_new RENAME TO reaction; + +PRAGMA foreign_key_check; +COMMIT; + +PRAGMA foreign_keys = ON; + +BEGIN; +CREATE TABLE lost_portals ( + mxid TEXT PRIMARY KEY, + chat_id TEXT, + receiver TEXT +); +INSERT INTO lost_portals SELECT mxid, chat_id, receiver FROM portal WHERE mxid<>''; +DELETE FROM portal WHERE receiver NOT IN (SELECT username FROM "user" WHERE uuid<>''); +UPDATE portal SET receiver=(SELECT uuid FROM "user" WHERE username=receiver) WHERE receiver<>''; +UPDATE portal SET receiver='00000000-0000-0000-0000-000000000000' WHERE receiver=''; +DELETE FROM portal WHERE chat_id NOT LIKE '________-____-____-____-____________' AND LENGTH(chat_id) <> 44; +DELETE FROM lost_portals WHERE mxid IN (SELECT mxid FROM portal WHERE mxid<>''); +COMMIT; + +PRAGMA foreign_keys = OFF; + +BEGIN; + +CREATE TABLE portal_new ( + chat_id TEXT NOT NULL, + receiver uuid NOT NULL, + mxid TEXT, + name TEXT NOT NULL, + topic TEXT NOT NULL, + encrypted BOOLEAN NOT NULL DEFAULT false, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, + name_set BOOLEAN NOT NULL DEFAULT false, + avatar_set BOOLEAN NOT NULL DEFAULT false, + revision INTEGER NOT NULL DEFAULT 0, + + expiration_time BIGINT NOT NULL, + relay_user_id TEXT NOT NULL, + + PRIMARY KEY (chat_id, receiver), + CONSTRAINT portal_mxid_unique UNIQUE(mxid) +); + +INSERT INTO portal_new + SELECT chat_id, receiver, CASE WHEN mxid='' THEN NULL ELSE mxid END, + COALESCE(name, ''), COALESCE(topic, ''), encrypted, COALESCE(avatar_hash, ''), COALESCE(avatar_url, ''), + name_set, avatar_set, revision, COALESCE(expiration_time, 0), COALESCE(relay_user_id, '') + FROM portal; +DROP TABLE portal; +ALTER TABLE portal_new RENAME TO portal; + +CREATE TABLE puppet_new ( + uuid uuid PRIMARY KEY, + number TEXT UNIQUE, + name TEXT NOT NULL, + name_quality INTEGER NOT NULL, + avatar_hash TEXT NOT NULL, + avatar_url TEXT NOT NULL, + name_set BOOLEAN NOT NULL DEFAULT false, + avatar_set BOOLEAN NOT NULL DEFAULT false, + + is_registered BOOLEAN NOT NULL DEFAULT false, + contact_info_set BOOLEAN NOT NULL DEFAULT false, + + custom_mxid TEXT, + access_token TEXT NOT NULL, + + CONSTRAINT puppet_custom_mxid_unique UNIQUE(custom_mxid) +); + +INSERT INTO puppet_new + SELECT uuid, number, COALESCE(name, ''), COALESCE(name_quality, 0), COALESCE(avatar_hash, ''), + COALESCE(avatar_url, ''), name_set, avatar_set, is_registered, contact_info_set, + CASE WHEN custom_mxid='' THEN NULL ELSE custom_mxid END, COALESCE(access_token, '') + FROM puppet; +DROP TABLE puppet; +ALTER TABLE puppet_new RENAME TO puppet; + +CREATE TABLE user_new ( + mxid TEXT PRIMARY KEY, + uuid uuid, + phone TEXT, + + management_room TEXT, + + CONSTRAINT user_uuid_unique UNIQUE(uuid) +); + +INSERT INTO user_new + SELECT mxid, uuid, phone, management_room + FROM user; +DROP TABLE user; +ALTER TABLE user_new RENAME TO user; + +CREATE TABLE disappearing_message_new ( + mxid TEXT NOT NULL PRIMARY KEY, + room_id TEXT NOT NULL, + expiration_seconds BIGINT NOT NULL, + expiration_ts BIGINT +); + +INSERT INTO disappearing_message_new + SELECT mxid, room_id, COALESCE(expiration_seconds, 0), expiration_ts + FROM disappearing_message; +DROP TABLE disappearing_message; +ALTER TABLE disappearing_message_new RENAME TO disappearing_message; + +PRAGMA foreign_key_check; +COMMIT; +PRAGMA foreign_keys = ON; diff --git a/database/user.go b/database/user.go index 16eb6471..a97cf852 100644 --- a/database/user.go +++ b/database/user.go @@ -20,6 +20,7 @@ import ( "context" "database/sql" + "github.com/google/uuid" "maunium.net/go/mautrix/id" "go.mau.fi/util/dbutil" @@ -40,8 +41,8 @@ type User struct { qh *dbutil.QueryHelper[*User] MXID id.UserID - SignalUsername string // TODO rename to phone - SignalID string // TODO change to *uuid.UUID + SignalUsername string + SignalID uuid.UUID ManagementRoom id.RoomID } @@ -57,7 +58,7 @@ func (uq *UserQuery) GetByPhone(ctx context.Context, phone string) (*User, error return uq.QueryOne(ctx, getUserByPhoneQuery, phone) } -func (uq *UserQuery) GetBySignalID(ctx context.Context, uuid string) (*User, error) { +func (uq *UserQuery) GetBySignalID(ctx context.Context, uuid uuid.UUID) (*User, error) { return uq.QueryOne(ctx, getUserByUUIDQuery, uuid) } @@ -66,7 +67,10 @@ func (uq *UserQuery) GetAllLoggedIn(ctx context.Context) ([]*User, error) { } func (u *User) sqlVariables() []any { - return []any{u.MXID, dbutil.StrPtr(u.SignalUsername), dbutil.StrPtr(u.SignalID), dbutil.StrPtr(u.ManagementRoom)} + var nu uuid.NullUUID + nu.UUID = u.SignalID + nu.Valid = u.SignalID != uuid.Nil + return []any{u.MXID, dbutil.StrPtr(u.SignalUsername), nu, dbutil.StrPtr(u.ManagementRoom)} } func (u *User) Insert(ctx context.Context) error { @@ -80,18 +84,19 @@ func (u *User) Update(ctx context.Context) error { } func (u *User) Scan(row dbutil.Scannable) (*User, error) { - var username, managementRoom, signalID sql.NullString + var username, managementRoom sql.NullString + var signalID uuid.NullUUID err := row.Scan( &u.MXID, &username, &signalID, - &managementRoom, + &u.ManagementRoom, ) if err != nil { return nil, err } u.SignalUsername = username.String - u.SignalID = signalID.String + u.SignalID = signalID.UUID u.ManagementRoom = id.RoomID(managementRoom.String) return u, nil } diff --git a/disappearing.go b/disappearing.go index 6fd420de..8bedbbb6 100644 --- a/disappearing.go +++ b/disappearing.go @@ -126,15 +126,15 @@ func (dmm *DisappearingMessagesManager) redactExpiredMessages(ctx context.Contex } } -func (dmm *DisappearingMessagesManager) AddDisappearingMessage(ctx context.Context, eventID id.EventID, roomID id.RoomID, expireInSeconds int64, startTimerNow bool) { - if expireInSeconds == 0 { +func (dmm *DisappearingMessagesManager) AddDisappearingMessage(ctx context.Context, eventID id.EventID, roomID id.RoomID, expireIn time.Duration, startTimerNow bool) { + if expireIn == 0 { return } var expireAt time.Time if startTimerNow { - expireAt = time.Now().Add(time.Duration(expireInSeconds) * time.Second) + expireAt = time.Now().Add(expireIn) } - disappearingMessage := dmm.DB.DisappearingMessage.NewWithValues(roomID, eventID, expireInSeconds, expireAt) + disappearingMessage := dmm.DB.DisappearingMessage.NewWithValues(roomID, eventID, expireIn, expireAt) err := disappearingMessage.Insert(ctx) if err != nil { zerolog.Ctx(ctx).Err(err).Str("event_id", eventID.String()). diff --git a/main.go b/main.go index 9153f226..8c97d8c8 100644 --- a/main.go +++ b/main.go @@ -23,6 +23,7 @@ import ( "os" "sync" + "github.com/google/uuid" "github.com/rs/zerolog" "maunium.net/go/mautrix" "maunium.net/go/mautrix/bridge" @@ -63,7 +64,7 @@ type SignalBridge struct { provisioning *ProvisioningAPI usersByMXID map[id.UserID]*User - usersBySignalID map[string]*User + usersBySignalID map[uuid.UUID]*User usersLock sync.Mutex managementRooms map[id.RoomID]*User @@ -73,7 +74,7 @@ type SignalBridge struct { portalsByID map[database.PortalKey]*Portal portalsLock sync.Mutex - puppets map[string]*Puppet + puppets map[uuid.UUID]*Puppet puppetsByCustomMXID map[id.UserID]*Puppet puppetsByNumber map[string]*Puppet puppetsLock sync.Mutex @@ -119,12 +120,12 @@ func (br *SignalBridge) Init() { br.MatrixHandler.TrackEventDuration = br.Metrics.TrackMatrixEvent signalFormatParams = &signalfmt.FormatParams{ - GetUserInfo: func(uuid string) signalfmt.UserInfo { - puppet := br.GetPuppetBySignalID(uuid) + GetUserInfo: func(u uuid.UUID) signalfmt.UserInfo { + puppet := br.GetPuppetBySignalID(u) if puppet == nil { return signalfmt.UserInfo{} } - user := br.GetUserBySignalID(uuid) + user := br.GetUserBySignalID(u) if user != nil { return signalfmt.UserInfo{ MXID: user.MXID, @@ -138,17 +139,17 @@ func (br *SignalBridge) Init() { }, } matrixFormatParams = &matrixfmt.HTMLParser{ - GetUUIDFromMXID: func(userID id.UserID) string { + GetUUIDFromMXID: func(userID id.UserID) uuid.UUID { parsed, ok := br.ParsePuppetMXID(userID) if ok { return parsed } // TODO only get if exists user := br.GetUserByMXID(userID) - if user != nil && user.SignalID != "" { + if user != nil && user.SignalID != uuid.Nil { return user.SignalID } - return "" + return uuid.Nil }, } @@ -213,7 +214,7 @@ func (br *SignalBridge) CreatePrivatePortal(roomID id.RoomID, brInviter bridge.U br.Log.Debugln("CreatePrivatePortal", roomID, brInviter, brGhost) inviter := brInviter.(*User) puppet := brGhost.(*Puppet) - key := database.NewPortalKey(inviter.SignalID, puppet.SignalID) + key := database.NewPortalKey(inviter.SignalID.String(), puppet.SignalID) portal := br.GetPortalByChatID(key) if len(portal.MXID) == 0 { @@ -293,14 +294,14 @@ func (br *SignalBridge) createPrivatePortalFromInvite(roomID id.RoomID, inviter func main() { br := &SignalBridge{ usersByMXID: make(map[id.UserID]*User), - usersBySignalID: make(map[string]*User), + usersBySignalID: make(map[uuid.UUID]*User), managementRooms: make(map[id.RoomID]*User), portalsByMXID: make(map[id.RoomID]*Portal), portalsByID: make(map[database.PortalKey]*Portal), - puppets: make(map[string]*Puppet), + puppets: make(map[uuid.UUID]*Puppet), puppetsByCustomMXID: make(map[id.UserID]*Puppet), puppetsByNumber: make(map[string]*Puppet), } diff --git a/msgconv/matrixfmt/convert_test.go b/msgconv/matrixfmt/convert_test.go index 719e4be0..0ee25141 100644 --- a/msgconv/matrixfmt/convert_test.go +++ b/msgconv/matrixfmt/convert_test.go @@ -4,6 +4,7 @@ import ( "fmt" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" @@ -13,11 +14,11 @@ import ( ) var formatParams = &matrixfmt.HTMLParser{ - GetUUIDFromMXID: func(id id.UserID) string { + GetUUIDFromMXID: func(id id.UserID) uuid.UUID { if id.Homeserver() == "signal" { - return id.Localpart() + return uuid.MustParse(id.Localpart()) } - return "" + return uuid.Nil }, } diff --git a/msgconv/matrixfmt/html.go b/msgconv/matrixfmt/html.go index 28daec6c..074caff5 100644 --- a/msgconv/matrixfmt/html.go +++ b/msgconv/matrixfmt/html.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" + "github.com/google/uuid" "golang.org/x/exp/slices" "golang.org/x/net/html" "maunium.net/go/mautrix/event" @@ -223,7 +224,7 @@ func (ctx Context) WithWhitespace() Context { // HTMLParser is a somewhat customizable Matrix HTML parser. type HTMLParser struct { - GetUUIDFromMXID func(id.UserID) string + GetUUIDFromMXID func(id.UserID) uuid.UUID } // TaggedString is a string that also contains a HTML tag. @@ -355,8 +356,8 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) *EntityStri // Mention not allowed, use name as-is return str } - uuid := parser.GetUUIDFromMXID(mxid) - if uuid == "" { + u := parser.GetUUIDFromMXID(mxid) + if u == uuid.Nil { // Don't include the link for mentions of non-Signal users, the name is enough return str } @@ -365,7 +366,7 @@ func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) *EntityStri MXID: mxid, Name: str.String.String(), }, - UUID: uuid, + UUID: u, }) } if str.String.String() == href { diff --git a/msgconv/signalfmt/convert.go b/msgconv/signalfmt/convert.go index d88ab747..317374d8 100644 --- a/msgconv/signalfmt/convert.go +++ b/msgconv/signalfmt/convert.go @@ -20,6 +20,7 @@ import ( "html" "strings" + "github.com/google/uuid" "golang.org/x/exp/maps" "golang.org/x/exp/slices" "maunium.net/go/mautrix/event" @@ -34,7 +35,7 @@ type UserInfo struct { } type FormatParams struct { - GetUserInfo func(uuid string) UserInfo + GetUserInfo func(uuid uuid.UUID) UserInfo } type formatContext struct { @@ -87,7 +88,11 @@ func Parse(message string, ranges []*signalpb.BodyRange, params *FormatParams) * case *signalpb.BodyRange_Style_: br.Value = Style(rv.Style) case *signalpb.BodyRange_MentionAci: - userInfo := params.GetUserInfo(rv.MentionAci) + parsed, err := uuid.Parse(rv.MentionAci) + if err != nil { + continue + } + userInfo := params.GetUserInfo(parsed) if userInfo.MXID == "" { continue } @@ -96,7 +101,7 @@ func Parse(message string, ranges []*signalpb.BodyRange, params *FormatParams) * // Maybe use NewUTF16String and do index replacements for the plaintext body too, // or just replace the plaintext body by parsing the generated HTML. content.Body = strings.Replace(content.Body, "\uFFFC", userInfo.Name, 1) - br.Value = Mention{UserInfo: userInfo, UUID: rv.MentionAci} + br.Value = Mention{UserInfo: userInfo, UUID: parsed} } lrt.Add(br) } diff --git a/msgconv/signalfmt/convert_test.go b/msgconv/signalfmt/convert_test.go index 1c066c63..eaed12dc 100644 --- a/msgconv/signalfmt/convert_test.go +++ b/msgconv/signalfmt/convert_test.go @@ -29,11 +29,11 @@ import ( signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" ) -var realUser = uuid.NewString() +var realUser = uuid.New() func TestParse(t *testing.T) { formatParams := &signalfmt.FormatParams{ - GetUserInfo: func(uuid string) signalfmt.UserInfo { + GetUserInfo: func(uuid uuid.UUID) signalfmt.UserInfo { if uuid == realUser { return signalfmt.UserInfo{ MXID: "@test:example.com", @@ -41,7 +41,7 @@ func TestParse(t *testing.T) { } } else { return signalfmt.UserInfo{ - MXID: id.UserID("@signal_" + uuid + ":example.com"), + MXID: id.UserID("@signal_" + uuid.String() + ":example.com"), Name: "Signal User", } } @@ -79,7 +79,7 @@ func TestParse(t *testing.T) { Start: proto.Uint32(6), Length: proto.Uint32(1), AssociatedValue: &signalpb.BodyRange_MentionAci{ - MentionAci: realUser, + MentionAci: realUser.String(), }, }}, body: "Hello Matrix User", diff --git a/msgconv/signalfmt/tags.go b/msgconv/signalfmt/tags.go index b81709f9..043bb437 100644 --- a/msgconv/signalfmt/tags.go +++ b/msgconv/signalfmt/tags.go @@ -19,6 +19,8 @@ package signalfmt import ( "fmt" + "github.com/google/uuid" + signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" ) @@ -30,7 +32,7 @@ type BodyRangeValue interface { type Mention struct { UserInfo - UUID string + UUID uuid.UUID } func (m Mention) String() string { @@ -39,7 +41,7 @@ func (m Mention) String() string { func (m Mention) Proto() signalpb.BodyRangeAssociatedValue { return &signalpb.BodyRange_MentionAci{ - MentionAci: m.UUID, + MentionAci: m.UUID.String(), } } diff --git a/pkg/signalmeow/sending.go b/pkg/signalmeow/sending.go index 08c4b17f..ddfc481b 100644 --- a/pkg/signalmeow/sending.go +++ b/pkg/signalmeow/sending.go @@ -26,6 +26,7 @@ import ( "strings" "time" + "github.com/google/uuid" "google.golang.org/protobuf/proto" "go.mau.fi/mautrix-signal/pkg/libsignalgo" @@ -435,14 +436,14 @@ func DataMessageForAttachment(attachmentPointer *AttachmentPointer, caption stri return wrapDataMessageInContent(dm) } -func DataMessageForReaction(reaction string, targetMessageSender string, targetMessageTimestamp uint64, removing bool) *SignalContent { +func DataMessageForReaction(reaction string, targetMessageSender uuid.UUID, targetMessageTimestamp uint64, removing bool) *SignalContent { timestamp := currentMessageTimestamp() dm := &signalpb.DataMessage{ Timestamp: ×tamp, Reaction: &signalpb.DataMessage_Reaction{ Emoji: proto.String(reaction), Remove: proto.Bool(removing), - TargetAuthorAci: proto.String(targetMessageSender), + TargetAuthorAci: proto.String(targetMessageSender.String()), TargetSentTimestamp: proto.Uint64(targetMessageTimestamp), }, } @@ -460,12 +461,12 @@ func DataMessageForDelete(targetMessageTimestamp uint64) *SignalContent { return wrapDataMessageInContent(dm) } -func AddQuoteToDataMessage(content *SignalContent, quotedMessageSender string, quotedMessageTimestamp uint64) { +func AddQuoteToDataMessage(content *SignalContent, quotedMessageSender uuid.UUID, quotedMessageTimestamp uint64) { // Note: We're supposed to send the quoted message content too as a fallback, // but it only seems to be necessary to quote image messages on iOS and Desktop. // Android seems to render every quote fine, and iOS and Desktop render text quotes fine. content.DataMessage.Quote = &signalpb.DataMessage_Quote{ - AuthorAci: proto.String(quotedMessageSender), + AuthorAci: proto.String(quotedMessageSender.String()), Id: proto.Uint64(quotedMessageTimestamp), Type: signalpb.DataMessage_Quote_NORMAL.Enum(), } @@ -543,7 +544,7 @@ func SendGroupMessage(ctx context.Context, device *Device, gid GroupIdentifier, return result, nil } -func SendMessage(ctx context.Context, device *Device, recipientUuid string, message *SignalContent) SendMessageResult { +func SendMessage(ctx context.Context, device *Device, recipientID string, message *SignalContent) SendMessageResult { // Assemble the content to send content := (*signalpb.Content)(message) dataMessage := content.DataMessage @@ -555,12 +556,12 @@ func SendMessage(ctx context.Context, device *Device, recipientUuid string, mess } // Send to the recipient - sentUnidentified, err := sendContent(ctx, device, recipientUuid, messageTimestamp, content, 0) + sentUnidentified, err := sendContent(ctx, device, recipientID, messageTimestamp, content, 0) if err != nil { return SendMessageResult{ WasSuccessful: false, FailedSendResult: &FailedSendResult{ - RecipientUuid: recipientUuid, + RecipientUuid: recipientID, Error: err, }, } @@ -568,7 +569,7 @@ func SendMessage(ctx context.Context, device *Device, recipientUuid string, mess result := SendMessageResult{ WasSuccessful: true, SuccessfulSendResult: &SuccessfulSendResult{ - RecipientUuid: recipientUuid, + RecipientUuid: recipientID, Unidentified: sentUnidentified, }, } @@ -585,7 +586,7 @@ func SendMessage(ctx context.Context, device *Device, recipientUuid string, mess syncContent = syncMessageFromSoloDataMessage(dataMessage, *result.SuccessfulSendResult) } if content.ReceiptMessage != nil && *content.ReceiptMessage.Type == signalpb.ReceiptMessage_READ { - syncContent = syncMessageFromReadReceiptMessage(content.ReceiptMessage, recipientUuid) + syncContent = syncMessageFromReadReceiptMessage(content.ReceiptMessage, recipientID) } if syncContent != nil { _, selfSendErr := sendContent(ctx, device, device.Data.AciUuid, messageTimestamp, syncContent, 0) diff --git a/portal.go b/portal.go index d1e673a9..fb8ecaba 100644 --- a/portal.go +++ b/portal.go @@ -139,7 +139,7 @@ func (portal *Portal) IsPrivateChat() bool { func (portal *Portal) MainIntent() *appservice.IntentAPI { if portal.IsPrivateChat() { - return portal.bridge.GetPuppetBySignalID(portal.ChatID).DefaultIntent() + return portal.bridge.GetPuppetBySignalID(portal.UserID()).DefaultIntent() } return portal.bridge.Bot @@ -873,7 +873,7 @@ func (portal *Portal) sendMessageStatusCheckpointFailed(evt *event.Event, err er func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { log := portal.log.With(). Str("action", "handle signal message"). - Str("sender", portalMessage.sender.SignalID). + Str("sender", portalMessage.sender.SignalID.String()). Uint64("timestamp", portalMessage.message.Base().Timestamp). Logger() ctx := log.WithContext(context.TODO()) @@ -961,7 +961,7 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { } } -func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID string, timestamp uint64) { +func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID uuid.UUID, timestamp uint64) { dbMessage := portal.bridge.DB.Message.New() dbMessage.MXID = eventID dbMessage.RoomID = portal.MXID @@ -978,8 +978,8 @@ func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, func (portal *Portal) storeReactionInDB( ctx context.Context, eventID id.EventID, - senderSignalID string, - msgAuthor string, + senderSignalID, + msgAuthor uuid.UUID, msgTimestamp uint64, emoji string, ) { @@ -1003,7 +1003,7 @@ func (portal *Portal) addSignalQuote(ctx context.Context, content *event.Message return } originalMessage, err := portal.bridge.DB.Message.GetBySignalID( - ctx, quote.QuotedSender, quote.QuotedTimestamp, 0, portal.Receiver, + ctx, uuid.MustParse(quote.QuotedSender), quote.QuotedTimestamp, 0, portal.Receiver, ) if err != nil { zerolog.Ctx(ctx).Err(err).Str("quoted_sender", quote.QuotedSender).Uint64("quoted_timestamp", quote.QuotedTimestamp).Msg("Failed to get quoted message from database") @@ -1028,7 +1028,7 @@ func (portal *Portal) addSignalQuote(ctx context.Context, content *event.Message } func (portal *Portal) addDisappearingMessage(ctx context.Context, eventID id.EventID, expireInSeconds int64, startTimerNow bool) { - portal.bridge.disappearingMessagesManager.AddDisappearingMessage(ctx, eventID, portal.MXID, expireInSeconds, startTimerNow) + portal.bridge.disappearingMessagesManager.AddDisappearingMessage(ctx, eventID, portal.MXID, time.Duration(expireInSeconds)*time.Second, startTimerNow) } var signalFormatParams *signalfmt.FormatParams @@ -1130,7 +1130,8 @@ func (portal *Portal) handleSignalUnhandledMessage(portalMessage portalSignalMes func (portal *Portal) handleSignalReceiptMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) { receiptMessage := (portalMessage.message).(signalmeow.IncomingSignalMessageReceipt) log := zerolog.Ctx(ctx) - messageSender := receiptMessage.OriginalSender + messageSender, err := uuid.Parse(receiptMessage.OriginalSender) + // TODO handle err timestamp := receiptMessage.OriginalTimestamp lastPart, err := portal.bridge.DB.Message.GetLastPartBySignalID(ctx, messageSender, timestamp, portal.Receiver) if err != nil { @@ -1306,13 +1307,13 @@ func (portal *Portal) HandleMatrixReadReceipt(sender bridge.User, eventID id.Eve // who sent the original message, not the portal's ChatID ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - result := signalmeow.SendMessage(ctx, receiptSender.SignalDevice, receiptDestination, msg) + result := signalmeow.SendMessage(ctx, receiptSender.SignalDevice, receiptDestination.String(), msg) if !result.WasSuccessful { log.Err(result.FailedSendResult.Error). - Str("receipt_destination", receiptDestination). + Str("receipt_destination", receiptDestination.String()). Msg("Failed to send read receipt to Signal") } else { - log.Debug().Str("receipt_destination", receiptDestination).Msg("Sent read receipt to Signal") + log.Debug().Str("receipt_destination", receiptDestination.String()).Msg("Sent read receipt to Signal") } } @@ -1386,7 +1387,11 @@ func (portal *Portal) handleSignalReactionMessage(ctx context.Context, portalMes Str("target_message_sender", msg.TargetAuthorUUID). Uint64("target_message_timestamp", msg.TargetMessageTimestamp). Msg("Received reaction from Signal") - dbMessage, err := portal.bridge.DB.Message.GetBySignalID(ctx, msg.TargetAuthorUUID, msg.TargetMessageTimestamp, 0, portal.Receiver) + parsedTargetAuthor, err := uuid.Parse(msg.TargetAuthorUUID) + // TODO handle err + senderUUID, err := uuid.Parse(msg.SenderUUID) + // TODO handle err + dbMessage, err := portal.bridge.DB.Message.GetBySignalID(ctx, parsedTargetAuthor, msg.TargetMessageTimestamp, 0, portal.Receiver) if err != nil { log.Err(err).Msg("Failed to get reaction target message") return @@ -1396,9 +1401,9 @@ func (portal *Portal) handleSignalReactionMessage(ctx context.Context, portalMes } existingReaction, err := portal.bridge.DB.Reaction.GetBySignalID( ctx, - msg.TargetAuthorUUID, + parsedTargetAuthor, msg.TargetMessageTimestamp, - msg.SenderUUID, + senderUUID, portal.Receiver, ) if err != nil { @@ -1443,7 +1448,7 @@ func (portal *Portal) handleSignalReactionMessage(ctx context.Context, portalMes ctx, resp.EventID, portalMessage.sender.SignalID, - msg.TargetAuthorUUID, + parsedTargetAuthor, dbMessage.Timestamp, msg.Emoji, // Store without variation selector, as they come from Signal ) @@ -1452,9 +1457,12 @@ func (portal *Portal) handleSignalReactionMessage(ctx context.Context, portalMes func (portal *Portal) handleSignalDeleteMessage(ctx context.Context, portalMessage portalSignalMessage, intent *appservice.IntentAPI) { msg := (portalMessage.message).(signalmeow.IncomingSignalMessageDelete) + senderUUID, err := uuid.Parse(msg.SenderUUID) + // TODO handle err + log := zerolog.Ctx(ctx) // Find the event ID of the message to delete - messages, err := portal.bridge.DB.Message.GetAllPartsBySignalID(ctx, msg.SenderUUID, msg.TargetMessageTimestamp, portal.Receiver) + messages, err := portal.bridge.DB.Message.GetAllPartsBySignalID(ctx, senderUUID, msg.TargetMessageTimestamp, portal.Receiver) if err != nil { log.Err(err).Msg("Failed to get messages to delete") return @@ -1580,32 +1588,6 @@ func (portal *Portal) sendMatrixEventContent(intent *appservice.IntentAPI, event } } -func (portal *Portal) getMessagePuppet(user *User, senderUUID string) (puppet *Puppet) { - if portal.IsPrivateChat() { - puppet = portal.bridge.GetPuppetBySignalID(portal.ChatID) - } else if senderUUID != "" { - puppet = portal.bridge.GetPuppetBySignalID(senderUUID) - } - if puppet == nil { - return nil - } - return puppet -} - -func (portal *Portal) getMessageIntent(user *User, senderUUID string) *appservice.IntentAPI { - puppet := portal.getMessagePuppet(user, senderUUID) - if puppet == nil { - portal.log.Debug().Msg("Not handling: puppet is nil") - return nil - } - intent := puppet.IntentFor(portal) - //if !intent.IsCustomPuppet && portal.IsPrivateChat() { //&& info.Sender.User == portal.Key.Receiver.User && portal.Key.Receiver != portal.Key.JID { - // portal.log.Debugfln("Not handling: user doesn't have double puppeting enabled") - // return nil - //} - return intent -} - func (portal *Portal) getEncryptionEventContent() (evt *event.EncryptionEventContent) { evt = &event.EncryptionEventContent{Algorithm: id.AlgorithmMegolmV1} if rot := portal.bridge.Config.Bridge.Encryption.Rotation; rot.EnableCustom { @@ -1798,8 +1780,8 @@ func (br *SignalBridge) GetPortalByChatID(key database.PortalKey) *Portal { br.portalsLock.Lock() defer br.portalsLock.Unlock() // If this PortalKey is for a group, Receiver should be empty - if !isUUID(key.ChatID) { - key.Receiver = "" + if key.UserID() == uuid.Nil { + key.Receiver = uuid.Nil } portal, ok := br.portalsByID[key] if !ok { diff --git a/provisioning.go b/provisioning.go index b726925d..1eb9b9b7 100644 --- a/provisioning.go +++ b/provisioning.go @@ -30,6 +30,7 @@ import ( "sync" "time" + "github.com/google/uuid" "github.com/gorilla/mux" "github.com/rs/zerolog" "maunium.net/go/mautrix/id" @@ -184,7 +185,7 @@ func (prov *ProvisioningAPI) resolveIdentifier(user *User, phoneNum string) (int } portal := user.GetPortalByChatID(contact.UUID) - puppet := prov.bridge.GetPuppetBySignalID(contact.UUID) + puppet := prov.bridge.GetPuppetBySignalIDString(contact.UUID) return http.StatusOK, &ResolveIdentifierResponse{ RoomID: portal.MXID.String(), @@ -511,7 +512,8 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ // Update user with SignalID if resp.ProvisioningData.AciUuid != "" { - user.SignalID = resp.ProvisioningData.AciUuid + user.SignalID, err = uuid.Parse(resp.ProvisioningData.AciUuid) + // TODO handle err user.SignalUsername = resp.ProvisioningData.Number err = user.Update(r.Context()) if err != nil { @@ -589,7 +591,7 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R jsonResponse(w, http.StatusOK, Response{ Success: true, Status: "prekeys_registered", - UUID: user.SignalID, + UUID: user.SignalID.String(), Number: user.SignalUsername, }) diff --git a/puppet.go b/puppet.go index 52953f50..45cc79a4 100644 --- a/puppet.go +++ b/puppet.go @@ -22,6 +22,7 @@ import ( "regexp" "sync" + "github.com/google/uuid" "github.com/rs/zerolog" "maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/bridge" @@ -67,7 +68,7 @@ func (puppet *Puppet) CustomIntent() *appservice.IntentAPI { func (puppet *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI { if puppet != nil { - if puppet.customIntent == nil || portal.ChatID == puppet.SignalID { + if puppet.customIntent == nil || portal.UserID() == puppet.SignalID { return puppet.DefaultIntent() } return puppet.customIntent @@ -89,13 +90,13 @@ func (br *SignalBridge) NewPuppet(dbPuppet *database.Puppet) *Puppet { return &Puppet{ Puppet: dbPuppet, bridge: br, - log: br.ZLog.With().Str("signal_user_id", dbPuppet.SignalID).Logger(), + log: br.ZLog.With().Str("signal_user_id", dbPuppet.SignalID.String()).Logger(), MXID: br.FormatPuppetMXID(dbPuppet.SignalID), } } -func (br *SignalBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) { +func (br *SignalBridge) ParsePuppetMXID(mxid id.UserID) (uuid.UUID, bool) { if userIDRegex == nil { pattern := fmt.Sprintf( "^@%s:%s$", @@ -110,10 +111,14 @@ func (br *SignalBridge) ParsePuppetMXID(mxid id.UserID) (string, bool) { match := userIDRegex.FindStringSubmatch(string(mxid)) if len(match) == 2 { - return match[1], true + parsed, err := uuid.Parse(match[1]) + if err != nil { + return uuid.Nil, false + } + return parsed, true } - return "", false + return uuid.Nil, false } func (br *SignalBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { @@ -125,11 +130,19 @@ func (br *SignalBridge) GetPuppetByMXID(mxid id.UserID) *Puppet { return br.GetPuppetBySignalID(signalID) } -func (br *SignalBridge) GetPuppetBySignalID(id string) *Puppet { +func (br *SignalBridge) GetPuppetBySignalIDString(id string) *Puppet { + parsed, err := uuid.Parse(id) + if err != nil { + return nil + } + return br.GetPuppetBySignalID(parsed) +} + +func (br *SignalBridge) GetPuppetBySignalID(id uuid.UUID) *Puppet { br.puppetsLock.Lock() defer br.puppetsLock.Unlock() - if id == "" { + if id == uuid.Nil { br.ZLog.Warn().Msg("Trying to get puppet with empty signal_user_id") return nil } @@ -141,13 +154,13 @@ func (br *SignalBridge) GetPuppetBySignalID(id string) *Puppet { br.ZLog.Err(err).Msg("Failed to get puppet from database") return nil } else if dbPuppet == nil { - br.ZLog.Info().Str("signal_user_id", id).Msg("Puppet not found in database, creating new entry") + br.ZLog.Info().Str("signal_user_id", id.String()).Msg("Puppet not found in database, creating new entry") dbPuppet = br.DB.Puppet.New() dbPuppet.SignalID = id //dbPuppet.Number = err = dbPuppet.Insert(context.TODO()) if err != nil { - br.ZLog.Error().Err(err).Str("signal_user_id", id).Msg("Error creating new puppet") + br.ZLog.Error().Err(err).Str("signal_user_id", id.String()).Msg("Error creating new puppet") return nil } } @@ -222,9 +235,9 @@ func (br *SignalBridge) GetAllPuppetsWithCustomMXID() []*Puppet { return br.dbPuppetsToPuppets(puppets) } -func (br *SignalBridge) FormatPuppetMXID(did string) id.UserID { +func (br *SignalBridge) FormatPuppetMXID(u uuid.UUID) id.UserID { return id.NewUserID( - br.Config.Bridge.FormatUsername(did), + br.Config.Bridge.FormatUsername(u.String()), br.Config.Homeserver.Domain, ) } diff --git a/user.go b/user.go index a9036d68..da1df5bf 100644 --- a/user.go +++ b/user.go @@ -137,7 +137,7 @@ func (br *SignalBridge) loadUser(ctx context.Context, dbUser *database.User, mxi user := br.NewUser(dbUser) br.usersByMXID[user.MXID] = user - if user.SignalID != "" { + if user.SignalID != uuid.Nil { br.usersBySignalID[user.SignalID] = user } if user.ManagementRoom != "" { @@ -177,7 +177,7 @@ func (br *SignalBridge) GetUserByMXID(userID id.UserID) *User { return user } -func (br *SignalBridge) GetUserBySignalID(id string) *User { +func (br *SignalBridge) GetUserBySignalID(id uuid.UUID) *User { br.usersLock.Lock() defer br.usersLock.Unlock() @@ -286,7 +286,7 @@ func (user *User) GetMXID() id.UserID { return user.MXID } func (user *User) GetRemoteID() string { - return user.SignalID + return user.SignalID.String() } func (user *User) GetRemoteName() string { @@ -504,11 +504,11 @@ func (user *User) populateSignalDevice() *signalmeow.Device { user.Lock() defer user.Unlock() - if user.SignalID == "" { + if user.SignalID == uuid.Nil { return nil } - device, err := user.bridge.MeowStore.DeviceByAci(user.SignalID) + device, err := user.bridge.MeowStore.DeviceByAci(user.SignalID.String()) if err != nil { user.log.Err(err).Msgf("problem looking up aci %s", user.SignalID) return nil @@ -524,7 +524,7 @@ func (user *User) populateSignalDevice() *signalmeow.Device { } func updatePuppetWithSignalContact(ctx context.Context, user *User, puppet *Puppet, newContactAvatar *signalmeow.ContactAvatar) error { - contact, newProfileAvatar, err := user.SignalDevice.ContactByIDWithProfileAvatar(puppet.SignalID) + contact, newProfileAvatar, err := user.SignalDevice.ContactByIDWithProfileAvatar(puppet.SignalID.String()) if err != nil { user.log.Err(err).Msg("updatePuppetWithSignalContact: error retrieving contact") return err @@ -623,10 +623,15 @@ func ensureGroupPuppetsAreJoinedToPortal(ctx context.Context, user *User, portal return err } for _, member := range group.Members { - if member.UserId == user.SignalID { + parsedUserID, err := uuid.Parse(member.UserId) + if err != nil { + // TODO log? + continue + } + if parsedUserID == user.SignalID { continue } - memberPuppet := portal.bridge.GetPuppetBySignalID(member.UserId) + memberPuppet := portal.bridge.GetPuppetBySignalID(parsedUserID) if memberPuppet == nil { user.log.Err(err).Msgf("no puppet found for signalID %s", member.UserId) continue @@ -645,8 +650,12 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign m := incomingMessage.Base() var chatID string var senderPuppet *Puppet + parsedSenderUUID, err := uuid.Parse(m.SenderUUID) + if err != nil { + return err + } - isSyncMessage := m.SenderUUID == user.SignalID + isSyncMessage := parsedSenderUUID == user.SignalID // Get and update the puppet for this message user.tryAutomaticDoublePuppeting() @@ -656,7 +665,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign chatID = m.RecipientUUID senderPuppet = user.bridge.GetPuppetByCustomMXID(user.MXID) if senderPuppet == nil { - senderPuppet = user.bridge.GetPuppetBySignalID(m.SenderUUID) + senderPuppet = user.bridge.GetPuppetBySignalID(parsedSenderUUID) if senderPuppet == nil { err := fmt.Errorf("no puppet found for me (%s)", user.MXID) user.log.Err(err).Msg("error getting puppet") @@ -666,7 +675,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign } else { user.log.Debug().Msgf("Message received from %s (group: %v)", m.SenderUUID, m.GroupID) chatID = m.SenderUUID - senderPuppet = user.bridge.GetPuppetBySignalID(m.SenderUUID) + senderPuppet = user.bridge.GetPuppetBySignalID(parsedSenderUUID) if senderPuppet == nil { err := fmt.Errorf("no puppet found for sender: %s", m.SenderUUID) user.log.Err(err).Msg("error getting puppet") @@ -694,8 +703,12 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign if incomingMessage.MessageType() == signalmeow.IncomingSignalMessageTypeReceipt { receiptMessage := incomingMessage.(signalmeow.IncomingSignalMessageReceipt) timestamp := receiptMessage.OriginalTimestamp - sender := receiptMessage.OriginalSender - dbMessage, err := user.bridge.DB.Message.GetBySignalIDWithUnknownReceiver(context.TODO(), sender, timestamp, 0, user.SignalUsername) + sender, err := uuid.Parse(receiptMessage.OriginalSender) + if err != nil { + user.log.Err(err).Msg("Failed to parse sender UUID in receipt") + return nil + } + dbMessage, err := user.bridge.DB.Message.GetBySignalIDWithUnknownReceiver(context.TODO(), sender, timestamp, 0, user.SignalID) if err != nil { user.log.Err(err).Msg("Failed to get receipt target message from database") return nil @@ -819,7 +832,7 @@ func (user *User) incomingMessageHandler(incomingMessage signalmeow.IncomingSign func (user *User) GetPortalByChatID(signalID string) *Portal { pk := database.PortalKey{ ChatID: signalID, - Receiver: user.SignalUsername, + Receiver: user.SignalID, } return user.bridge.GetPortalByChatID(pk) } @@ -924,7 +937,7 @@ func (user *User) getDirectChats() map[id.UserID][]id.RoomID { } for _, portal := range privateChats { if portal.MXID != "" { - puppetMXID := user.bridge.FormatPuppetMXID(portal.ChatID) + puppetMXID := user.bridge.FormatPuppetMXID(portal.UserID()) chats[puppetMXID] = []id.RoomID{portal.MXID} }