diff --git a/provisioning.go b/provisioning.go index 689177dc..10fb4bd6 100644 --- a/provisioning.go +++ b/provisioning.go @@ -17,12 +17,10 @@ package main import ( - "bufio" "context" "encoding/json" "errors" "fmt" - "net" "net/http" _ "net/http/pprof" "strconv" @@ -34,11 +32,18 @@ import ( "github.com/gorilla/mux" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" "go.mau.fi/mautrix-signal/pkg/signalmeow" ) +type provisioningContextKey int + +const ( + provisioningUserKey provisioningContextKey = iota +) + type provisioningHandle struct { id int context context.Context @@ -77,50 +82,26 @@ func (prov *ProvisioningAPI) Init() { } } -type responseWrap struct { - http.ResponseWriter - statusCode int -} - -func jsonResponse(w http.ResponseWriter, status int, response interface{}) { +func jsonResponse(w http.ResponseWriter, status int, response any) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(response) } -var _ http.Hijacker = (*responseWrap)(nil) - -func (rw *responseWrap) WriteHeader(statusCode int) { - rw.ResponseWriter.WriteHeader(statusCode) - rw.statusCode = statusCode -} - -func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := rw.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("response does not implement http.Hijacker") - } - return hijacker.Hijack() -} - func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - if strings.HasPrefix(auth, "Bearer ") { - auth = auth[len("Bearer "):] - } + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret { - prov.log.Info().Msg("Authentication token does not match shared secret") - jsonResponse(w, http.StatusForbidden, map[string]interface{}{ - "error": "Authentication token does not match shared secret", - "errcode": "M_FORBIDDEN", + zerolog.Ctx(r.Context()).Warn().Msg("Authentication token does not match shared secret") + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "Authentication token does not match shared secret", + ErrCode: mautrix.MForbidden.ErrCode, }) return } userID := r.URL.Query().Get("user_id") user := prov.bridge.GetUserByMXID(id.UserID(userID)) - wWrap := &responseWrap{w, 200} - h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) + h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), provisioningUserKey, user))) }) } @@ -199,7 +180,7 @@ func (prov *ProvisioningAPI) resolveIdentifier(user *User, phoneNum string) (int } func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) phoneNum, _ := mux.Vars(r)["phonenum"] log := prov.log.With(). @@ -233,7 +214,7 @@ func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Re } func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) phoneNum, _ := mux.Vars(r)["phonenum"] log := prov.log.With(). @@ -373,7 +354,7 @@ func (prov *ProvisioningAPI) loginOrSendError(ctx context.Context, w http.Respon func (prov *ProvisioningAPI) checkSessionAndReturnHandle(ctx context.Context, w http.ResponseWriter, currentSession int) *provisioningHandle { log := zerolog.Ctx(ctx).With().Str("function", "checkSessionAndReturnHandle").Logger() - user := ctx.Value("user").(*User) + user := ctx.Value(provisioningUserKey).(*User) handle := prov.existingSession(user) if handle == nil { log.Warn().Msg("no session found") @@ -402,7 +383,7 @@ func (prov *ProvisioningAPI) checkSessionAndReturnHandle(ctx context.Context, w // ** Provisioning API ** // func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) log := prov.log.With(). Str("action", "link_new"). Str("user_id", user.MXID.String()). @@ -466,7 +447,7 @@ type LinkWaitForScanRequest struct { } func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) var body LinkWaitForScanRequest err := json.NewDecoder(r.Body).Decode(&body) if err != nil { @@ -559,7 +540,7 @@ type LinkWaitForAccountRequest struct { } func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) var body LinkWaitForAccountRequest err := json.NewDecoder(r.Body).Decode(&body) if err != nil { @@ -639,7 +620,7 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R } func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) log := prov.log.With(). Str("action", "logout"). Str("user_id", user.MXID.String()).