Skip to content

Commit

Permalink
Add multipart/form-data transport to support file uploads (#268)
Browse files Browse the repository at this point in the history
This can be tested with a request like:

    curl --url http://localhost:8082/query \
      --form 'operations={"query":"mutation upload($fileA: Upload!, $fileB: Upload!) {\n\tuploadGizmoFile(upload: $fileA)\n\tuploadGadgetFile(upload: {\n\t\tupload: $fileB\n\t})\n}","variables":{"fileA":null,"fileB":null}}' \
      --form 'map={
     "a": ["variables.fileA"],
     "b": ["variables.fileB"]
    }' \
      --form [email protected] \
      --form [email protected]

---------

Co-authored-by: Adam Sven Johnson <[email protected]>
  • Loading branch information
benzolium and pkqk authored Jul 17, 2024
1 parent ae27879 commit ae98872
Show file tree
Hide file tree
Showing 18 changed files with 572 additions and 15 deletions.
151 changes: 147 additions & 4 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ import (
"fmt"
"io"
"math"
"mime"
"mime/multipart"
"net/http"
"net/textproto"
"os"
"strings"
"time"

"github.com/99designs/gqlgen/graphql"
"github.com/prometheus/client_golang/prometheus"
"github.com/vektah/gqlparser/v2/ast"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
Expand Down Expand Up @@ -131,11 +135,11 @@ func (c *GraphQLClient) Request(ctx context.Context, url string, request *Reques
return err
}

var buf bytes.Buffer
if err := json.NewEncoder(&buf).Encode(request); err != nil {
buf, contentType, err := request.requestBody()
if err != nil {
return traceErr(fmt.Errorf("unable to encode request body: %w", err))
}

}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, &buf)
if err != nil {
return traceErr(fmt.Errorf("unable to create request: %w", err))
Expand All @@ -145,7 +149,7 @@ func (c *GraphQLClient) Request(ctx context.Context, url string, request *Reques
httpReq.Header = request.Headers.Clone()
}

httpReq.Header.Set("Content-Type", "application/json; charset=utf-8")
httpReq.Header.Set("Content-Type", contentType)
httpReq.Header.Set("Accept", "application/json")

if c.UserAgent != "" {
Expand Down Expand Up @@ -246,6 +250,88 @@ func (r *Request) WithVariables(variables map[string]interface{}) *Request {
return r
}

// isMultipart returns true if the request contains a graphql.Upload object
// implying that the downstream request needs to be a multipart/form-data request
func (r *Request) isMultipart() bool {
stack := []map[string]any{r.Variables}
for len(stack) > 0 {
currentItem := stack[len(stack)-1]
stack = stack[:len(stack)-1]
for _, v := range currentItem {
switch v := v.(type) {
case graphql.Upload, *graphql.Upload, []graphql.Upload, []*graphql.Upload:
return true
case map[string]any:
stack = append(stack, v)
}
}
}
return false
}

func (r *Request) requestBody() (bytes.Buffer, string, error) {
var buf bytes.Buffer
var err error
contentType := "application/json; charset=utf-8"
if r.isMultipart() {
buf, contentType, err = multipartBody(r)
if err != nil {
return buf, "", fmt.Errorf("unable to encode multipart request body: %w", err)
}
return buf, contentType, nil
}
if err = json.NewEncoder(&buf).Encode(r); err != nil {
return buf, "", fmt.Errorf("unable to encode request body: %w", err)
}
return buf, contentType, nil
}

func multipartBody(r *Request) (bytes.Buffer, string, error) {
files, fileMap := prepareUploadsFromVariables(r.Variables)

var buf bytes.Buffer
mpw := multipart.NewWriter(&buf)
fw, err := mpw.CreateFormField("operations")
if err != nil {
return buf, "", err
}
if err = json.NewEncoder(fw).Encode(r); err != nil {
return buf, "", err
}
fw, err = mpw.CreateFormField("map")
if err != nil {
return buf, "", err
}
if err = json.NewEncoder(fw).Encode(fileMap); err != nil {
return buf, "", err
}
for fileIndex := range fileMap {
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", mime.FormatMediaType("form-data", map[string]string{
"name": fileIndex,
"filename": files[fileIndex].Filename,
}))
if ct := files[fileIndex].ContentType; ct != "" {
h.Set("Content-Type", files[fileIndex].ContentType)
} else {
h.Set("Content-Type", "application/octet-stream")
}
innerFw, fileErr := mpw.CreatePart(h)
if fileErr != nil {
return buf, "", fileErr
}
_, ioErr := io.Copy(innerFw, files[fileIndex].File)
if ioErr != nil {
return buf, "", ioErr
}
}
err = mpw.Close()
if err != nil {
return buf, "", err
}
return buf, mpw.FormDataContentType(), nil
}

