From 4b7fda258884e957e1bf7f40ee80379da94f35d9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 31 Dec 2023 09:20:33 +0100 Subject: [PATCH] Fix saving message part index --- pkg/signalmeow/incoming_messages.go | 1 + pkg/signalmeow/receiving.go | 25 +++++++++++-------------- portal.go | 14 ++++++++------ 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pkg/signalmeow/incoming_messages.go b/pkg/signalmeow/incoming_messages.go index 78b723c4..fd766ddd 100644 --- a/pkg/signalmeow/incoming_messages.go +++ b/pkg/signalmeow/incoming_messages.go @@ -28,6 +28,7 @@ type IncomingSignalMessageBase struct { RecipientUUID string // Usually our UUID, unless this is a message we sent on another device GroupID *GroupIdentifier // Unique identifier for the group chat, or nil for 1:1 chats Timestamp uint64 // With SenderUUID, treated as a unique identifier for a specific Signal message + PartIndex int // Quote *IncomingSignalMessageQuoteData // If this message is a quote (reply), this will be non-nil ExpiresIn int64 // If this message is ephemeral, this will be non-zero (in seconds) } diff --git a/pkg/signalmeow/receiving.go b/pkg/signalmeow/receiving.go index 57e51f8b..e0d82eb7 100644 --- a/pkg/signalmeow/receiving.go +++ b/pkg/signalmeow/receiving.go @@ -828,7 +828,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa expiresIn = int64(dataMessage.GetExpireTimer()) } - tsIndex := 0 + partIndex := 0 // If there's attachements, handle them (one at a time for now) if dataMessage.Attachments != nil { for index, attachmentPointer := range dataMessage.Attachments { @@ -837,19 +837,14 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa zlog.Err(err).Msg("fetchAndDecryptAttachment error") continue } - // TODO: this is a hack to make sure each attachment has a unique timestamp - // this allows us to associate up to 999 attachments with a single Signal message - var timestamp uint64 = dataMessage.GetTimestamp() - if index > 0 { - timestamp = (dataMessage.GetTimestamp() * 1000) + uint64(index) - } // TODO: right now this will be one message per image, each with the same caption incomingMessage := IncomingSignalMessageAttachment{ IncomingSignalMessageBase: IncomingSignalMessageBase{ SenderUUID: senderUUID, RecipientUUID: recipientUUID, GroupID: gidPointer, - Timestamp: timestamp, + Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, Quote: quoteData, ExpiresIn: expiresIn, }, @@ -861,33 +856,31 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa Height: attachmentPointer.GetHeight(), BlurHash: attachmentPointer.GetBlurHash(), } + partIndex++ if HackyCaptionToggle && index == 0 { incomingMessage.Caption = dataMessage.GetBody() incomingMessage.CaptionRanges = dataMessage.GetBodyRanges() } incomingMessages = append(incomingMessages, incomingMessage) } - tsIndex += len(dataMessage.Attachments) } // If there's a body but no attachments, pass along as a text message if dataMessage.Body != nil && (dataMessage.Attachments == nil || !HackyCaptionToggle) { - timestamp := dataMessage.GetTimestamp() - if tsIndex != 0 { - timestamp = (dataMessage.GetTimestamp() * 1000) + uint64(tsIndex+1) - } incomingMessage := IncomingSignalMessageText{ IncomingSignalMessageBase: IncomingSignalMessageBase{ SenderUUID: senderUUID, RecipientUUID: recipientUUID, GroupID: gidPointer, - Timestamp: timestamp, + Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, Quote: quoteData, ExpiresIn: expiresIn, }, Content: dataMessage.GetBody(), ContentRanges: dataMessage.GetBodyRanges(), } + partIndex++ incomingMessages = append(incomingMessages, incomingMessage) } @@ -903,6 +896,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa RecipientUUID: recipientUUID, GroupID: gidPointer, Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, Quote: quoteData, ExpiresIn: expiresIn, }, @@ -914,6 +908,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa Emoji: dataMessage.GetSticker().GetEmoji(), } incomingMessages = append(incomingMessages, incomingMessage) + partIndex++ } } @@ -978,6 +973,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa RecipientUUID: recipientUUID, GroupID: gidPointer, Timestamp: dataMessage.GetTimestamp(), + PartIndex: partIndex, }, DisplayName: contactCard.GetName().GetDisplayName(), Organization: contactCard.GetOrganization(), @@ -1017,6 +1013,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa addressString := strings.Join(addressParts, ", ") incomingMessage.Addresses = append(incomingMessage.Addresses, addressString) } + partIndex++ incomingMessages = append(incomingMessages, incomingMessage) } } diff --git a/portal.go b/portal.go index d1d98cde..cf51c581 100644 --- a/portal.go +++ b/portal.go @@ -379,7 +379,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *User, evt timings.totalSend = time.Since(start) go ms.sendMessageMetrics(evt, err, "Error sending", true) if err == nil { - portal.storeMessageInDB(ctx, evt.ID, sender.SignalID, timestamp) + portal.storeMessageInDB(ctx, evt.ID, sender.SignalID, timestamp, 0) if portal.ExpirationTime > 0 { portal.addDisappearingMessage(ctx, evt.ID, int64(portal.ExpirationTime), true) } @@ -940,13 +940,14 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { Str("action", "handle signal message"). Str("sender", portalMessage.sender.SignalID.String()). Uint64("timestamp", portalMessage.message.Base().Timestamp). + Int("part_index", portalMessage.message.Base().PartIndex). Logger() ctx := log.WithContext(context.TODO()) if existingMessage, err := portal.bridge.DB.Message.GetBySignalID( ctx, portalMessage.sender.SignalID, portalMessage.message.Base().Timestamp, - 0, + portalMessage.message.Base().PartIndex, portal.Receiver, ); err != nil { log.Err(err).Msg("Failed to check if message was already handled") @@ -1026,12 +1027,13 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) { } } -func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID uuid.UUID, timestamp uint64) { +func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID uuid.UUID, timestamp uint64, partIndex int) { dbMessage := portal.bridge.DB.Message.New() dbMessage.MXID = eventID dbMessage.RoomID = portal.MXID dbMessage.Sender = senderSignalID dbMessage.Timestamp = timestamp + dbMessage.PartIndex = partIndex dbMessage.SignalChatID = portal.ChatID dbMessage.SignalReceiver = portal.Receiver err := dbMessage.Insert(ctx) @@ -1111,7 +1113,7 @@ func (portal *Portal) handleSignalTextMessage(ctx context.Context, portalMessage if resp.EventID == "" { return errors.New("Didn't receive event ID from Matrix") } - portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp) + portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex) portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) return err } @@ -1144,7 +1146,7 @@ func (portal *Portal) handleSignalStickerMessage(ctx context.Context, portalMess if resp.EventID == "" { return errors.New("Didn't receive event ID from Matrix") } - portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp) + portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex) portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) return err } @@ -1438,7 +1440,7 @@ func (portal *Portal) handleSignalAttachmentMessage(ctx context.Context, portalM if resp.EventID == "" { return errors.New("Didn't receive event ID from Matrix") } - portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp) + portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex) portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync) return err }