diff --git a/pkg/connector/groupinfo.go b/pkg/connector/groupinfo.go index 5c8b0004..840db11f 100644 --- a/pkg/connector/groupinfo.go +++ b/pkg/connector/groupinfo.go @@ -121,18 +121,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden } } for _, member := range groupInfo.PendingMembers { - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID + aci := maybeResolvePNItoACI(memberServiceID) + if aci == nil { + continue } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) @@ -144,18 +138,12 @@ func (s *SignalClient) getGroupInfo(ctx context.Context, groupID types.GroupIden }) } for _, member := range groupInfo.BannedMembers { - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID + aci := maybeResolvePNItoACI(memberServiceID) + if aci == nil { + continue } members.Members = append(members.Members, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipBan, }) } @@ -258,39 +246,23 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddPendingMembers { - if member.ServiceID.Type != libsignalgo.ServiceIDTypeACI { + aci := maybeResolvePNItoACI(memberServiceID) + if aci == nil { continue } - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID - } - mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), PowerLevel: roleToPL(member.Role), Membership: event.MembershipInvite, }) } for _, memberServiceID := range groupChange.DeletePendingMembers { - var aci uuid.UUID - if memberServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, memberServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = memberServiceID.UUID + aci := maybeResolvePNItoACI(memberServiceID) + if aci == nil { + continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipLeave, PrevMembership: event.MembershipInvite, }) @@ -309,18 +281,12 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint }) } for _, member := range groupChange.AddBannedMembers { - var aci uuid.UUID - if member.ServiceID.Type == libsignalgo.ServiceIDTypePNI { - device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, member.ServiceID.UUID) - if err != nil || device == nil { - continue - } - aci = device.ACI - } else { - aci = member.ServiceID.UUID + aci := maybeResolvePNItoACI(memberServiceID) + if aci == nil { + continue } mc = append(mc, bridgev2.ChatMember{ - EventSender: s.makeEventSender(aci), + EventSender: s.makeEventSender(*aci), Membership: event.MembershipBan, }) } @@ -340,6 +306,17 @@ func (s *SignalClient) groupChangeToChatInfoChange(ctx context.Context, rev uint return ic } +func (s *SignalClient) maybeResolvePNItoACI(serviceID libsignalgo.ServiceID) *uuid.UUID { + if serviceID.Type == libsignalgo.ServiceIDTypeACI { + return &serviceID.UUID + } + device, err := s.Client.Store.DeviceStore.DeviceByPNI(ctx, serviceID.UUID) + if err != nil || device == nil { + return nil + } + return &device.ACI +} + func (s *SignalClient) catchUpGroup(ctx context.Context, portal *bridgev2.Portal, fromRevision, toRevision uint32, ts uint64) { if fromRevision >= toRevision { return