// Response is a GraphQL response
type Response struct {
Errors GraphqlErrors `json:"errors"`
Expand Down Expand Up @@ -275,3 +361,60 @@ func (e GraphqlErrors) Error() string {
func GenerateUserAgent(operation string) string {
return fmt.Sprintf("Bramble/%s (%s)", Version, operation)
}

func prepareUploadsFromVariables(variables map[string]any) (map[string]graphql.Upload, map[string][]string) {
type stackItem struct {
path string
data map[string]interface{}
}

stack := []stackItem{{path: "variables", data: variables}}

index := 0
fileMap := map[string][]string{}
files := map[string]graphql.Upload{}
for len(stack) > 0 {
currentItem := stack[len(stack)-1]
stack = stack[:len(stack)-1]

for key, value := range currentItem.data {
currentPath := currentItem.path + "." + key

switch v := value.(type) {
case graphql.Upload, *graphql.Upload:
currentItem.data[key] = nil
fileIndex := fmt.Sprintf("file%d", index)
fileMap[fileIndex] = []string{currentPath}
index += 1
switch v := v.(type) {
case graphql.Upload:
files[fileIndex] = v
case *graphql.Upload:
files[fileIndex] = *v
}
case []graphql.Upload:
currentItem.data[key] = make([]*struct{}, len(v))
for i, file := range v {
elemPath := fmt.Sprintf("%s.%d", currentPath, i)
fileIndex := fmt.Sprintf("file%d", index)
fileMap[fileIndex] = []string{elemPath}
index += 1
files[fileIndex] = file
}
case []*graphql.Upload:
currentItem.data[key] = make([]*struct{}, len(v))
for i, file := range v {
elemPath := fmt.Sprintf("%s.%d", currentPath, i)
fileIndex := fmt.Sprintf("file%d", index)
fileMap[fileIndex] = []string{elemPath}
index += 1
files[fileIndex] = *file
}
case map[string]any:
stack = append(stack, stackItem{data: v, path: currentPath})
default:
}
}
}
return files, fileMap
}
91 changes: 91 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"net/url"
"testing"

"github.com/99designs/gqlgen/graphql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -102,3 +103,93 @@ func TestGraphqlClient(t *testing.T) {
assert.Equal(t, "response exceeded maximum size of 1 bytes", err.Error())
})
}
func TestMultipartClient(t *testing.T) {
nestedMap := map[string]any{
"node1": map[string]any{
"node11": map[string]any{
"leaf111": graphql.Upload{},
"leaf112": "someThing",
"node113": map[string]any{"leaf1131": graphql.Upload{}},
},
"leaf12": 42,
"leaf13": graphql.Upload{},
},
"node2": map[string]any{
"leaf21": false,
"node21": map[string]any{
"leaf211": &graphql.Upload{},
},
},
"node3": graphql.Upload{},
"node4": []graphql.Upload{{}, {}},
"node5": []*graphql.Upload{{}, {}},
}

t.Run("parseMultipartVariables", func(t *testing.T) {
_, fileMap := prepareUploadsFromVariables(nestedMap)
fileMapKeys := []string{}
fileMapValues := []string{}
for k, v := range fileMap {
fileMapKeys = append(fileMapKeys, k)
fileMapValues = append(fileMapValues, v...)
}
assert.ElementsMatch(t, fileMapKeys, []string{"file0", "file1", "file2", "file3", "file4", "file5", "file6", "file7", "file8"})
assert.ElementsMatch(t, fileMapValues, []string{
"variables.node1.node11.node113.leaf1131",
"variables.node1.node11.leaf111",
"variables.node1.leaf13",
"variables.node2.node21.leaf211",
"variables.node3",
"variables.node4.0",
"variables.node4.1",
"variables.node5.0",
"variables.node5.1",
})
assert.Equal(
t,
map[string]any{
"node1": map[string]any{
"node11": map[string]any{
"leaf111": nil,
"leaf112": "someThing",
"node113": map[string]any{"leaf1131": nil},
},
"leaf12": 42,
"leaf13": nil,
},
"node2": map[string]any{
"leaf21": false,
"node21": map[string]any{
"leaf211": nil,
},
},
"node3": nil,
"node4": []*struct{}{nil, nil},
"node5": []*struct{}{nil, nil},
},
nestedMap,
)
})

t.Run("multipart request", func(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{ "data": {"root": "multipart response"} }`))
}))

c := NewClient()
req := &Request{Headers: make(http.Header)}
req.Headers.Set("Content-Type", "multipart/form-data")

var res struct {
Root string
}
err := c.Request(
context.Background(),
srv.URL,
req,
&res,
)
require.NoError(t, err)
assert.Equal(t, "multipart response", res.Root)
})
}
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ type Config struct {
PollIntervalDuration time.Duration
MaxRequestsPerQuery int64 `json:"max-requests-per-query"`
MaxServiceResponseSize int64 `json:"max-service-response-size"`
MaxFileUploadSize int64 `json:"max-file-upload-size"`
Telemetry TelemetryConfig `json:"telemetry"`
Plugins []PluginConfig
// Config extensions that can be shared among plugins
Expand Down
12 changes: 11 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,16 @@ services:
retries: 5
expose:
- 8080
gqlgen-multipart-file-upload-service:
build:
context: examples/gqlgen-multipart-file-upload-service
healthcheck: &healthcheck
test: wget -qO - http://localhost:8080/health
interval: 5s
timeout: 1s
retries: 5
expose:
- 8080
graph-gophers-service:
healthcheck: *healthcheck
build:
Expand All @@ -34,7 +44,7 @@ services:
configs: [gateway]
command: ["-config", "gateway"]
environment:
- BRAMBLE_SERVICE_LIST=http://gqlgen-service:8080/query http://graph-gophers-service:8080/query http://slow-service:8080/query http://nodejs-service:8080/query
- BRAMBLE_SERVICE_LIST=http://gqlgen-service:8080/query http://gqlgen-multipart-file-upload-service:8080/query http://graph-gophers-service:8080/query http://slow-service:8080/query http://nodejs-service:8080/query
ports:
- 8082:8082
- 8083:8083
Expand Down
13 changes: 13 additions & 0 deletions docs/plugins.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,19 @@ Add `CORS` headers to queries.
}
```

## Headers

Allow headers to passthrough to downstream services.

```json
{
"name": "headers",
"config": {
"allowed-headers": ["X-Custom-Header"]
}
}
```

## JWT Auth

The JWT auth plugin validates that the request contains a valid JWT and
Expand Down
3 changes: 3 additions & 0 deletions examples/gqlgen-multipart-file-upload-service/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gqlgen-service
generated.go
models_gen.go
11 changes: 11 additions & 0 deletions examples/gqlgen-multipart-file-upload-service/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM golang:1.22-alpine3.19

ENV CGO_ENABLED=0

WORKDIR /go/src/app

COPY . .

RUN go generate .
RUN go get
CMD go run .
18 changes: 18 additions & 0 deletions examples/gqlgen-multipart-file-upload-service/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
ARTIFACT=gqlgen-service
DEF=gqlgen.yml schema.graphql
GEN=models_gen.go generated.go

build: $(ARTIFACT)

.PHONY: clean
clean:
rm -f $(ARTIFACT) $(GEN)

.PHONY: generate
generate: $(GEN)

$(GEN): $(DEF)
go generate

gqlgen-service: $(GEN) $(wildcard *.go)
go build
Loading

0 comments on commit ae98872

Please sign in to comment.