Skip to content

Commit

Permalink
Fix saving message part index
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Dec 31, 2023
1 parent 242d0a3 commit 4b7fda2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
1 change: 1 addition & 0 deletions pkg/signalmeow/incoming_messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type IncomingSignalMessageBase struct {
RecipientUUID string // Usually our UUID, unless this is a message we sent on another device
GroupID *GroupIdentifier // Unique identifier for the group chat, or nil for 1:1 chats
Timestamp uint64 // With SenderUUID, treated as a unique identifier for a specific Signal message
PartIndex int //
Quote *IncomingSignalMessageQuoteData // If this message is a quote (reply), this will be non-nil
ExpiresIn int64 // If this message is ephemeral, this will be non-zero (in seconds)
}
Expand Down
25 changes: 11 additions & 14 deletions pkg/signalmeow/receiving.go
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
expiresIn = int64(dataMessage.GetExpireTimer())
}

tsIndex := 0
partIndex := 0
// If there's attachements, handle them (one at a time for now)
if dataMessage.Attachments != nil {
for index, attachmentPointer := range dataMessage.Attachments {
Expand All @@ -837,19 +837,14 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
zlog.Err(err).Msg("fetchAndDecryptAttachment error")
continue
}
// TODO: this is a hack to make sure each attachment has a unique timestamp
// this allows us to associate up to 999 attachments with a single Signal message
var timestamp uint64 = dataMessage.GetTimestamp()
if index > 0 {
timestamp = (dataMessage.GetTimestamp() * 1000) + uint64(index)
}
// TODO: right now this will be one message per image, each with the same caption
incomingMessage := IncomingSignalMessageAttachment{
IncomingSignalMessageBase: IncomingSignalMessageBase{
SenderUUID: senderUUID,
RecipientUUID: recipientUUID,
GroupID: gidPointer,
Timestamp: timestamp,
Timestamp: dataMessage.GetTimestamp(),
PartIndex: partIndex,
Quote: quoteData,
ExpiresIn: expiresIn,
},
Expand All @@ -861,33 +856,31 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
Height: attachmentPointer.GetHeight(),
BlurHash: attachmentPointer.GetBlurHash(),
}
partIndex++
if HackyCaptionToggle && index == 0 {
incomingMessage.Caption = dataMessage.GetBody()
incomingMessage.CaptionRanges = dataMessage.GetBodyRanges()
}
incomingMessages = append(incomingMessages, incomingMessage)
}
tsIndex += len(dataMessage.Attachments)
}

// If there's a body but no attachments, pass along as a text message
if dataMessage.Body != nil && (dataMessage.Attachments == nil || !HackyCaptionToggle) {
timestamp := dataMessage.GetTimestamp()
if tsIndex != 0 {
timestamp = (dataMessage.GetTimestamp() * 1000) + uint64(tsIndex+1)
}
incomingMessage := IncomingSignalMessageText{
IncomingSignalMessageBase: IncomingSignalMessageBase{
SenderUUID: senderUUID,
RecipientUUID: recipientUUID,
GroupID: gidPointer,
Timestamp: timestamp,
Timestamp: dataMessage.GetTimestamp(),
PartIndex: partIndex,
Quote: quoteData,
ExpiresIn: expiresIn,
},
Content: dataMessage.GetBody(),
ContentRanges: dataMessage.GetBodyRanges(),
}
partIndex++
incomingMessages = append(incomingMessages, incomingMessage)
}

Expand All @@ -903,6 +896,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
RecipientUUID: recipientUUID,
GroupID: gidPointer,
Timestamp: dataMessage.GetTimestamp(),
PartIndex: partIndex,
Quote: quoteData,
ExpiresIn: expiresIn,
},
Expand All @@ -914,6 +908,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
Emoji: dataMessage.GetSticker().GetEmoji(),
}
incomingMessages = append(incomingMessages, incomingMessage)
partIndex++
}
}

