From 7e74d98371b564ce9125c2ed72e9f780f80c7496 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 6 Jun 2024 16:11:58 +0300 Subject: [PATCH] Implement Matrix reactions, edits and redactions --- go.mod | 2 +- go.sum | 4 +- pkg/connector/connector.go | 191 +++++++++++++++++++++++++++++++------ pkg/signalmeow/sending.go | 7 ++ 4 files changed, 173 insertions(+), 31 deletions(-) diff --git a/go.mod b/go.mod index 848b912d..fee3d523 100644 --- a/go.mod +++ b/go.mod @@ -21,7 +21,7 @@ require ( golang.org/x/net v0.25.0 google.golang.org/protobuf v1.34.1 gopkg.in/yaml.v3 v3.0.1 - maunium.net/go/mautrix v0.18.2-0.20240605105031-218ed06e73f6 + maunium.net/go/mautrix v0.18.2-0.20240606131110-a0e309fa55ab nhooyr.io/websocket v1.8.11 ) diff --git a/go.sum b/go.sum index dfce1298..b80546d5 100644 --- a/go.sum +++ b/go.sum @@ -95,7 +95,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.18.2-0.20240605105031-218ed06e73f6 h1:JkMk5Urz1niqsqOVWhoHculon2FSVrITM1g1iVMcxhU= -maunium.net/go/mautrix v0.18.2-0.20240605105031-218ed06e73f6/go.mod h1:P/FV8cXY262MezYX7ViuhfzeJ0nK4+M8K6ZmxEC/aEA= +maunium.net/go/mautrix v0.18.2-0.20240606131110-a0e309fa55ab h1:e0Zo3/K+quT6p+U2Gsmw8R1kinzZ0wOkPbVPwoMkrBY= +maunium.net/go/mautrix v0.18.2-0.20240606131110-a0e309fa55ab/go.mod h1:P/FV8cXY262MezYX7ViuhfzeJ0nK4+M8K6ZmxEC/aEA= nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= diff --git a/pkg/connector/connector.go b/pkg/connector/connector.go index a0fc4ee6..9075567a 100644 --- a/pkg/connector/connector.go +++ b/pkg/connector/connector.go @@ -27,6 +27,7 @@ import ( "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/dbutil" + "go.mau.fi/util/variationselector" "golang.org/x/exp/slices" "google.golang.org/protobuf/proto" @@ -381,6 +382,20 @@ func (s *SignalClient) getPortalID(chatID string) networkid.PortalID { } } +func parseMessageID(messageID networkid.MessageID) (sender uuid.UUID, timestamp uint64, err error) { + parts := strings.Split(string(messageID), "|") + if len(parts) != 2 { + err = fmt.Errorf("invalid message ID: expected two pipe-separated parts") + return + } + sender, err = uuid.Parse(parts[0]) + if err != nil { + return + } + timestamp, err = strconv.ParseUint(parts[1], 10, 64) + return +} + func makeMessageID(sender uuid.UUID, timestamp uint64) networkid.MessageID { return networkid.MessageID(fmt.Sprintf("%s|%d", sender, timestamp)) } @@ -519,31 +534,18 @@ func (s *SignalClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma ReplyTo: msg.ReplyTo, } ctx = context.WithValue(ctx, msgconvContextKey, mcCtx) - userID, groupID, err := s.parsePortalID(msg.Portal.ID) + converted, err := s.Main.MsgConv.ToSignal(ctx, msg.Event, msg.Content, msg.OrigSender != nil) if err != nil { return nil, err } - converted, err := s.Main.MsgConv.ToSignal(ctx, msg.Event, msg.Content, msg.OrigSender != nil) + res, err := s.sendMessage(ctx, msg.Portal.ID, &signalpb.Content{DataMessage: converted}) if err != nil { return nil, err } - wrappedContent := &signalpb.Content{ - DataMessage: converted, - } - if groupID != "" { - res, err := s.Client.SendGroupMessage(ctx, groupID, wrappedContent) - if err != nil { - return nil, err - } - // TODO check result - fmt.Println(res) - } else { - res := s.Client.SendMessage(ctx, userID, wrappedContent) - // TODO check result - fmt.Println(res) - } + // TODO check result + fmt.Println(res) meta := map[string]any{ - "reply_to_file": len(converted.Attachments) > 0, + "contains_attachments": len(converted.Attachments) > 0, } dbMsg := &database.Message{ ID: makeMessageID(s.Client.Store.ACI, converted.GetTimestamp()), @@ -560,23 +562,156 @@ func (s *SignalClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.Ma } func (s *SignalClient) HandleMatrixEdit(ctx context.Context, msg *bridgev2.MatrixEdit) error { - //TODO implement me - panic("implement me") + _, targetSentTimestamp, err := parseMessageID(msg.EditTarget.ID) + if err != nil { + return fmt.Errorf("failed to parse target message ID: %w", err) + } else if msg.EditTarget.SenderID != makeUserID(s.Client.Store.ACI) { + return fmt.Errorf("cannot edit other people's messages") + } + mcCtx := &msgconvContext{ + Connector: s.Main, + Intent: nil, + Client: s, + Portal: msg.Portal, + } + if msg.EditTarget.RelatesToRowID != 0 { + var err error + mcCtx.ReplyTo, err = s.Main.Bridge.DB.Message.GetByRowID(ctx, msg.EditTarget.RelatesToRowID) + if err != nil { + return fmt.Errorf("failed to get message reply target: %w", err) + } + } + ctx = context.WithValue(ctx, msgconvContextKey, mcCtx) + converted, err := s.Main.MsgConv.ToSignal(ctx, msg.Event, msg.Content, msg.OrigSender != nil) + if err != nil { + return err + } + res, err := s.sendMessage(ctx, msg.Portal.ID, &signalpb.Content{EditMessage: &signalpb.EditMessage{ + TargetSentTimestamp: proto.Uint64(targetSentTimestamp), + DataMessage: converted, + }}) + if err != nil { + return err + } + // TODO check result + fmt.Println(res) + msg.EditTarget.ID = makeMessageID(s.Client.Store.ACI, converted.GetTimestamp()) + msg.EditTarget.Metadata["contains_attachments"] = len(converted.Attachments) > 0 + return nil } -func (s *SignalClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (emojiID networkid.EmojiID, err error) { - //TODO implement me - panic("implement me") +func (s *SignalClient) sendMessage(ctx context.Context, portalID networkid.PortalID, content *signalpb.Content) (signalmeow.SendResult, error) { + userID, groupID, err := s.parsePortalID(portalID) + if err != nil { + return nil, err + } + if groupID != "" { + res, err := s.Client.SendGroupMessage(ctx, groupID, content) + if err != nil { + return nil, err + } + return res, nil + } else { + res := s.Client.SendMessage(ctx, userID, content) + return &res, nil + } +} + +func (s *SignalClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (reaction *database.Reaction, err error) { + senderID := makeUserID(s.Client.Store.ACI) + // emojiID is always empty because only one reaction is allowed per message+user + var emojiID networkid.EmojiID + signalEmoji := variationselector.FullyQualify(msg.Content.RelatesTo.Key) + if existing, err := msg.GetExisting(ctx, senderID, emojiID); err != nil { + return nil, fmt.Errorf("failed to check for duplicate reaction: %w", err) + } else if existing != nil && existing.Metadata["emoji"] == signalEmoji { + return nil, nil + } + targetAuthorACI, targetSentTimestamp, err := parseMessageID(msg.TargetMessage.ID) + if err != nil { + return nil, fmt.Errorf("failed to parse target message ID: %w", err) + } + wrappedContent := &signalpb.Content{ + DataMessage: &signalpb.DataMessage{ + Timestamp: proto.Uint64(uint64(msg.Event.Timestamp)), + RequiredProtocolVersion: proto.Uint32(uint32(signalpb.DataMessage_REACTIONS)), + Reaction: &signalpb.DataMessage_Reaction{ + Emoji: proto.String(signalEmoji), + Remove: proto.Bool(false), + TargetAuthorAci: proto.String(targetAuthorACI.String()), + TargetSentTimestamp: proto.Uint64(targetSentTimestamp), + }, + }, + } + res, err := s.sendMessage(ctx, msg.Portal.ID, wrappedContent) + if err != nil { + return nil, err + } + // TODO check result + fmt.Println(res) + return &database.Reaction{ + RoomID: msg.Portal.ID, + MessageID: msg.TargetMessage.ID, + MessagePartID: msg.TargetMessage.PartID, + SenderID: senderID, + EmojiID: emojiID, + MXID: msg.Event.ID, + Timestamp: time.UnixMilli(msg.Event.Timestamp), + Metadata: map[string]any{ + "emoji": signalEmoji, + }, + }, nil } func (s *SignalClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { - //TODO implement me - panic("implement me") + emoji, _ := msg.TargetReaction.Metadata["emoji"].(string) + targetAuthorACI, targetSentTimestamp, err := parseMessageID(msg.TargetReaction.MessageID) + if err != nil { + return fmt.Errorf("failed to parse target message ID: %w", err) + } + wrappedContent := &signalpb.Content{ + DataMessage: &signalpb.DataMessage{ + Timestamp: proto.Uint64(uint64(msg.Event.Timestamp)), + RequiredProtocolVersion: proto.Uint32(uint32(signalpb.DataMessage_REACTIONS)), + Reaction: &signalpb.DataMessage_Reaction{ + Emoji: proto.String(emoji), + Remove: proto.Bool(true), + TargetAuthorAci: proto.String(targetAuthorACI.String()), + TargetSentTimestamp: proto.Uint64(targetSentTimestamp), + }, + }, + } + res, err := s.sendMessage(ctx, msg.Portal.ID, wrappedContent) + if err != nil { + return err + } + // TODO check result + fmt.Println(res) + return nil } func (s *SignalClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { - //TODO implement me - panic("implement me") + _, targetSentTimestamp, err := parseMessageID(msg.TargetMessage.ID) + if err != nil { + return fmt.Errorf("failed to parse target message ID: %w", err) + } else if msg.TargetMessage.SenderID != makeUserID(s.Client.Store.ACI) { + return fmt.Errorf("cannot delete other people's messages") + } + wrappedContent := &signalpb.Content{ + DataMessage: &signalpb.DataMessage{ + Timestamp: proto.Uint64(uint64(msg.Event.Timestamp)), + Delete: &signalpb.DataMessage_Delete{ + TargetSentTimestamp: proto.Uint64(targetSentTimestamp), + }, + }, + } + res, err := s.sendMessage(ctx, msg.Portal.ID, wrappedContent) + if err != nil { + return err + } + // TODO check result + fmt.Println(res) + return nil } type msgconvPortalMethods struct{} @@ -606,7 +741,7 @@ func (mpm *msgconvPortalMethods) GetSignalReply(ctx context.Context, content *ev AuthorAci: proto.String(string(mcCtx.ReplyTo.SenderID)), Type: signalpb.DataMessage_Quote_NORMAL.Enum(), } - if mcCtx.ReplyTo.Metadata["reply_to_file"] != false { + if mcCtx.ReplyTo.Metadata["contains_attachments"] != false { quote.Attachments = make([]*signalpb.DataMessage_Quote_QuotedAttachment, 1) } return quote diff --git a/pkg/signalmeow/sending.go b/pkg/signalmeow/sending.go index 11211e00..053e3354 100644 --- a/pkg/signalmeow/sending.go +++ b/pkg/signalmeow/sending.go @@ -304,6 +304,13 @@ type GroupMessageSendResult struct { FailedToSendTo []FailedSendResult } +type SendResult interface { + isSendResult() +} + +func (gmsr *GroupMessageSendResult) isSendResult() {} +func (smsr *SendMessageResult) isSendResult() {} + func contentFromDataMessage(dataMessage *signalpb.DataMessage) *signalpb.Content { return &signalpb.Content{ DataMessage: dataMessage,