Skip to content

Commit

Permalink
provisioning: use proper context key
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Dec 30, 2023
1 parent ac2e0aa commit 76b477a
Showing 1 changed file with 21 additions and 40 deletions.
61 changes: 21 additions & 40 deletions provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
package main

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"strconv"
Expand All @@ -34,11 +32,18 @@ import (
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"

"go.mau.fi/mautrix-signal/pkg/signalmeow"
)

type provisioningContextKey int

const (
provisioningUserKey provisioningContextKey = iota
)

type provisioningHandle struct {
id int
context context.Context
Expand Down Expand Up @@ -77,50 +82,26 @@ func (prov *ProvisioningAPI) Init() {
}
}

type responseWrap struct {
http.ResponseWriter
statusCode int
}

func jsonResponse(w http.ResponseWriter, status int, response interface{}) {
func jsonResponse(w http.ResponseWriter, status int, response any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(response)
}

var _ http.Hijacker = (*responseWrap)(nil)

func (rw *responseWrap) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
rw.statusCode = statusCode
}

func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("response does not implement http.Hijacker")
}
return hijacker.Hijack()
}

func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
auth = auth[len("Bearer "):]
}
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
prov.log.Info().Msg("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Authentication token does not match shared secret",
"errcode": "M_FORBIDDEN",
zerolog.Ctx(r.Context()).Warn().Msg("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, &mautrix.RespError{
Err: "Authentication token does not match shared secret",
ErrCode: mautrix.MForbidden.ErrCode,
})
return
}
userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID))
wWrap := &responseWrap{w, 200}
h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user)))
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), provisioningUserKey, user)))
})
}

Expand Down Expand Up @@ -202,7 +183,7 @@ func (prov *ProvisioningAPI) resolveIdentifier(user *User, phoneNum string) (int
}

func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
phoneNum, _ := mux.Vars(r)["phonenum"]
prov.log.Debug().Msgf("ResolveIdentifier from %v, phone number: %v", user.MXID, phoneNum)

Expand Down Expand Up @@ -230,7 +211,7 @@ func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Re
}

func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
phoneNum, _ := mux.Vars(r)["phonenum"]
prov.log.Debug().Msgf("StartPM from %v, phone number: %v", user.MXID, phoneNum)

Expand Down Expand Up @@ -372,7 +353,7 @@ func (prov *ProvisioningAPI) loginOrSendError(w http.ResponseWriter, user *User)
}

func (prov *ProvisioningAPI) checkSessionAndReturnHandle(w http.ResponseWriter, r *http.Request, currentSession int) *provisioningHandle {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
handle := prov.existingSession(user)
if handle == nil {
prov.log.Warn().Msgf("checkSessionAndReturnHandle: from %v, no session found", user.MXID)
Expand All @@ -398,7 +379,7 @@ func (prov *ProvisioningAPI) checkSessionAndReturnHandle(w http.ResponseWriter,
// ** Provisioning API ** //

func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)

prov.log.Debug().Msgf("LinkNew from %v, starting login", user.MXID)
handle := prov.loginOrSendError(w, user)
Expand Down Expand Up @@ -446,7 +427,7 @@ func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) {
}

func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
body := struct {
SessionID string `json:"session_id"`
}{}
Expand Down Expand Up @@ -530,7 +511,7 @@ func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Requ
}

func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
body := struct {
SessionID string `json:"session_id"`
DeviceName string `json:"device_name"`
Expand Down Expand Up @@ -606,7 +587,7 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R
}

func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
prov.log.Debug().Msgf("Logout called from %v (but not logging out)", user.MXID)
prov.clearSession(user)

Expand Down

0 comments on commit 76b477a

Please sign in to comment.