Skip to content

Commit

Permalink
Move login command to mautrix-go
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Jun 3, 2024
1 parent c39813f commit 18116ea
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 169 deletions.
2 changes: 1 addition & 1 deletion connector/connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ var _ bridgev2.NetworkConnector = (*SignalConnector)(nil)
var _ bridgev2.NetworkAPI = (*SignalClient)(nil)
var _ msgconv.PortalMethods = (*msgconvPortalMethods)(nil)

func (s *SignalConnector) PrepareLogin(ctx context.Context, login *bridgev2.UserLogin) error {
func (s *SignalConnector) LoadUserLogin(ctx context.Context, login *bridgev2.UserLogin) error {
aci, err := uuid.Parse(string(login.ID))
if err != nil {
return fmt.Errorf("failed to parse user login ID: %w", err)
Expand Down
155 changes: 155 additions & 0 deletions connector/login.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
// mautrix-signal - A Matrix-Signal puppeting bridge.
// Copyright (C) 2024 Tulir Asokan
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

package connector

import (
"context"
"fmt"

"github.com/google/uuid"

"go.mau.fi/mautrix-signal/pkg/signalmeow"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/database"
)

func (s *SignalConnector) GetLoginFlows() []bridgev2.LoginFlow {
return []bridgev2.LoginFlow{{
Name: "QR",
Description: "Scan a QR code to pair the bridge to your Signal app",
ID: "qr",
}}
}

func (s *SignalConnector) CreateLogin(ctx context.Context, user *bridgev2.User, flowID string) (bridgev2.LoginProcess, error) {
if flowID != "qr" {
return nil, fmt.Errorf("invalid login flow ID")
}
return &QRLogin{User: user, Main: s}, nil
}

type QRLogin struct {
User *bridgev2.User
Main *SignalConnector
cancelChan context.CancelFunc
ProvChan chan signalmeow.ProvisioningResponse
}

var _ bridgev2.LoginProcessDisplayAndWait = (*QRLogin)(nil)

func (qr *QRLogin) Cancel() {
qr.cancelChan()
go func() {
for range qr.ProvChan {
}
}()
}

func (qr *QRLogin) Start(ctx context.Context) (*bridgev2.LoginStep, error) {
log := qr.Main.Bridge.Log.With().
Str("action", "login").
Stringer("user_id", qr.User.MXID).
Logger()
provCtx, cancel := context.WithCancel(log.WithContext(context.Background()))
qr.cancelChan = cancel
// Don't use the start context here: the channel will outlive the start request.
qr.ProvChan = signalmeow.PerformProvisioning(provCtx, qr.Main.Store, qr.Main.Config.DeviceName)
var resp signalmeow.ProvisioningResponse
select {
case resp = <-qr.ProvChan:
if resp.Err != nil || resp.State == signalmeow.StateProvisioningError {
return nil, resp.Err
} else if resp.State != signalmeow.StateProvisioningURLReceived {
return nil, fmt.Errorf("unexpected state %v", resp.State)
}
case <-ctx.Done():
cancel()
return nil, ctx.Err()
// TODO separate timeout here?
}
return &bridgev2.LoginStep{
Type: bridgev2.LoginStepTypeDisplayAndWait,
StepID: "fi.mau.signal.login.qr",
Instructions: "Scan the QR code on your Signal app to log in",
DisplayAndWaitParams: &bridgev2.LoginDisplayAndWaitParams{
Type: bridgev2.LoginDisplayTypeQR,
Data: resp.ProvisioningURL,
},
}, nil
}

func (qr *QRLogin) Wait(ctx context.Context) (*bridgev2.LoginStep, error) {
if qr.ProvChan == nil {
return nil, fmt.Errorf("login not started")
}
defer qr.cancelChan()

var signalID uuid.UUID
var signalPhone string
select {
case resp := <-qr.ProvChan:
if resp.Err != nil || resp.State == signalmeow.StateProvisioningError {
return nil, resp.Err
} else if resp.State != signalmeow.StateProvisioningDataReceived {
return nil, fmt.Errorf("unexpected state %v", resp.State)
} else if resp.ProvisioningData.ACI == uuid.Nil {
return nil, fmt.Errorf("no signal account ID received")
}
signalID = resp.ProvisioningData.ACI
signalPhone = resp.ProvisioningData.Number
case <-ctx.Done():
return nil, ctx.Err()
}

select {
case resp := <-qr.ProvChan:
if resp.Err != nil || resp.State == signalmeow.StateProvisioningError {
return nil, resp.Err
} else if resp.State != signalmeow.StateProvisioningPreKeysRegistered {
return nil, fmt.Errorf("unexpected state %v", resp.State)
}
case <-ctx.Done():
return nil, ctx.Err()
}

ul, err := qr.User.NewLogin(ctx, &database.UserLogin{
ID: makeUserLoginID(signalID),
Metadata: map[string]any{
"phone": signalPhone,
},
}, nil)
if err != nil {
return nil, fmt.Errorf("failed to save new login: %w", err)
}
backgroundCtx := ul.Log.WithContext(context.Background())
err = qr.Main.LoadUserLogin(backgroundCtx, ul)
if err != nil {
return nil, fmt.Errorf("failed to prepare connection after login: %w", err)
}
err = ul.Client.Connect(backgroundCtx)
if err != nil {
return nil, fmt.Errorf("failed to connect after login: %w", err)
}
return &bridgev2.LoginStep{
Type: bridgev2.LoginStepTypeComplete,
StepID: "fi.mau.signal.login.complete",
Instructions: fmt.Sprintf("Successfully logged in as %s / %s", signalPhone, signalID),
CompleteParams: &bridgev2.LoginCompleteParams{
UserLoginID: ul.ID,
},
}, nil
}
166 changes: 2 additions & 164 deletions connector/mautrix-signal-v2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,18 @@
package main

import (
"fmt"
"os"
"strings"
"time"

"github.com/google/uuid"
"github.com/skip2/go-qrcode"
"go.mau.fi/util/dbutil"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/exzerolog"
"gopkg.in/yaml.v3"

"go.mau.fi/mautrix-signal/connector"
"go.mau.fi/mautrix-signal/pkg/signalmeow"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/bridgeconfig"
"maunium.net/go/mautrix/bridgev2/database"
"maunium.net/go/mautrix/bridgev2/matrix"
"maunium.net/go/mautrix/bridgev2/networkid"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"

"go.mau.fi/mautrix-signal/connector"
)

func main() {
Expand All @@ -47,163 +37,11 @@ func main() {
exerrors.PanicIfNotNil(yaml.Unmarshal(config, &cfg))
log := exerrors.Must(cfg.Logging.Compile())
exzerolog.SetupDefaults(log)
signalmeow.SetLogger(log.With().Str("component", "signalmeow").Logger())
db := exerrors.Must(dbutil.NewFromConfig("mautrix-signal", cfg.Database, dbutil.ZeroLogger(log.With().Str("db_section", "main").Logger())))
signalConnector := connector.NewConnector()
exerrors.PanicIfNotNil(cfg.Network.Decode(signalConnector.Config))
bridge := bridgev2.NewBridge("", db, *log, matrix.NewConnector(&cfg), signalConnector)
bridge.CommandPrefix = "!signal"
bridge.Commands.AddHandlers(&bridgev2.FullHandler{
Func: fnLogin,
Name: "login",
Help: bridgev2.HelpMeta{
Section: bridgev2.HelpSectionAuth,
Description: "Log into Signal",
},
})
bridge.Start()
}

func sendQR(ce *bridgev2.CommandEvent, code string, prevQR, prevMsg id.EventID) (qr, msg id.EventID) {
content, ok := uploadQR(ce, code)
if !ok {
return prevQR, prevMsg
}
if len(prevQR) != 0 {
content.SetEdit(prevQR)
}
resp, err := ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, time.Now())
if err != nil {
ce.Log.Err(err).Msg("Failed to send QR code to user")
} else if len(prevQR) == 0 {
prevQR = resp.EventID
}
content = event.MessageEventContent{
MsgType: event.MsgNotice,
Body: fmt.Sprintf("Raw linking URI: %s", code),
Format: event.FormatHTML,
FormattedBody: fmt.Sprintf("Raw linking URI: <code>%s</code>", code),
}
if len(prevMsg) != 0 {
content.SetEdit(prevMsg)
}
resp, err = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventMessage, &event.Content{Parsed: &content}, time.Now())
if err != nil {
ce.Log.Err(err).Msg("Failed to send raw code to user")
} else if len(prevMsg) == 0 {
prevMsg = resp.EventID
}
return prevQR, prevMsg
}

func uploadQR(ce *bridgev2.CommandEvent, code string) (event.MessageEventContent, bool) {
const size = 512
qrCode, err := qrcode.Encode(code, qrcode.Low, size)
if err != nil {
ce.Log.Err(err).Msg("Failed to encode QR code")
ce.Reply("Failed to encode QR code: %v", err)
return event.MessageEventContent{}, false
}

uri, file, err := ce.Bot.UploadMedia(ce.Ctx, ce.RoomID, qrCode, "qr.png", "image/png")
if err != nil {
ce.Log.Err(err).Msg("Failed to upload QR code")
ce.Reply("Failed to upload QR code: %v", err)
return event.MessageEventContent{}, false
}
return event.MessageEventContent{
MsgType: event.MsgImage,
Info: &event.FileInfo{
MimeType: "image/png",
Width: size,
Height: size,
Size: len(qrCode),
},
Body: "qr.png",
URL: uri,
File: file,
}, true
}
func fnLogin(ce *bridgev2.CommandEvent) {
signal := ce.Bridge.Network.(*connector.SignalConnector)
// TODO configurable device name
provChan := signalmeow.PerformProvisioning(ce.Ctx, signal.Store, "Mautrix-Signal Megabridge")

resp := <-provChan
if resp.Err != nil || resp.State == signalmeow.StateProvisioningError {
ce.Reply("Error getting provisioning URL: %v", resp.Err)
return
}
var qrEventID, msgEventID id.EventID
if resp.State == signalmeow.StateProvisioningURLReceived {
qrEventID, msgEventID = sendQR(ce, resp.ProvisioningURL, qrEventID, msgEventID)
} else {
ce.Reply("Unexpected state: %v", resp.State)
return
}

// Next, get the results of finishing registration
resp = <-provChan
_, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
Parsed: &event.RedactionEventContent{
Redacts: qrEventID,
},
}, time.Now())
_, _ = ce.Bot.SendMessage(ce.Ctx, ce.RoomID, event.EventRedaction, &event.Content{
Parsed: &event.RedactionEventContent{
Redacts: msgEventID,
},
}, time.Now())
if resp.Err != nil || resp.State == signalmeow.StateProvisioningError {
if resp.Err != nil && strings.HasSuffix(resp.Err.Error(), " EOF") {
ce.Reply("Logging in timed out, please try again.")
} else {
ce.Reply("Error finishing registration: %v", resp.Err)
}
return
}
var signalID uuid.UUID
var signalPhone string
if resp.State == signalmeow.StateProvisioningDataReceived {
signalID = resp.ProvisioningData.ACI
signalPhone = resp.ProvisioningData.Number
} else {
ce.Reply("Unexpected state: %v", resp.State)
return
}

// Finally, get the results of generating and registering prekeys
resp = <-provChan
if resp.Err != nil || resp.State == signalmeow.StateProvisioningError {
ce.Reply("Error with prekeys: %v", resp.Err)
return
} else if resp.State != signalmeow.StateProvisioningPreKeysRegistered {
ce.Reply("Unexpected state: %v", resp.State)
return
}

if signalID == uuid.Nil {
ce.Reply("Problem logging in - No SignalID received")
return
}
ul, err := ce.User.NewLogin(ce.Ctx, &database.UserLogin{
ID: networkid.UserLoginID(signalID.String()),
Metadata: map[string]any{
"phone": signalPhone,
},
}, nil)
if err != nil {
ce.Reply("Failed to save new login: %v", err)
return
}
err = ce.Bridge.Network.PrepareLogin(ce.Ctx, ul)
if err != nil {
ce.Reply("Failed to prepare connection after login: %v", err)
return
}
err = ul.Client.Connect(ce.Ctx)
if err != nil {
ce.Reply("Failed to connect after login: %v", err)
return
}
ce.Reply("Successfully logged in as %s (UUID: %s)", signalPhone, signalID)
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ require (
golang.org/x/net v0.25.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mautrix v0.18.2-0.20240529135554-248de0e6adb2
maunium.net/go/mautrix v0.18.2-0.20240603193336-a599b15466ae
nhooyr.io/websocket v1.8.11
)

Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/mautrix v0.18.2-0.20240529135554-248de0e6adb2 h1:AUKv3tqpdFerCw2X8m05BGfhtP3vH8cDuEtAGxwuUl0=
maunium.net/go/mautrix v0.18.2-0.20240529135554-248de0e6adb2/go.mod h1:Ln4XquIKL5MttTUGNUSbiEGX3XYC0P6jzT9XjLFFPdY=
maunium.net/go/mautrix v0.18.2-0.20240603193336-a599b15466ae h1:PlT6saNJNjRT3i04LNLsFAC5ewZU1HrxBSM4V/Aze7k=
maunium.net/go/mautrix v0.18.2-0.20240603193336-a599b15466ae/go.mod h1:P/FV8cXY262MezYX7ViuhfzeJ0nK4+M8K6ZmxEC/aEA=
nhooyr.io/websocket v1.8.11 h1:f/qXNc2/3DpoSZkHt1DQu6rj4zGC8JmkkLkWss0MgN0=
nhooyr.io/websocket v1.8.11/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c=
2 changes: 1 addition & 1 deletion pkg/signalmeow/provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ type ProvisioningResponse struct {

func PerformProvisioning(ctx context.Context, deviceStore store.DeviceStore, deviceName string) chan ProvisioningResponse {
log := zerolog.Ctx(ctx).With().Str("action", "perform provisioning").Logger()
c := make(chan ProvisioningResponse)
c := make(chan ProvisioningResponse, 4)
go func() {
defer close(c)

Expand Down

0 comments on commit 18116ea

Please sign in to comment.