From 77cf7e79854ae7bffecbc0f9f48fee06bacf6572 Mon Sep 17 00:00:00 2001 From: Malte E <97891689+maltee1@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:03:50 +0200 Subject: [PATCH] groupinfo: look up ACI from local devices when receiving PNI group member (#528) --- pkg/connector/groupinfo.go | 42 ++++++++++++++++++++++--------- pkg/signalmeow/store/container.go | 10 ++++++++ 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/pkg/connector/groupinfo.go b/pkg/connector/groupinfo.go index 3bf809d1..db06b957 100644 --- a/pkg/connector/groupinfo.go +++ b/pkg/connector/groupinfo.go @@ -20,6 +20,7 @@ import ( "context" "time" + "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" @@ -120,11 +121,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden } } for _, member := range groupInfo.PendingMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { continue } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(*aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) @@ -136,11 +138,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden }) } for _, member := range groupInfo.BannedMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { continue } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipBan, }) } @@ -243,21 +246,23 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddPendingMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(*aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) } for _, memberServiceID := range groupChange.DeletePendingMembers { - if memberServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, memberServiceID) + if aci == nil { continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(memberServiceID.UUID), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipInvite, }) @@ -276,20 +281,22 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddBannedMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipBan, }) } for _, memberServiceID := range groupChange.DeleteBannedMembers { - if memberServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, memberServiceID) + if aci == nil { continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(memberServiceID.UUID), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipBan, }) @@ -300,6 +307,17 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint return ic } +func (s *SignalClient) maybeResolvePNItoACI(ctx context.Context, serviceID *libsignalgo.ServiceID) *uuid.UUID { + if serviceID.Type == libsignalgo.ServiceIDTypeACI { + return &serviceID.UUID + } + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, serviceID.UUID) + if err != nil || device == nil { + return nil + } + return &device.ACI +} + func (s *SignalClient) catchUpGroup(ctx context.Context, portal *bridgev2.Portal, fromRevision, toRevision uint32, ts uint64) { if fromRevision >= toRevision { return diff --git a/pkg/signalmeow/store/container.go b/pkg/signalmeow/store/container.go index 9edafef6..a08b42d6 100644 --- a/pkg/signalmeow/store/container.go +++ b/pkg/signalmeow/store/container.go @@ -19,6 +19,7 @@ var _ DeviceStore = (*Container)(nil) type DeviceStore interface { PutDevice(ctx context.Context, dd *DeviceData) error DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error) + DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) } // Container is a wrapper for a SQL database that can contain multiple signalmeow sessions. @@ -39,6 +40,7 @@ FROM signalmeow_device ` const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1" +const deviceByPNIQuery = getAllDevicesQuery + "WHERE pni_uuid=$1" func (c *Container) Upgrade(ctx context.Context) error { return c.db.Upgrade(ctx) @@ -122,6 +124,14 @@ func (c *Container) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, er return sess, err } +func (c *Container) DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) { + sess, err := c.scanDevice(c.db.QueryRow(ctx, deviceByPNIQuery, pni)) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return sess, err +} + const ( insertDeviceQuery = ` INSERT INTO signalmeow_device (