Skip to content

Commit

Permalink
Support file upload
Browse files Browse the repository at this point in the history
  • Loading branch information
pedraumcosta committed Mar 21, 2024
1 parent 0e9a903 commit 11cd5f0
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1762,9 +1762,14 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) {
return variables, false
}

func (s *Source) Load(ctx context.Context, input []byte, writer io.Writer) (err error) {
func (s *Source) Load(ctx context.Context, input []byte, files [][]byte, writer io.Writer) (err error) {
input = s.compactAndUnNullVariables(input)
return httpclient.Do(s.httpClient, ctx, input, writer)

if files == nil {
return httpclient.Do(s.httpClient, ctx, input, writer)
}

return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, writer)
}

type GraphQLSubscriptionClient interface {
Expand Down
153 changes: 153 additions & 0 deletions v2/pkg/engine/datasource/httpclient/nethttpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"encoding/json"
"errors"
"io"
"mime/multipart"
"net/http"
"os"
"slices"
"strings"
"time"
Expand Down Expand Up @@ -208,3 +210,154 @@ func respBodyReader(res *http.Response) (io.Reader, error) {
return res.Body, nil
}
}

func DoMultipartForm(client *http.Client, ctx context.Context, requestInput []byte, files [][]byte, out io.Writer) (err error) {
if files == nil || files[0] == nil {
return errors.New("no files provided")
}

url, method, body, headers, queryParams, enableTrace := requestInputParams(requestInput)

formValues := map[string]io.Reader{
"operations": bytes.NewReader(body),
// TODO pedro prepare for multipart next
"map": strings.NewReader(`{ "0": ["variables.file"] }`),
"0": bytes.NewReader(files[0]),
}
multipartBody, contentType, err := multipartBytes(formValues)
if err != nil {
return errors.New("error creating multipart bytes")
}

request, err := http.NewRequestWithContext(ctx, string(method), string(url), &multipartBody)
if err != nil {
return err
}

if headers != nil {
err = jsonparser.ObjectEach(headers, func(key []byte, value []byte, dataType jsonparser.ValueType, offset int) error {
_, err := jsonparser.ArrayEach(value, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
if err != nil {
return
}
if len(value) == 0 {
return
}
request.Header.Add(string(key), string(value))
})
return err
})
if err != nil {
return err
}
}

if queryParams != nil {
query := request.URL.Query()
_, err = jsonparser.ArrayEach(queryParams, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
var (
parameterName, parameterValue []byte
)
jsonparser.EachKey(value, func(i int, bytes []byte, valueType jsonparser.ValueType, err error) {
switch i {
case 0:
parameterName = bytes
case 1:
parameterValue = bytes
}
}, queryParamsKeys...)
if len(parameterName) != 0 && len(parameterValue) != 0 {
if bytes.Equal(parameterValue[:1], literal.LBRACK) {
_, _ = jsonparser.ArrayEach(parameterValue, func(value []byte, dataType jsonparser.ValueType, offset int, err error) {
query.Add(string(parameterName), string(value))
})
} else {
query.Add(string(parameterName), string(parameterValue))
}
}
})
if err != nil {
return err
}
request.URL.RawQuery = query.Encode()
}

request.Header.Add(AcceptHeader, ContentTypeJSON)
request.Header.Add(ContentTypeHeader, contentType)
request.Header.Set(AcceptEncodingHeader, EncodingGzip)
request.Header.Add(AcceptEncodingHeader, EncodingDeflate)

response, err := client.Do(request)
if err != nil {
return err
}
defer response.Body.Close()

respReader, err := respBodyReader(response)
if err != nil {
return err
}

if !enableTrace {
_, err = io.Copy(out, respReader)
return
}

