Skip to content

Commit

Permalink
provisioning: use proper context key
Browse files Browse the repository at this point in the history
Signed-off-by: Sumner Evans <[email protected]>
  • Loading branch information
sumnerevans committed Dec 30, 2023
1 parent df22a86 commit e295ae0
Showing 1 changed file with 21 additions and 40 deletions.
61 changes: 21 additions & 40 deletions provisioning.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,10 @@
package main

import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
_ "net/http/pprof"
"strconv"
Expand All @@ -34,11 +32,18 @@ import (
"github.com/gorilla/mux"
"github.com/rs/zerolog"
"github.com/rs/zerolog/hlog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"

"go.mau.fi/mautrix-signal/pkg/signalmeow"
)

type provisioningContextKey int

const (
provisioningUserKey provisioningContextKey = iota
)

type provisioningHandle struct {
id int
context context.Context
Expand Down Expand Up @@ -77,50 +82,26 @@ func (prov *ProvisioningAPI) Init() {
}
}

type responseWrap struct {
http.ResponseWriter
statusCode int
}

func jsonResponse(w http.ResponseWriter, status int, response interface{}) {
func jsonResponse(w http.ResponseWriter, status int, response any) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(response)
}

var _ http.Hijacker = (*responseWrap)(nil)

func (rw *responseWrap) WriteHeader(statusCode int) {
rw.ResponseWriter.WriteHeader(statusCode)
rw.statusCode = statusCode
}

func (rw *responseWrap) Hijack() (net.Conn, *bufio.ReadWriter, error) {
hijacker, ok := rw.ResponseWriter.(http.Hijacker)
if !ok {
return nil, nil, errors.New("response does not implement http.Hijacker")
}
return hijacker.Hijack()
}

func (prov *ProvisioningAPI) AuthMiddleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if strings.HasPrefix(auth, "Bearer ") {
auth = auth[len("Bearer "):]
}
auth := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
if auth != prov.bridge.Config.Bridge.Provisioning.SharedSecret {
prov.log.Info().Msg("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, map[string]interface{}{
"error": "Authentication token does not match shared secret",
"errcode": "M_FORBIDDEN",
zerolog.Ctx(r.Context()).Warn().Msg("Authentication token does not match shared secret")
jsonResponse(w, http.StatusForbidden, &mautrix.RespError{
Err: "Authentication token does not match shared secret",
ErrCode: mautrix.MForbidden.ErrCode,
})
return
}
userID := r.URL.Query().Get("user_id")
user := prov.bridge.GetUserByMXID(id.UserID(userID))
wWrap := &responseWrap{w, 200}
h.ServeHTTP(wWrap, r.WithContext(context.WithValue(r.Context(), "user", user)))
h.ServeHTTP(w, r.WithContext(context.WithValue(r.Context(), provisioningUserKey, user)))
})
}

Expand Down Expand Up @@ -199,7 +180,7 @@ func (prov *ProvisioningAPI) resolveIdentifier(user *User, phoneNum string) (int
}

func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
phoneNum, _ := mux.Vars(r)["phonenum"]

log := prov.log.With().
Expand Down Expand Up @@ -233,7 +214,7 @@ func (prov *ProvisioningAPI) ResolveIdentifier(w http.ResponseWriter, r *http.Re
}

func (prov *ProvisioningAPI) StartPM(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
phoneNum, _ := mux.Vars(r)["phonenum"]

log := prov.log.With().
Expand Down Expand Up @@ -373,7 +354,7 @@ func (prov *ProvisioningAPI) loginOrSendError(ctx context.Context, w http.Respon

func (prov *ProvisioningAPI) checkSessionAndReturnHandle(ctx context.Context, w http.ResponseWriter, currentSession int) *provisioningHandle {
log := zerolog.Ctx(ctx).With().Str("function", "checkSessionAndReturnHandle").Logger()
user := ctx.Value("user").(*User)
user := ctx.Value(provisioningUserKey).(*User)
handle := prov.existingSession(user)
if handle == nil {
log.Warn().Msg("no session found")
Expand Down Expand Up @@ -402,7 +383,7 @@ func (prov *ProvisioningAPI) checkSessionAndReturnHandle(ctx context.Context, w
// ** Provisioning API ** //

func (prov *ProvisioningAPI) LinkNew(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
log := prov.log.With().
Str("action", "link_new").
Str("user_id", user.MXID.String()).
Expand Down Expand Up @@ -466,7 +447,7 @@ type LinkWaitForScanRequest struct {
}

func (prov *ProvisioningAPI) LinkWaitForScan(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
var body LinkWaitForScanRequest
err := json.NewDecoder(r.Body).Decode(&body)
if err != nil {
Expand Down Expand Up @@ -559,7 +540,7 @@ type LinkWaitForAccountRequest struct {
}

func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
var body LinkWaitForAccountRequest
err := json.NewDecoder(r.Body).Decode(&body)
if err != nil {
Expand Down Expand Up @@ -639,7 +620,7 @@ func (prov *ProvisioningAPI) LinkWaitForAccount(w http.ResponseWriter, r *http.R
}

func (prov *ProvisioningAPI) Logout(w http.ResponseWriter, r *http.Request) {
user := r.Context().Value("user").(*User)
user := r.Context().Value(provisioningUserKey).(*User)
log := prov.log.With().
Str("action", "logout").
Str("user_id", user.MXID.String()).
Expand Down

0 comments on commit e295ae0

Please sign in to comment.