Skip to content

Commit

Permalink
feat: support file upload in router
Browse files Browse the repository at this point in the history
  • Loading branch information
pedraumcosta committed Mar 26, 2024
1 parent 95ebc94 commit c41faf8
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1650,9 +1650,16 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) {
return variables, false
}

func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err error) {
func (s *Source) Load(
ctx context.Context, input []byte, files []httpclient.File, writer io.Writer,
) (err error) {
input = s.compactAndUnNullVariables(input)
return httpclient.Do(s.httpClient, ctx, input, writer)

if files == nil {
return httpclient.Do(s.httpClient, ctx, input, writer)
}

return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, writer)
}

type GraphQLSubscriptionClient interface {
Expand Down
26 changes: 26 additions & 0 deletions v2/pkg/engine/datasource/httpclient/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package httpclient

type File interface {
Path() string
Name() string
}

type internalFile struct {
path string
name string
}

func NewFile(path string, name string) File {
return &internalFile{
path: path,
name: name,
}
}

func (f *internalFile) Path() string {
return f.path
}

func (f *internalFile) Name() string {
return f.name
}
191 changes: 191 additions & 0 deletions v2/pkg/engine/datasource/httpclient/nethttpclient.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
package httpclient

import (
"bufio"
"bytes"
"compress/flate"
"compress/gzip"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"os"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -219,3 +224,189 @@ func respBodyReader(res *http.Response) (io.Reader, error) {
return res.Body, nil
}
}

func DoMultipartForm(
client *http.Client, ctx context.Context, requestInput []byte, files []File, out io.Writer,
) (err error) {
if files == nil || len(files) == 0 {
return errors.New("no files provided")
}

url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput)

formValues := map[string]io.Reader{
"operations": bytes.NewReader(body),
}

var fileMap string
for i, file := range files {
if len(fileMap) == 0 {
if len(files) == 1 {
fileMap = fmt.Sprintf(`"%d" : ["variables.file"]`, i)
} else {
fileMap = fmt.Sprintf(`"%d" : ["variables.file%d"]`, i, i+1)
}
} else {
fileMap = fmt.Sprintf(`%s, "%d" : ["variables.file%d"]`, fileMap, i, i+1)
}
key := fmt.Sprintf("%d", i)
temporaryFile, err := os.Open(file.Path())
if err != nil {
return err
}
formValues[key] = bufio.NewReader(temporaryFile)
}
formValues["map"] = strings.NewReader("{ " + fileMap + " }")

multipartBody, contentType, err := multipartBytes(formValues, files)
if err != nil {
return err
}

request, err := http.NewRequestWithContext(ctx, string(method), string(url), &multipartBody)
if err != nil {
return err
}

if headers != nil {
err = jsonparser.ObjectEach(headers, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error {
_, err = jsonparser.ArrayEach(value, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
if err != nil {
return
}
if len(value) == 0 {
return
}
request.Header.Add(string(key), string(value))
})
return err
})
if err != nil {
return err
}
}

if queryParams != nil {
query := request.URL.Query()
_, err = jsonparser.ArrayEach(queryParams, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
var (
parameterName, parameterValue []byte
)
jsonparser.EachKey(value, func(i int, bytes []byte, valueType jsonparser.ValueType, err error) {
switch i {
case 0:
parameterName = bytes
case 1:
parameterValue = bytes
}
}, queryParamsKeys...)
if len(parameterName) != 0 && len(parameterValue) != 0 {
if bytes.Equal(parameterValue[:1], literal.LBRACK) {
_, _ = jsonparser.ArrayEach(parameterValue, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
query.Add(string(parameterName), string(value))
})
} else {
query.Add(string(parameterName), string(parameterValue))
}
}
})
if err != nil {
return err
}
request.URL.RawQuery = query.Encode()
}

request.Header.Add(AcceptHeader, ContentTypeJSON)
request.Header.Add(ContentTypeHeader, contentType)
request.Header.Set(AcceptEncodingHeader, EncodingGzip)
request.Header.Add(AcceptEncodingHeader, EncodingDeflate)

response, err := client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()
for _, file := range files {
err = os.Remove(file.Path())
if err != nil {
return err
}
}

respReader, err := respBodyReader(response)
if err != nil {
return err
}

if !enableTrace {
_, err = io.Copy(out, respReader)
return
}

buf := &bytes.Buffer{}
_, err = io.Copy(buf, respReader)
if err != nil {
return err
}
responseTrace := TraceHTTP{
Request: TraceHTTPRequest{
Method: request.Method,
URL: request.URL.String(),
Headers: redactHeaders(request.Header),
},
Response: TraceHTTPResponse{
StatusCode: response.StatusCode,
Status: response.Status,
Headers: redactHeaders(response.Header),
BodySize: buf.Len(),
},
}
trace, err := json.Marshal(responseTrace)
if err != nil {
return err
}
responseWithTraceExtension, err := jsonparser.Set(buf.Bytes(), trace, "extensions", "trace")
if err != nil {
return err
}
_, err = out.Write(responseWithTraceExtension)
return err
}