buf := &bytes.Buffer{}
_, err = io.Copy(buf, respReader)
if err != nil {
return err
}
responseTrace := TraceHTTP{
Request: TraceHTTPRequest{
Method: request.Method,
URL: request.URL.String(),
Headers: redactHeaders(request.Header),
},
Response: TraceHTTPResponse{
StatusCode: response.StatusCode,
Status: response.Status,
Headers: redactHeaders(response.Header),
BodySize: buf.Len(),
},
}
trace, err := json.Marshal(responseTrace)
if err != nil {
return err
}
responseWithTraceExtension, err := jsonparser.Set(buf.Bytes(), trace, "extensions", "trace")
if err != nil {
return err
}
_, err = out.Write(responseWithTraceExtension)
return err
}
func multipartBytes(values map[string]io.Reader) (bytes.Buffer, string, error) {
var err error
var b bytes.Buffer
w := multipart.NewWriter(&b)
for key, r := range values {
var fw io.Writer
if x, ok := r.(io.Closer); ok {
defer x.Close()
}
// Add a file
if x, ok := r.(*os.File); ok {
if fw, err = w.CreateFormFile(key, x.Name()); err != nil {
return b, "", err
}
} else {
// Add other fields
if fw, err = w.CreateFormField(key); err != nil {
return b, "", err
}
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}

}
w.Close()

return b, w.FormDataContentType(), nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type Source struct {
introspectionData *introspection.Data
}

func (s *Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) {
func (s *Source) Load(ctx context.Context, input []byte, files [][]byte, w io.Writer) (err error) {
var req introspectionInput
if err := json.Unmarshal(input, &req); err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ type PublishDataSource struct {
pubSub PubSub
}

func (s *PublishDataSource) Load(ctx context.Context, input []byte, w io.Writer) error {
func (s *PublishDataSource) Load(ctx context.Context, input []byte, files [][]byte, w io.Writer) error {
topic, err := jsonparser.GetString(input, "topic")
if err != nil {
return fmt.Errorf("error getting topic from input: %w", err)
Expand All @@ -312,7 +312,7 @@ type RequestDataSource struct {
pubSub PubSub
}

func (s *RequestDataSource) Load(ctx context.Context, input []byte, w io.Writer) error {
func (s *RequestDataSource) Load(ctx context.Context, input []byte, files [][]byte, w io.Writer) error {
topic, err := jsonparser.GetString(input, "topic")
if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func (p *Planner) ConfigureSubscription() plan.SubscriptionConfiguration {

type Source struct{}

func (Source) Load(ctx context.Context, input []byte, w io.Writer) (err error) {
func (Source) Load(ctx context.Context, input []byte, files [][]byte, w io.Writer) (err error) {
_, err = w.Write(input)
return
}
3 changes: 3 additions & 0 deletions v2/pkg/engine/resolve/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type Context struct {
ctx context.Context
Variables []byte
Files [][]byte
Request Request
RenameTypeNames []RenameTypeName
TracingOptions TraceOptions
Expand Down Expand Up @@ -136,6 +137,7 @@ func (c *Context) clone(ctx context.Context) *Context {
cpy := *c
cpy.ctx = ctx
cpy.Variables = append([]byte(nil), c.Variables...)
cpy.Files = append([][]byte(nil), c.Files...)
cpy.Request.Header = c.Request.Header.Clone()
cpy.RenameTypeNames = append([]RenameTypeName(nil), c.RenameTypeNames...)
return &cpy
Expand All @@ -144,6 +146,7 @@ func (c *Context) clone(ctx context.Context) *Context {
func (c *Context) Free() {
c.ctx = nil
c.Variables = nil
c.Files = nil
c.Request.Header = nil
c.RenameTypeNames = nil
c.TracingOptions.DisableAll()
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/engine/resolve/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import (
)

type DataSource interface {
Load(ctx context.Context, input []byte, w io.Writer) (err error)
Load(ctx context.Context, input []byte, files [][]byte, w io.Writer) (err error)
}

type SubscriptionDataSource interface {
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, source DataSource, input
if l.info != nil && l.info.OperationType == ast.OperationTypeMutation {
ctx = context.WithValue(ctx, disallowSingleFlightContextKey{}, true)
}
err = source.Load(ctx, input, out)
err = source.Load(ctx, input, l.ctx.Files, out)
if l.ctx.TracingOptions.Enable {
stats := GetSingleFlightStats(ctx)
if stats != nil {
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/variablesvalidation/variablesvalidation.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (v *variablesVisitor) EnterVariableDefinition(ref int) {
v.renderVariableRequiredError(varName, varTypeRef)
return
}
if v.variables.Nodes[jsonField].Kind == astjson.NodeKindNull {
if v.variables.Nodes[jsonField].Kind == astjson.NodeKindNull && varTypeName.String() != "Upload" {
v.renderVariableInvalidNullError(varName, varTypeRef)
return
}
Expand Down

0 comments on commit 11cd5f0

Please sign in to comment.