Expand Down Expand Up @@ -978,6 +973,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
RecipientUUID: recipientUUID,
GroupID: gidPointer,
Timestamp: dataMessage.GetTimestamp(),
PartIndex: partIndex,
},
DisplayName: contactCard.GetName().GetDisplayName(),
Organization: contactCard.GetOrganization(),
Expand Down Expand Up @@ -1017,6 +1013,7 @@ func incomingDataMessage(ctx context.Context, device *Device, dataMessage *signa
addressString := strings.Join(addressParts, ", ")
incomingMessage.Addresses = append(incomingMessage.Addresses, addressString)
}
partIndex++
incomingMessages = append(incomingMessages, incomingMessage)
}
}
Expand Down
14 changes: 8 additions & 6 deletions portal.go
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ func (portal *Portal) handleMatrixMessage(ctx context.Context, sender *User, evt
timings.totalSend = time.Since(start)
go ms.sendMessageMetrics(evt, err, "Error sending", true)
if err == nil {
portal.storeMessageInDB(ctx, evt.ID, sender.SignalID, timestamp)
portal.storeMessageInDB(ctx, evt.ID, sender.SignalID, timestamp, 0)
if portal.ExpirationTime > 0 {
portal.addDisappearingMessage(ctx, evt.ID, int64(portal.ExpirationTime), true)
}
Expand Down Expand Up @@ -940,13 +940,14 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) {
Str("action", "handle signal message").
Str("sender", portalMessage.sender.SignalID.String()).
Uint64("timestamp", portalMessage.message.Base().Timestamp).
Int("part_index", portalMessage.message.Base().PartIndex).
Logger()
ctx := log.WithContext(context.TODO())
if existingMessage, err := portal.bridge.DB.Message.GetBySignalID(
ctx,
portalMessage.sender.SignalID,
portalMessage.message.Base().Timestamp,
0,
portalMessage.message.Base().PartIndex,
portal.Receiver,
); err != nil {
log.Err(err).Msg("Failed to check if message was already handled")
Expand Down Expand Up @@ -1026,12 +1027,13 @@ func (portal *Portal) handleSignalMessages(portalMessage portalSignalMessage) {
}
}

func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID uuid.UUID, timestamp uint64) {
func (portal *Portal) storeMessageInDB(ctx context.Context, eventID id.EventID, senderSignalID uuid.UUID, timestamp uint64, partIndex int) {
dbMessage := portal.bridge.DB.Message.New()
dbMessage.MXID = eventID
dbMessage.RoomID = portal.MXID
dbMessage.Sender = senderSignalID
dbMessage.Timestamp = timestamp
dbMessage.PartIndex = partIndex
dbMessage.SignalChatID = portal.ChatID
dbMessage.SignalReceiver = portal.Receiver
err := dbMessage.Insert(ctx)
Expand Down Expand Up @@ -1111,7 +1113,7 @@ func (portal *Portal) handleSignalTextMessage(ctx context.Context, portalMessage
if resp.EventID == "" {
return errors.New("Didn't receive event ID from Matrix")
}
portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp)
portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex)
portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync)
return err
}
Expand Down Expand Up @@ -1144,7 +1146,7 @@ func (portal *Portal) handleSignalStickerMessage(ctx context.Context, portalMess
if resp.EventID == "" {
return errors.New("Didn't receive event ID from Matrix")
}
portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp)
portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex)
portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync)
return err
}
Expand Down Expand Up @@ -1438,7 +1440,7 @@ func (portal *Portal) handleSignalAttachmentMessage(ctx context.Context, portalM
if resp.EventID == "" {
return errors.New("Didn't receive event ID from Matrix")
}
portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp)
portal.storeMessageInDB(ctx, resp.EventID, portalMessage.sender.SignalID, timestamp, portalMessage.message.Base().PartIndex)
portal.addDisappearingMessage(ctx, resp.EventID, portalMessage.message.Base().ExpiresIn, portalMessage.sync)
return err
}
Expand Down

0 comments on commit 4b7fda2

Please sign in to comment.