func multipartBytes(values map[string]io.Reader, files []File) (bytes.Buffer, string, error) {
var err error
var b bytes.Buffer
var fw io.Writer
w := multipart.NewWriter(&b)

// First create the fields to control the file upload
valuesInOrder := []string{"operations", "map"}
for _, key := range valuesInOrder {
r := values[key]
if fw, err = w.CreateFormField(key); err != nil {
return b, "", err
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}
}

// Now create one form for each file
for i, file := range files {
key := fmt.Sprintf("%d", i)
r := values[key]
if fw, err = w.CreateFormFile(key, file.Name()); err != nil {
return b, "", err
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}
}

err = w.Close()
if err != nil {
return b, "", err
}

return b, w.FormDataContentType(), nil
}
3 changes: 2 additions & 1 deletion v2/pkg/engine/datasource/introspection_datasource/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package introspection_datasource
import (
"context"
"encoding/json"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"

"github.com/wundergraph/graphql-go-tools/v2/pkg/introspection"
Expand All @@ -16,7 +17,7 @@ type Source struct {
introspectionData *introspection.Data
}

func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) {
func (s *Source) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) {
var req introspectionInput
if err := json.Unmarshal(input, &req); err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"
"regexp"
"strings"
Expand Down Expand Up @@ -311,7 +312,7 @@ type PublishDataSource struct {
pubSub PubSub
}

func (s *PublishDataSource) Load(ctx context.Context, input []byte, w io.Writer) error {
func (s *PublishDataSource) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error {
topic, err := jsonparser.GetString(input, "topic")
if err != nil {
return fmt.Errorf("error getting topic from input: %w", err)
Expand All @@ -333,7 +334,7 @@ type RequestDataSource struct {
pubSub PubSub
}

func (s *RequestDataSource) Load(ctx context.Context, input []byte, w io.Writer) error {
func (s *RequestDataSource) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error {
topic, err := jsonparser.GetString(input, "topic")
if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package staticdatasource

import (
"context"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"

"github.com/jensneuse/abstractlogger"
Expand Down Expand Up @@ -65,7 +66,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration {

type Source struct{}

func (Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) {
func (Source) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) {
_, err = w.Write(input)
return
}
4 changes: 4 additions & 0 deletions v2/pkg/engine/resolve/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package resolve
import (
"context"
"encoding/json"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"
"net/http"
"time"
Expand All @@ -14,6 +15,7 @@ import (
type Context struct {
ctx context.Context
Variables []byte
Files []httpclient.File
Request Request
RenameTypeNames []RenameTypeName
TracingOptions TraceOptions
Expand Down Expand Up @@ -136,6 +138,7 @@ func (c *Context) clone(ctx context.Context) *Context {
cpy := *c
cpy.ctx = ctx
cpy.Variables = append([]byte(nil), c.Variables...)
cpy.Files = append([]httpclient.File(nil), c.Files...)
cpy.Request.Header = c.Request.Header.Clone()
cpy.RenameTypeNames = append([]RenameTypeName(nil), c.RenameTypeNames...)
return &cpy
Expand All @@ -144,6 +147,7 @@ func (c *Context) clone(ctx context.Context) *Context {
func (c *Context) Free() {
c.ctx = nil
c.Variables = nil
c.Files = nil
c.Request.Header = nil
c.RenameTypeNames = nil
c.TracingOptions.DisableAll()
Expand Down
3 changes: 2 additions & 1 deletion v2/pkg/engine/resolve/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"io"

"github.com/cespare/xxhash/v2"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
)

type DataSource interface {
Load(ctx context.Context, input []byte, w io.Writer) (err error)
Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error)
}

type SubscriptionDataSource interface {
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, source DataSource, input
}
var responseContext *httpclient.ResponseContext
ctx, responseContext = httpclient.InjectResponseContext(ctx)
res.err = source.Load(ctx, input, res.out)
res.err = source.Load(ctx, input, l.ctx.Files, res.out)
res.statusCode = responseContext.StatusCode
if l.ctx.TracingOptions.Enable {
stats := GetSingleFlightStats(ctx)
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/variablesvalidation/variablesvalidation.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (v *variablesVisitor) EnterVariableDefinition(ref int) {
v.renderVariableRequiredError(varName, varTypeRef)
return
}
if v.variables.Nodes[jsonField].Kind == astjson.NodeKindNull {
if v.variables.Nodes[jsonField].Kind == astjson.NodeKindNull && varTypeName.String() != "Upload" {
v.renderVariableInvalidNullError(varName, varTypeRef)
return
}
Expand Down

0 comments on commit c41faf8

Please sign in to comment.