From 5ff230d0e41f73f12d5ef3dff8a2fd28b87a181a Mon Sep 17 00:00:00 2001 From: Malte E Date: Thu, 8 Aug 2024 23:13:59 +0200 Subject: [PATCH 1/4] look up aci when receiving pni --- pkg/connector/groupinfo.go | 66 +++++++++++++++++++++++++------ pkg/signalmeow/store/container.go | 10 +++++ 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/pkg/connector/groupinfo.go b/pkg/connector/groupinfo.go index 3bf809d1..43e2b78b 100644 --- a/pkg/connector/groupinfo.go +++ b/pkg/connector/groupinfo.go @@ -20,6 +20,7 @@ import ( "context" "time" + "github.com/google/uuid" "github.com/rs/zerolog" "go.mau.fi/util/ptr" "maunium.net/go/mautrix/bridgev2" @@ -120,11 +121,18 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden } } for _, member := range groupInfo.PendingMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) @@ -136,11 +144,18 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden }) } for _, member := range groupInfo.BannedMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), Membership: event.MembershipBan, }) } @@ -246,18 +261,36 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { continue } + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID + } + mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) } for _, memberServiceID := range groupChange.DeletePendingMembers { - if memberServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if memberServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, memberServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = memberServiceID.UUID } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(memberServiceID.UUID), + EventSender: s.makeEventSender(aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipInvite, }) @@ -276,11 +309,18 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddBannedMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { - continue + var aci uuid.UUID + if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) + if err != nil { + continue + } + aci = device.ACI + } else { + aci = member.ServiceID.UUID } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(member.ServiceID.UUID), + EventSender: s.makeEventSender(aci), Membership: event.MembershipBan, }) } diff --git a/pkg/signalmeow/store/container.go b/pkg/signalmeow/store/container.go index 9edafef6..879716dc 100644 --- a/pkg/signalmeow/store/container.go +++ b/pkg/signalmeow/store/container.go @@ -19,6 +19,7 @@ var _ DeviceStore = (*Container)(nil) type DeviceStore interface { PutDevice(ctx context.Context, dd *DeviceData) error DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, error) + DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) } // Container is a wrapper for a SQL database that can contain multiple signalmeow sessions. @@ -39,6 +40,7 @@ FROM signalmeow_device ` const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1" +const deviceByPNIQuery = getAllDevicesQuery + "Where pni_uuid=$1" func (c *Container) Upgrade(ctx context.Context) error { return c.db.Upgrade(ctx) @@ -122,6 +124,14 @@ func (c *Container) DeviceByACI(ctx context.Context, aci uuid.UUID) (*Device, er return sess, err } +func (c *Container) DeviceByPNI(ctx context.Context, pni uuid.UUID) (*Device, error) { + sess, err := c.scanDevice(c.db.QueryRow(ctx, deviceByPNIQuery, pni)) + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return sess, err +} + const ( insertDeviceQuery = ` INSERT INTO signalmeow_device ( From 3dca57a706582557918deb86096e0b34dffce631 Mon Sep 17 00:00:00 2001 From: Malte E Date: Fri, 9 Aug 2024 07:57:50 +0200 Subject: [PATCH 2/4] check whether device == nil --- pkg/connector/groupinfo.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/connector/groupinfo.go b/pkg/connector/groupinfo.go index 43e2b78b..5c8b0004 100644 --- a/pkg/connector/groupinfo.go +++ b/pkg/connector/groupinfo.go @@ -124,7 +124,7 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden var aci uuid.UUID if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil { + if err != nil || device == nil { continue } aci = device.ACI @@ -147,7 +147,7 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden var aci uuid.UUID if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil { + if err != nil || device == nil { continue } aci = device.ACI @@ -264,7 +264,7 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint var aci uuid.UUID if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil { + if err != nil || device == nil { continue } aci = device.ACI @@ -282,7 +282,7 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint var aci uuid.UUID if memberServiceID.Type == libsignalgo.ServiceIDTypePNI { device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, memberServiceID.UUID) - if err != nil { + if err != nil || device == nil { continue } aci = device.ACI @@ -312,7 +312,7 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint var aci uuid.UUID if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil { + if err != nil || device == nil { continue } aci = device.ACI From 83f2bfde98c008a7eb66cf802ec599f0179b42ae Mon Sep 17 00:00:00 2001 From: Malte Date: Fri, 9 Aug 2024 09:04:51 +0200 Subject: [PATCH 3/4] function instead of boilerplate --- pkg/connector/groupinfo.go | 88 ++++++++++++++------------------------ 1 file changed, 33 insertions(+), 55 deletions(-) diff --git a/pkg/connector/groupinfo.go b/pkg/connector/groupinfo.go index 5c8b0004..db06b957 100644 --- a/pkg/connector/groupinfo.go +++ b/pkg/connector/groupinfo.go @@ -121,18 +121,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden } } for _, member := range groupInfo.PendingMembers { - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { + continue } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) @@ -144,18 +138,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden }) } for _, member := range groupInfo.BannedMembers { - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { + continue } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipBan, }) } @@ -258,39 +246,23 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddPendingMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { continue } - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID - } - mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) } for _, memberServiceID := range groupChange.DeletePendingMembers { - var aci uuid.UUID - if memberServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, memberServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = memberServiceID.UUID + aci := s.maybeResolvePNItoACI(ctx, memberServiceID) + if aci == nil { + continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipInvite, }) @@ -309,27 +281,22 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddBannedMembers { - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID + aci := s.maybeResolvePNItoACI(ctx, &member.ServiceID) + if aci == nil { + continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipBan, }) } for _, memberServiceID := range groupChange.DeleteBannedMembers { - if memberServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := s.maybeResolvePNItoACI(ctx, memberServiceID) + if aci == nil { continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(memberServiceID.UUID), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipBan, }) @@ -340,6 +307,17 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint return ic } +func (s *SignalClient) maybeResolvePNItoACI(ctx context.Context, serviceID *libsignalgo.ServiceID) *uuid.UUID { + if serviceID.Type == libsignalgo.ServiceIDTypeACI { + return &serviceID.UUID + } + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, serviceID.UUID) + if err != nil || device == nil { + return nil + } + return &device.ACI +} + func (s *SignalClient) catchUpGroup(ctx context.Context, portal *bridgev2.Portal, fromRevision, toRevision uint32, ts uint64) { if fromRevision >= toRevision { return From bd53f191c856f6a21665163c6f77f420bb1af42f Mon Sep 17 00:00:00 2001 From: Malte E <97891689+maltee1@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:01:51 +0200 Subject: [PATCH 4/4] Update pkg/signalmeow/store/container.go Co-authored-by: Tulir Asokan --- pkg/signalmeow/store/container.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/signalmeow/store/container.go b/pkg/signalmeow/store/container.go index 879716dc..a08b42d6 100644 --- a/pkg/signalmeow/store/container.go +++ b/pkg/signalmeow/store/container.go @@ -40,7 +40,7 @@ FROM signalmeow_device ` const getDeviceQuery = getAllDevicesQuery + " WHERE aci_uuid=$1" -const deviceByPNIQuery = getAllDevicesQuery + "Where pni_uuid=$1" +const deviceByPNIQuery = getAllDevicesQuery + "WHERE pni_uuid=$1" func (c *Container) Upgrade(ctx context.Context) error { return c.db.Upgrade(ctx)