Skip to content

Commit

Permalink
Simplifying load interface
Browse files Browse the repository at this point in the history
  • Loading branch information
pedraumcosta committed Mar 26, 2024
1 parent ca96426 commit b9aa48c
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1651,15 +1651,15 @@ func (s *Source) replaceEmptyObject(variables []byte) ([]byte, bool) {
}

func (s *Source) Load(
ctx context.Context, input []byte, files []string, filesNames []string, writer io.Writer,
ctx context.Context, input []byte, files []httpclient.File, writer io.Writer,
) (err error) {
input = s.compactAndUnNullVariables(input)

if files == nil {
return httpclient.Do(s.httpClient, ctx, input, writer)
}

return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, filesNames, writer)
return httpclient.DoMultipartForm(s.httpClient, ctx, input, files, writer)
}

type GraphQLSubscriptionClient interface {
Expand Down
26 changes: 26 additions & 0 deletions v2/pkg/engine/datasource/httpclient/file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package httpclient

type File interface {
Path() string
Name() string
}

type internalFile struct {
path string
name string
}

func NewFile(path string, name string) File {
return &internalFile{
path: path,
name: name,
}
}

func (f *internalFile) Path() string {
return f.path
}

func (f *internalFile) Name() string {
return f.name
}
70 changes: 31 additions & 39 deletions v2/pkg/engine/datasource/httpclient/nethttpclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ func respBodyReader(res *http.Response) (io.Reader, error) {
}

func DoMultipartForm(
client *http.Client, ctx context.Context, requestInput []byte, files []string, filesNames []string, out io.Writer,
client *http.Client, ctx context.Context, requestInput []byte, files []File, out io.Writer,
) (err error) {
if files == nil || len(files) == 0 {
return errors.New("no files provided")
Expand All @@ -239,8 +239,8 @@ func DoMultipartForm(
}

var fileMap string
fileNamesMap := make(map[string]string, len(filesNames))
for i, filePath := range files {
//fileNamesMap := make(map[string]string, len(files))
for i, file := range files {
if len(fileMap) == 0 {
if len(files) == 1 {
fileMap = fmt.Sprintf(`"%d" : ["variables.file"]`, i)
Expand All @@ -251,16 +251,16 @@ func DoMultipartForm(
fileMap = fmt.Sprintf(`%s, "%d" : ["variables.file%d"]`, fileMap, i, i+1)
}
key := fmt.Sprintf("%d", i)
fileNamesMap[key] = filesNames[i]
file, err := os.Open(filePath)
//fileNamesMap[key] = filesNames[i]
temporaryFile, err := os.Open(file.Path())
if err != nil {
return err
}
formValues[key] = bufio.NewReader(file)
formValues[key] = bufio.NewReader(temporaryFile)
}
formValues["map"] = strings.NewReader("{ " + fileMap + " }")

multipartBody, contentType, err := multipartBytes(formValues, fileNamesMap)
multipartBody, contentType, err := multipartBytes(formValues, files)
if err != nil {
return err
}
Expand Down Expand Up @@ -328,8 +328,8 @@ func DoMultipartForm(
return err
}
defer response.Body.Close()
for _, filePath := range files {
err = os.Remove(filePath)
for _, file := range files {
err = os.Remove(file.Path())
if err != nil {
return err
}
Expand Down Expand Up @@ -375,48 +375,40 @@ func DoMultipartForm(
return err
}

func multipartBytes(values map[string]io.Reader, fileNamesMap map[string]string) (bytes.Buffer, string, error) {
// Build an string array of N elements starting from 0 to iterate over the values, this needs to be done in order
valuesInOrder := []string{"operations", "map"}
i := 0
for i < len(values)-2 { // all files are there, except for operations and map entries, so -2
valuesInOrder = append(valuesInOrder, fmt.Sprintf("%d", i))
i++
}

func multipartBytes(values map[string]io.Reader, files []File) (bytes.Buffer, string, error) {
var err error
var b bytes.Buffer
var fw io.Writer
w := multipart.NewWriter(&b)

// First create the fields to control the file upload
valuesInOrder := []string{"operations", "map"}
for _, key := range valuesInOrder {
r := values[key]
var fw io.Writer
if x, ok := r.(io.Closer); ok {
defer x.Close()
if fw, err = w.CreateFormField(key); err != nil {
return b, "", err
}
// Add a file
if x, ok := r.(*os.File); ok {
if fw, err = w.CreateFormFile(key, x.Name()); err != nil {
return b, "", err
}
} else {
// Add other fields
fileName, ok := fileNamesMap[key]
if !ok {
if fw, err = w.CreateFormField(key); err != nil {
return b, "", err
}
} else {
if fw, err = w.CreateFormFile(key, fileName); err != nil {
return b, "", err
}
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}
}

// Now create one form for each file
for i, file := range files {
key := fmt.Sprintf("%d", i)
r := values[key]
if fw, err = w.CreateFormFile(key, file.Name()); err != nil {
return b, "", err
}
if _, err = io.Copy(fw, r); err != nil {
return b, "", err
}
}

err = w.Close()
if err != nil {
return b, "", err
}
w.Close()

return b, w.FormDataContentType(), nil
}
3 changes: 2 additions & 1 deletion v2/pkg/engine/datasource/introspection_datasource/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package introspection_datasource
import (
"context"
"encoding/json"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"

"github.com/wundergraph/graphql-go-tools/v2/pkg/introspection"
Expand All @@ -16,7 +17,7 @@ type Source struct {
introspectionData *introspection.Data
}

func (s *Source) Load(ctx context.Context, input []byte, files []string, filesNames []string, w io.Writer) (err error) {
func (s *Source) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) {
var req introspectionInput
if err := json.Unmarshal(input, &req); err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"
"regexp"
"strings"
Expand Down Expand Up @@ -311,7 +312,7 @@ type PublishDataSource struct {
pubSub PubSub
}

func (s *PublishDataSource) Load(ctx context.Context, input []byte, files []string, filesNames []string, w io.Writer) error {
func (s *PublishDataSource) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error {
topic, err := jsonparser.GetString(input, "topic")
if err != nil {
return fmt.Errorf("error getting topic from input: %w", err)
Expand All @@ -333,7 +334,7 @@ type RequestDataSource struct {
pubSub PubSub
}

func (s *RequestDataSource) Load(ctx context.Context, input []byte, files []string, filesNames []string, w io.Writer) error {
func (s *RequestDataSource) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) error {
topic, err := jsonparser.GetString(input, "topic")
if err != nil {
return err
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package staticdatasource

import (
"context"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"

"github.com/jensneuse/abstractlogger"
Expand Down Expand Up @@ -65,7 +66,7 @@ func (p *Planner[T]) ConfigureSubscription() plan.SubscriptionConfiguration {

type Source struct{}

func (Source) Load(ctx context.Context, input []byte, files []string, filesNames []string, w io.Writer) (err error) {
func (Source) Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error) {
_, err = w.Write(input)
return
}
6 changes: 3 additions & 3 deletions v2/pkg/engine/resolve/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package resolve
import (
"context"
"encoding/json"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
"io"
"net/http"
"time"
Expand All @@ -14,8 +15,7 @@ import (
type Context struct {
ctx context.Context
Variables []byte
Files []string
FilesNames []string
Files []httpclient.File
Request Request
RenameTypeNames []RenameTypeName
TracingOptions TraceOptions
Expand Down Expand Up @@ -138,7 +138,7 @@ func (c *Context) clone(ctx context.Context) *Context {
cpy := *c
cpy.ctx = ctx
cpy.Variables = append([]byte(nil), c.Variables...)
cpy.Files = append([]string(nil), c.Files...)
cpy.Files = append([]httpclient.File(nil), c.Files...)
cpy.Request.Header = c.Request.Header.Clone()
cpy.RenameTypeNames = append([]RenameTypeName(nil), c.RenameTypeNames...)
return &cpy
Expand Down
3 changes: 2 additions & 1 deletion v2/pkg/engine/resolve/datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
"io"

"github.com/cespare/xxhash/v2"
"github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient"
)

type DataSource interface {
Load(ctx context.Context, input []byte, files []string, filesNames []string, w io.Writer) (err error)
Load(ctx context.Context, input []byte, files []httpclient.File, w io.Writer) (err error)
}

type SubscriptionDataSource interface {
Expand Down
2 changes: 1 addition & 1 deletion v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,7 @@ func (l *Loader) executeSourceLoad(ctx context.Context, source DataSource, input
}
var responseContext *httpclient.ResponseContext
ctx, responseContext = httpclient.InjectResponseContext(ctx)
res.err = source.Load(ctx, input, l.ctx.Files, l.ctx.FilesNames, res.out)
res.err = source.Load(ctx, input, l.ctx.Files, res.out)
res.statusCode = responseContext.StatusCode
if l.ctx.TracingOptions.Enable {
stats := GetSingleFlightStats(ctx)
Expand Down

0 comments on commit b9aa48c

Please sign in to comment.