diff --git a/provisioning.go b/provisioning.go index 3218702c..c95a308a 100644 --- a/provisioning.go +++ b/provisioning.go @@ -17,12 +17,10 @@ package main import ( - "bufio" "context" "encoding/json" "errors" "fmt" - "net" "net/http" _ "net/http/pprof" "strconv" @@ -34,11 +32,18 @@ import ( "github.com/gorilla/mux" "github.com/rs/zerolog" "github.com/rs/zerolog/hlog" + "maunium.net/go/mautrix" "maunium.net/go/mautrix/id" "go.mau.fi/mautrix-signal/pkg/signalmeow" ) +type provisioningContextKey int + +const ( + provisioningUserKey provisioningContextKey = iota +) + type provisioningHandle struct { id int context context.Context @@ -77,50 +82,26 @@ func (prov *ProvisioningAPI) Init() { } } -type responseWrap struct { - http.ResponseWriter - statusCode int -} - -func jsonResponse(w http.ResponseWriter, status int, response interface{}) { +func jsonResponse(w http.ResponseWriter, status int, response any) { w.Header().Add("Content-Type", "application/json") w.WriteHeader(status) _ = json.NewEncoder(w).Encode(response) } -var _ http.Hijacker = (*responseWrap)(nil) - -func (rw *responseWrap) WriteHeader(statusCode int) { - rw.ResponseWriter.WriteHeader(statusCode) - rw.statusCode = statusCode -} - -func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) { - hijacker, ok := rw.ResponseWriter.(http.Hijacker) - if !ok { - return nil, nil, errors.New("response does not implement http.Hijacker") - } - return hijacker.Hijack() -} - func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - auth := r.Header.Get("Authorization") - if strings.HasPrefix(auth, "Bearer ") { - auth = auth[len("Bearer "):] - } + auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret { - prov.log.Info().Msg("Authentication token does not match shared secret") - jsonResponse(w, http.StatusForbidden, map[string]interface{}{ - "error": "Authentication token does not match shared secret", - "errcode": "M_FORBIDDEN", + zerolog.Ctx(r.Context()).Warn().Msg("Authentication token does not match shared secret") + jsonResponse(w, http.StatusForbidden, &mautrix.RespError{ + Err: "Authentication token does not match shared secret", + ErrCode: mautrix.MForbidden.ErrCode, }) return } userID := r.URL.Query().Get("user_id") user := prov.bridge.GetUserByMXID(id.UserID(userID)) - wWrap := &responseWrap{w, 200} - h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user))) + h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), provisioningUserKey, user))) }) } @@ -202,7 +183,7 @@ func (prov *ProvisioningAPI) resolveIdentifier(user *User, phoneNum string) (int } func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) phoneNum, _ := mux.Vars(r)["phonenum"] prov.log.Debug().Msgf("ResolveIdentifier from %v, phone number: %v", user.MXID, phoneNum) @@ -230,7 +211,7 @@ func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Re } func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) phoneNum, _ := mux.Vars(r)["phonenum"] prov.log.Debug().Msgf("StartPM from %v, phone number: %v", user.MXID, phoneNum) @@ -372,7 +353,7 @@ func (prov *ProvisioningAPI) loginOrSendError(w http.ResponseWriter, user *User) } func (prov *ProvisioningAPI) checkSessionAndReturnHandle(w http.ResponseWriter, r *http.Request, currentSession int) *provisioningHandle { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) handle := prov.existingSession(user) if handle == nil { prov.log.Warn().Msgf("checkSessionAndReturnHandle: from %v, no session found", user.MXID) @@ -398,7 +379,7 @@ func (prov *ProvisioningAPI) checkSessionAndReturnHandle(w http.ResponseWriter, // ** Provisioning API ** // func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) prov.log.Debug().Msgf("LinkNew from %v, starting login", user.MXID) handle := prov.loginOrSendError(w, user) @@ -446,7 +427,7 @@ func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) { } func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) body := struct { SessionID string `json:"session_id"` }{} @@ -530,7 +511,7 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ } func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) body := struct { SessionID string `json:"session_id"` DeviceName string `json:"device_name"` @@ -606,7 +587,7 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R } func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) { - user := r.Context().Value("user").(*User) + user := r.Context().Value(provisioningUserKey).(*User) prov.log.Debug().Msgf("Logout called from %v (but not logging out)", user.MXID) prov.clearSession(user)