diff --git a/connector/connector.go b/connector/connector.go index 6661c26b..fd7d7261 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -19,17 +19,28 @@ package connector import ( "context" "fmt" + "strconv" + "strings" + "time" "github.com/google/uuid" + "github.com/rs/zerolog" + "golang.org/x/exp/slices" + "google.golang.org/protobuf/proto" legacydb "go.mau.fi/mautrix-signal/database" "go.mau.fi/mautrix-signal/msgconv" + "go.mau.fi/mautrix-signal/msgconv/matrixfmt" + "go.mau.fi/mautrix-signal/msgconv/signalfmt" + "go.mau.fi/mautrix-signal/pkg/libsignalgo" "go.mau.fi/mautrix-signal/pkg/signalmeow" "go.mau.fi/mautrix-signal/pkg/signalmeow/events" signalpb "go.mau.fi/mautrix-signal/pkg/signalmeow/protobuf" "go.mau.fi/mautrix-signal/pkg/signalmeow/store" + "go.mau.fi/mautrix-signal/pkg/signalmeow/types" "maunium.net/go/mautrix/bridgev2" "maunium.net/go/mautrix/bridgev2/database" + "maunium.net/go/mautrix/bridgev2/networkid" "maunium.net/go/mautrix/event" "maunium.net/go/mautrix/id" ) @@ -41,14 +52,65 @@ type SignalConnector struct { Bridge *bridgev2.Bridge } +func (s *SignalConnector) Init(bridge *bridgev2.Bridge) { + s.Bridge = bridge + s.MsgConv = &msgconv.MessageConverter{ + PortalMethods: &msgconvPortalMethods{}, + SignalFmtParams: &signalfmt.FormatParams{ + GetUserInfo: func(ctx context.Context, uuid uuid.UUID) signalfmt.UserInfo { + ghost, err := s.Bridge.GetGhostByID(ctx, makeUserID(uuid)) + if err != nil { + // TODO log? + return signalfmt.UserInfo{} + } + userInfo := signalfmt.UserInfo{ + MXID: ghost.MXID, + Name: ghost.Name, + } + userLogin := s.Bridge.GetCachedUserLoginByID(networkid.UserLoginID(uuid.String())) + if userLogin != nil { + userInfo.MXID = userLogin.UserMXID + // TODO find matrix user displayname? + } + return userInfo + }, + }, + MatrixFmtParams: &matrixfmt.HTMLParser{ + GetUUIDFromMXID: func(ctx context.Context, userID id.UserID) uuid.UUID { + parsed, ok := s.Bridge.Matrix.ParseGhostMXID(userID) + if ok { + u, _ := uuid.Parse(string(parsed)) + return u + } + user, _ := s.Bridge.GetExistingUserByMXID(ctx, userID) + // TODO log errors? + if user != nil { + preferredLogin, _ := ctx.Value(msgconvContextKey).(*msgconvContext).Portal.FindPreferredLogin(ctx, user) + if preferredLogin != nil { + u, _ := uuid.Parse(string(preferredLogin.ID)) + return u + } + } + return uuid.Nil + }, + }, + ConvertVoiceMessages: true, + ConvertGIFToAPNG: true, + MaxFileSize: 100 * 1024 * 1024, + AsyncFiles: true, + LocationFormat: "", + } +} + var _ bridgev2.NetworkConnector = (*SignalConnector)(nil) -var _ msgconv.PortalMethods = (*SignalConnector)(nil) +var _ bridgev2.NetworkAPI = (*SignalClient)(nil) +var _ msgconv.PortalMethods = (*msgconvPortalMethods)(nil) func (s *SignalConnector) NewLogin(user *bridgev2.User) *database.UserLogin { - + return nil } -func (s *SignalConnector) Prepare(ctx context.Context, login *bridgev2.UserLogin) error { +func (s *SignalConnector) PrepareLogin(ctx context.Context, login *bridgev2.UserLogin) error { aci, err := uuid.Parse(string(login.ID)) if err != nil { return fmt.Errorf("failed to parse user login ID: %w", err) @@ -60,7 +122,8 @@ func (s *SignalConnector) Prepare(ctx context.Context, login *bridgev2.UserLogin return fmt.Errorf("%w: device not found in store", bridgev2.ErrNotLoggedIn) } sc := &SignalClient{ - Main: s, + Main: s, + UserLogin: login, Client: &signalmeow.Client{ Store: device, }, @@ -71,24 +134,151 @@ func (s *SignalConnector) Prepare(ctx context.Context, login *bridgev2.UserLogin } type SignalClient struct { - Main *SignalConnector - Client *signalmeow.Client + Main *SignalConnector + UserLogin *bridgev2.UserLogin + Client *signalmeow.Client } -var _ bridgev2.NetworkAPI = (*SignalClient)(nil) +func (s *SignalClient) GetChatInfo(ctx context.Context, portal *bridgev2.Portal) (*bridgev2.PortalInfo, error) { + return &bridgev2.PortalInfo{}, nil +} func (s *SignalClient) Connect(ctx context.Context) error { - statusChan, err := s.Client.StartReceiveLoops(ctx) + _, err := s.Client.StartReceiveLoops(ctx) + if err != nil { + return err + } + // TODO status + return nil } func (s *SignalClient) IsLoggedIn() bool { return s.Client.IsLoggedIn() } +func (s *SignalClient) parsePortalID(portalID networkid.PortalID) (string, error) { + parts := strings.Split(string(portalID), "|") + if len(parts) == 1 { + if len(parts[0]) == 44 { + return parts[0], nil + } + return "", fmt.Errorf("invalid portal ID: expected group ID to be 44 characters") + } else if len(parts) == 2 { + ourACI := s.Client.Store.ACI.String() + if parts[0] == ourACI { + return parts[1], nil + } else if parts[1] == ourACI { + return parts[0], nil + } else { + return "", fmt.Errorf("invalid portal ID: expected one side to be our ACI") + } + } + return "", fmt.Errorf("invalid portal ID: unexpected number of pipe-separated parts") +} + +func (s *SignalClient) getPortalID(chatID string) networkid.PortalID { + if len(chatID) == 44 { + // Group ID + return networkid.PortalID(chatID) + } else if strings.HasPrefix(chatID, "PNI:") { + // Temporary new DM ID: always put our own ACI first, the portal will never be shared anyway + return networkid.PortalID(fmt.Sprintf("%s|%s", s.Client.Store.ACI, chatID)) + } else { + // DM ID: sort the two parts so the ID is always the same regardless of which side is receiving the message + parts := []string{s.Client.Store.ACI.String(), chatID} + slices.Sort(parts) + return networkid.PortalID(strings.Join(parts, "|")) + } +} + +func makeMessageID(sender uuid.UUID, timestamp uint64) networkid.MessageID { + return networkid.MessageID(fmt.Sprintf("%s|%d", sender, timestamp)) +} + +func makeUserID(user uuid.UUID) networkid.UserID { + return networkid.UserID(user.String()) +} + +func makeUserLoginID(user uuid.UUID) networkid.UserLoginID { + return networkid.UserLoginID(user.String()) +} + +func (s *SignalClient) makeEventSender(sender uuid.UUID) bridgev2.EventSender { + return bridgev2.EventSender{ + IsFromMe: sender == s.Client.Store.ACI, + SenderLogin: makeUserLoginID(sender), + Sender: makeUserID(sender), + } +} + +func makeMessagePartID(index int) networkid.PartID { + if index == 0 { + return "" + } + return networkid.PartID(strconv.Itoa(index)) +} + +type contextKey int + +var msgconvContextKey contextKey + +type msgconvContext struct { + Connector *SignalConnector + Intent bridgev2.MatrixAPI + Client *SignalClient + Portal *bridgev2.Portal + ReplyTo *database.Message +} + +func (s *SignalClient) convertMessage(ctx context.Context, portal *bridgev2.Portal, data *events.ChatEvent) (*bridgev2.ConvertedMessage, error) { + dataMsg := data.Event.(*signalpb.DataMessage) + converted := s.Main.MsgConv.ToMatrix(ctx, dataMsg) + var replyTo *networkid.MessageOptionalPartID + if dataMsg.GetQuote() != nil { + quoteAuthor, _ := uuid.Parse(dataMsg.Quote.GetAuthorAci()) + replyTo = &networkid.MessageOptionalPartID{ + MessageID: makeMessageID(quoteAuthor, dataMsg.Quote.GetId()), + } + } + convertedParts := make([]*bridgev2.ConvertedMessagePart, len(converted.Parts)) + for i, part := range converted.Parts { + convertedParts[i] = &bridgev2.ConvertedMessagePart{ + ID: makeMessagePartID(i), + Type: part.Type, + Content: part.Content, + Extra: part.Extra, + } + + } + return &bridgev2.ConvertedMessage{ + ID: makeMessageID(data.Info.Sender, dataMsg.GetTimestamp()), + EventSender: s.makeEventSender(data.Info.Sender), + Timestamp: time.UnixMilli(int64(converted.Timestamp)), + ReplyTo: replyTo, + Parts: convertedParts, + }, nil +} + func (s *SignalClient) handleSignalEvent(rawEvt events.SignalEvent) { switch evt := rawEvt.(type) { case *events.ChatEvent: - portal := s.Main.Bridge.GetPortalByID(evt.Info.ChatID) + switch innerEvt := evt.Event.(type) { + case *signalpb.DataMessage: + s.Main.Bridge.QueueRemoteEvent(s.UserLogin, &bridgev2.SimpleRemoteEvent[*events.ChatEvent]{ + Type: bridgev2.RemoteEventMessage, + LogContext: func(c zerolog.Context) zerolog.Context { + return c. + Uint64("message_id", innerEvt.GetTimestamp()). + Stringer("sender_id", evt.Info.Sender) + }, + PortalID: s.getPortalID(evt.Info.ChatID), + Data: evt, + + ConvertMessageFunc: s.convertMessage, + }) + case *signalpb.EditMessage: + case *signalpb.TypingMessage: + } case *events.DecryptionError: case *events.Receipt: case *events.ReadSelf: @@ -98,57 +288,146 @@ func (s *SignalClient) handleSignalEvent(rawEvt events.SignalEvent) { } } -func (s *SignalConnector) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (messageID string, err error) { - //TODO implement me - panic("implement me") +func (s *SignalClient) HandleMatrixMessage(ctx context.Context, msg *bridgev2.MatrixMessage) (message *database.Message, err error) { + mcCtx := &msgconvContext{ + Connector: s.Main, + Intent: nil, + Client: s, + Portal: msg.Portal, + ReplyTo: msg.ReplyTo, + } + ctx = context.WithValue(ctx, msgconvContextKey, mcCtx) + chatID, err := s.parsePortalID(msg.Portal.ID) + if err != nil { + return nil, err + } + var userID libsignalgo.ServiceID + var groupID types.GroupIdentifier + if len(chatID) == 44 { + groupID = types.GroupIdentifier(chatID) + } else { + userID, err = libsignalgo.ServiceIDFromString(chatID) + if err != nil { + return nil, err + } + } + converted, err := s.Main.MsgConv.ToSignal(ctx, msg.Event, msg.Content, msg.OrigSender != nil) + if err != nil { + return nil, err + } + wrappedContent := &signalpb.Content{ + DataMessage: converted, + } + if groupID != "" { + res, err := s.Client.SendGroupMessage(ctx, groupID, wrappedContent) + if err != nil { + return nil, err + } + // TODO check result + fmt.Println(res) + } else { + res := s.Client.SendMessage(ctx, userID, wrappedContent) + // TODO check result + fmt.Println(res) + } + meta := map[string]any{ + "reply_to_file": len(converted.Attachments) > 0, + } + dbMsg := &database.Message{ + ID: makeMessageID(s.Client.Store.ACI, converted.GetTimestamp()), + MXID: msg.Event.ID, + RoomID: msg.Portal.ID, + SenderID: makeUserID(s.Client.Store.ACI), + Timestamp: time.UnixMilli(int64(converted.GetTimestamp())), + Metadata: meta, + } + if msg.ReplyTo != nil { + dbMsg.RelatesToRowID = msg.ReplyTo.RowID + } + return dbMsg, nil } -func (s *SignalConnector) HandleMatrixEdit(ctx context.Context, msg *bridgev2.MatrixEdit) error { +func (s *SignalClient) HandleMatrixEdit(ctx context.Context, msg *bridgev2.MatrixEdit) error { //TODO implement me panic("implement me") } -func (s *SignalConnector) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (emojiID string, err error) { +func (s *SignalClient) HandleMatrixReaction(ctx context.Context, msg *bridgev2.MatrixReaction) (emojiID networkid.EmojiID, err error) { //TODO implement me panic("implement me") } -func (s *SignalConnector) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { +func (s *SignalClient) HandleMatrixReactionRemove(ctx context.Context, msg *bridgev2.MatrixReactionRemove) error { //TODO implement me panic("implement me") } -func (s *SignalConnector) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { +func (s *SignalClient) HandleMatrixMessageRemove(ctx context.Context, msg *bridgev2.MatrixMessageRemove) error { //TODO implement me panic("implement me") } -func (s *SignalConnector) UploadMatrixMedia(ctx context.Context, data []byte, fileName, contentType string) (id.ContentURIString, error) { - //TODO implement me - panic("implement me") +type msgconvPortalMethods struct{} + +func (mpm *msgconvPortalMethods) UploadMatrixMedia(ctx context.Context, data []byte, fileName, contentType string) (id.ContentURIString, error) { + mcCtx := ctx.Value(msgconvContextKey).(*msgconvContext) + uri, _, err := mcCtx.Intent.UploadMedia(ctx, "", data, fileName, contentType) + return uri, err } -func (s *SignalConnector) DownloadMatrixMedia(ctx context.Context, uri id.ContentURIString) ([]byte, error) { - //TODO implement me - panic("implement me") +func (mpm *msgconvPortalMethods) DownloadMatrixMedia(ctx context.Context, uri id.ContentURIString) ([]byte, error) { + return ctx.Value(msgconvContextKey).(*msgconvContext).Connector.Bridge.Bot.DownloadMedia(ctx, uri, nil) } -func (s *SignalConnector) GetMatrixReply(ctx context.Context, msg *signalpb.DataMessage_Quote) (replyTo id.EventID, replyTargetSender id.UserID) { - //TODO implement me - panic("implement me") +func (mpm *msgconvPortalMethods) GetMatrixReply(ctx context.Context, msg *signalpb.DataMessage_Quote) (replyTo id.EventID, replyTargetSender id.UserID) { + // Matrix replies are handled in bridgev2 code + return "", "" } -func (s *SignalConnector) GetSignalReply(ctx context.Context, content *event.MessageEventContent) *signalpb.DataMessage_Quote { - //TODO implement me - panic("implement me") +func (mpm *msgconvPortalMethods) GetSignalReply(ctx context.Context, content *event.MessageEventContent) *signalpb.DataMessage_Quote { + mcCtx := ctx.Value(msgconvContextKey).(*msgconvContext) + if mcCtx.ReplyTo == nil { + return nil + } + quote := &signalpb.DataMessage_Quote{ + Id: proto.Uint64(uint64(mcCtx.ReplyTo.Timestamp.UnixMilli())), + AuthorAci: proto.String(string(mcCtx.ReplyTo.SenderID)), + Type: signalpb.DataMessage_Quote_NORMAL.Enum(), + } + if mcCtx.ReplyTo.Metadata["reply_to_file"] != false { + quote.Attachments = make([]*signalpb.DataMessage_Quote_QuotedAttachment, 1) + } + return quote } -func (s *SignalConnector) GetClient(ctx context.Context) *signalmeow.Client { - //TODO implement me - panic("implement me") +func (mpm *msgconvPortalMethods) GetClient(ctx context.Context) *signalmeow.Client { + return ctx.Value(msgconvContextKey).(*msgconvContext).Client.Client } -func (s *SignalConnector) GetData(ctx context.Context) *legacydb.Portal { - //TODO implement me - panic("implement me") +func (mpm *msgconvPortalMethods) GetData(ctx context.Context) *legacydb.Portal { + mcCtx := ctx.Value(msgconvContextKey).(*msgconvContext) + portal := mcCtx.Portal + chatID, _ := mcCtx.Client.parsePortalID(portal.ID) + pk := legacydb.PortalKey{ + ChatID: chatID, + } + if len(chatID) != 44 { + pk.Receiver = mcCtx.Client.Client.Store.ACI + } + return &legacydb.Portal{ + PortalKey: pk, + MXID: portal.MXID, + Name: portal.Name, + Topic: portal.Topic, + //AvatarPath: "", + //AvatarHash: "", + //AvatarURL: id.ContentURI{}, + NameSet: portal.NameSet, + AvatarSet: portal.AvatarSet, + TopicSet: portal.TopicSet, + //Revision: portal.Metadata["revision"].(uint32), + Encrypted: true, + //RelayUserID: portal.Relay.UserMXID, + //ExpirationTime: portal.Metadata["expiration_timer"].(uint32), + } } diff --git a/go.mod b/go.mod index 5c096660..fce64a70 100644 --- a/go.mod +++ b/go.mod @@ -15,12 +15,12 @@ require ( github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/stretchr/testify v1.9.0 github.com/tidwall/gjson v1.17.1 - go.mau.fi/util v0.4.2 + go.mau.fi/util v0.4.3-0.20240430151139-91c8483608d4 golang.org/x/crypto v0.22.0 golang.org/x/exp v0.0.0-20240409090435-93d18d7e34b8 golang.org/x/net v0.24.0 google.golang.org/protobuf v1.33.0 - maunium.net/go/mautrix v0.18.1 + maunium.net/go/mautrix v0.18.2-0.20240509095639-c18758a143e2 nhooyr.io/websocket v1.8.11 ) @@ -49,4 +49,4 @@ require ( maunium.net/go/mauflag v1.0.0 // indirect ) -replace maunium.net/go/mautrix => ../mautrix-go +//replace maunium.net/go/mautrix => ../mautrix-go diff --git a/go.sum b/go.sum index aaf41a79..b884ac9a 100644 --- a/go.sum +++ b/go.sum @@ -69,8 +69,8 @@ github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= -go.mau.fi/util v0.4.2 h1:RR3TOcRHmCF9Bx/3YG4S65MYfa+nV6/rn8qBWW4Mi30= -go.mau.fi/util v0.4.2/go.mod h1:PlAVfUUcPyHPrwnvjkJM9UFcPE7qGPDJqk+Oufa1Gtw= +go.mau.fi/util v0.4.3-0.20240430151139-91c8483608d4 h1:nNIMwMiqJmb18o4+OPDC946DNtds0sb1fbmaw0xsvPE= +go.mau.fi/util v0.4.3-0.20240430151139-91c8483608d4/go.mod h1:PlAVfUUcPyHPrwnvjkJM9UFcPE7qGPDJqk+Oufa1Gtw= go.mau.fi/zeroconfig v0.1.2 h1:DKOydWnhPMn65GbXZOafgkPm11BvFashZWLct0dGFto= go.mau.fi/zeroconfig v0.1.2/go.mod h1:NcSJkf180JT+1IId76PcMuLTNa1CzsFFZ0nBygIQM70= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= @@ -95,7 +95,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M= maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA= -maunium.net/go/mautrix v0.18.1 h1:a6mUsJixegBNTXUoqC5RQ9gsumIPzKvCubKwF+zmCt4= -maunium.net/go/mautrix v0.18.1/go.mod h1:2oHaq792cSXFGvxLvYw3Gf1L4WVVP4KZcYys5HVk/h8= +maunium.net/go/mautrix v0.18.2-0.20240509095639-c18758a143e2 h1:BbaCuWiwK5mFwwt6Hmapff6jMDlQEoaKjFt1K6PPDNE= +maunium.net/go/mautrix v0.18.2-0.20240509095639-c18758a143e2/go.mod h1:r3vg2vlvbT6nyB8GdSPgYb9C6HKo3ZW3OBEoOsTdK18= nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0= nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=