From b38be4bcb348991e0e6c1bc98382f9d8b47da730 Mon Sep 17 00:00:00 2001 From: Jinquan Wang <35188480+wangjq4214@users.noreply.github.com> Date: Mon, 4 Mar 2024 15:49:14 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E2=9C=A8=20v3=20(feature):=20client=20refa?= =?UTF-8?q?ctor=20(#1986)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ v3: Move the client module to the client folder and fix the error * ✨ v3: add xml encoder and decoder * 🚧 v3: design plugin and hook mechanism, complete simple get request * 🚧 v3: reset add some field * 🚧 v3: add doc and fix some error * 🚧 v3: add header merge * 🚧 v3: add query param * 🚧 v3: change to fasthttp's header and args * ✨ v3: add body and ua setting * 🚧 v3: add cookie support * 🚧 v3: add path param support * βœ… v3: fix error test case * 🚧 v3: add formdata and file support * 🚧 v3: referer support * 🚧 v3: reponse unmarshal * ✨ v3: finish API design * πŸ”₯ v3: remove plugin mechanism * 🚧 v3: add timeout * 🚧 v3: change path params pattern and add unit test for core * ✏️ v3: error spell * βœ… v3: improve test coverage * βœ… perf: change test func name to fit project format * 🚧 v3: handle error * 🚧 v3: add unit test and fix error * ⚑️ chore: change func to improve performance * βœ… v3: add some unit test * βœ… v3: fix error test * πŸ› fix: add cookie to response * βœ… v3: add unit test * ✨ v3: export raw field * πŸ› fix: fix data race * πŸ”’οΈ chore: change package * πŸ› fix: data race * πŸ› fix: test fail * ✨ feat: move core to req * πŸ› fix: connection reuse * πŸ› fix: data race * πŸ› fix: data race * πŸ”€ fix: change to testify * βœ… fix: fail test in windows * ✨ feat: response body save to file * ✨ feat: support tls config * πŸ› fix: add err check * 🎨 perf: fix some static check * ✨ feat: add proxy support * ✨ feat: add retry feature * πŸ› fix: static check error * 🎨 refactor: move som code * docs: change readme * ✨ feat: extend axios API * perf: change field to export field * βœ… chore: disable startup message * πŸ› fix: fix test error * chore: fix error test * chore: fix test case * feat: add some test to client * chore: add test case * chore: add test case * ✨ feat: add peek for client * βœ… chore: add test case * ⚑️ feat: lazy generate rand string * 🚧 perf: add config test case * πŸ› fix: fix merge error * :bug: fix utils error * :sparkles: add redirection * πŸ”₯ chore: delete deps * perf: fix spell error * 🎨 perf: spell error * ✨ feat: add logger * ✨ feat: add cookie jar * ✨ feat: logger with level * 🎨 perf: change the field name * perf: add jar test * fix proxy test * improve test coverage * fix proxy tests * add cookiejar support from pending fasthttp PR * fix some lint errors. * add benchmark for SetValWithStruct * optimize * update * fix proxy middleware * use panicf instead of errorf and fix panic on default logger * update * update * cleanup comments * cleanup comments * fix golang-lint errors * Update helper_test.go * add more test cases * add hostclient pool * make it more thread safe -> there is still something which is shared between the requests * fixed some golangci-lint errors * fix Test_Request_FormData test * create new test suite * just create client for once * use random port instead of 3000 * remove client pooling and fix test suite * fix data races on logger tests * fix proxy tests * fix global tests * remove unused code * fix logger test * fix proxy tests * fix linter * use lock instead of rlock * fix cookiejar data-race * fix(client): race conditions * fix(client): race conditions * apply some reviews * change client property name * apply review * add parallel benchmark for simple request * apply review * apply review * fix log tests * fix linter * fix(client): return error in SetProxyURL instead of panic --------- Co-authored-by: Muhammed Efe Γ‡etin Co-authored-by: RenΓ© Werner Co-authored-by: Joey Co-authored-by: RenΓ© --- client.go | 1021 -------------------- client/README.md | 35 + client/client.go | 775 +++++++++++++++ client/client_test.go | 1642 ++++++++++++++++++++++++++++++++ client/cookiejar.go | 245 +++++ client/cookiejar_test.go | 213 +++++ client/core.go | 272 ++++++ client/core_test.go | 248 +++++ client/helper_test.go | 157 +++ client/hooks.go | 328 +++++++ client/hooks_test.go | 652 +++++++++++++ client/request.go | 985 +++++++++++++++++++ client/request_test.go | 1623 +++++++++++++++++++++++++++++++ client/response.go | 184 ++++ client/response_test.go | 418 ++++++++ client_test.go | 1337 -------------------------- listen_test.go | 27 +- log/default.go | 2 + middleware/proxy/proxy.go | 8 - middleware/proxy/proxy_test.go | 115 ++- redirect_test.go | 30 +- 21 files changed, 7885 insertions(+), 2432 deletions(-) delete mode 100644 client.go create mode 100644 client/README.md create mode 100644 client/client.go create mode 100644 client/client_test.go create mode 100644 client/cookiejar.go create mode 100644 client/cookiejar_test.go create mode 100644 client/core.go create mode 100644 client/core_test.go create mode 100644 client/helper_test.go create mode 100644 client/hooks.go create mode 100644 client/hooks_test.go create mode 100644 client/request.go create mode 100644 client/request_test.go create mode 100644 client/response.go create mode 100644 client/response_test.go delete mode 100644 client_test.go diff --git a/client.go b/client.go deleted file mode 100644 index 8825f9d815e..00000000000 --- a/client.go +++ /dev/null @@ -1,1021 +0,0 @@ -package fiber - -import ( - "bytes" - "crypto/tls" - "encoding/json" - "encoding/xml" - "errors" - "fmt" - "io" - "mime/multipart" - "os" - "path/filepath" - "strconv" - "sync" - "time" - - "github.com/gofiber/utils/v2" - "github.com/valyala/fasthttp" -) - -// Request represents HTTP request. -// -// It is forbidden copying Request instances. Create new instances -// and use CopyTo instead. -// -// Request instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Request = fasthttp.Request - -// Response represents HTTP response. -// -// It is forbidden copying Response instances. Create new instances -// and use CopyTo instead. -// -// Response instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Response = fasthttp.Response - -// Args represents query arguments. -// -// It is forbidden copying Args instances. Create new instances instead -// and use CopyTo(). -// -// Args instance MUST NOT be used from concurrently running goroutines. -// Copy from fasthttp -type Args = fasthttp.Args - -// RetryIfFunc signature of retry if function -// Request argument passed to RetryIfFunc, if there are any request errors. -// Copy from fasthttp -type RetryIfFunc = fasthttp.RetryIfFunc - -var defaultClient Client - -// Client implements http client. -// -// It is safe calling Client methods from concurrently running goroutines. -type Client struct { - mutex sync.RWMutex - // UserAgent is used in User-Agent request header. - UserAgent string - - // NoDefaultUserAgentHeader when set to true, causes the default - // User-Agent header to be excluded from the Request. - NoDefaultUserAgentHeader bool - - // When set by an external client of Fiber it will use the provided implementation of a - // JSONMarshal - // - // Allowing for flexibility in using another json library for encoding - JSONEncoder utils.JSONMarshal - - // When set by an external client of Fiber it will use the provided implementation of a - // JSONUnmarshal - // - // Allowing for flexibility in using another json library for decoding - JSONDecoder utils.JSONUnmarshal -} - -// Get returns an agent with http method GET. -func Get(url string) *Agent { return defaultClient.Get(url) } - -// Get returns an agent with http method GET. -func (c *Client) Get(url string) *Agent { - return c.createAgent(MethodGet, url) -} - -// Head returns an agent with http method HEAD. -func Head(url string) *Agent { return defaultClient.Head(url) } - -// Head returns an agent with http method GET. -func (c *Client) Head(url string) *Agent { - return c.createAgent(MethodHead, url) -} - -// Post sends POST request to the given URL. -func Post(url string) *Agent { return defaultClient.Post(url) } - -// Post sends POST request to the given URL. -func (c *Client) Post(url string) *Agent { - return c.createAgent(MethodPost, url) -} - -// Put sends PUT request to the given URL. -func Put(url string) *Agent { return defaultClient.Put(url) } - -// Put sends PUT request to the given URL. -func (c *Client) Put(url string) *Agent { - return c.createAgent(MethodPut, url) -} - -// Patch sends PATCH request to the given URL. -func Patch(url string) *Agent { return defaultClient.Patch(url) } - -// Patch sends PATCH request to the given URL. -func (c *Client) Patch(url string) *Agent { - return c.createAgent(MethodPatch, url) -} - -// Delete sends DELETE request to the given URL. -func Delete(url string) *Agent { return defaultClient.Delete(url) } - -// Delete sends DELETE request to the given URL. -func (c *Client) Delete(url string) *Agent { - return c.createAgent(MethodDelete, url) -} - -func (c *Client) createAgent(method, url string) *Agent { - a := AcquireAgent() - a.req.Header.SetMethod(method) - a.req.SetRequestURI(url) - - c.mutex.RLock() - a.Name = c.UserAgent - a.NoDefaultUserAgentHeader = c.NoDefaultUserAgentHeader - a.jsonDecoder = c.JSONDecoder - a.jsonEncoder = c.JSONEncoder - if a.jsonDecoder == nil { - a.jsonDecoder = json.Unmarshal - } - c.mutex.RUnlock() - - if err := a.Parse(); err != nil { - a.errs = append(a.errs, err) - } - - return a -} - -// Agent is an object storing all request data for client. -// Agent instance MUST NOT be used from concurrently running goroutines. -type Agent struct { - // Name is used in User-Agent request header. - Name string - - // NoDefaultUserAgentHeader when set to true, causes the default - // User-Agent header to be excluded from the Request. - NoDefaultUserAgentHeader bool - - // HostClient is an embedded fasthttp HostClient - *fasthttp.HostClient - - req *Request - resp *Response - dest []byte - args *Args - timeout time.Duration - errs []error - formFiles []*FormFile - debugWriter io.Writer - mw multipartWriter - jsonEncoder utils.JSONMarshal - jsonDecoder utils.JSONUnmarshal - maxRedirectsCount int - boundary string - reuse bool - parsed bool -} - -// Parse initializes URI and HostClient. -func (a *Agent) Parse() error { - if a.parsed { - return nil - } - a.parsed = true - - uri := a.req.URI() - - var isTLS bool - scheme := uri.Scheme() - if bytes.Equal(scheme, []byte(schemeHTTPS)) { - isTLS = true - } else if !bytes.Equal(scheme, []byte(schemeHTTP)) { - return fmt.Errorf("unsupported protocol %q. http and https are supported", scheme) - } - - name := a.Name - if name == "" && !a.NoDefaultUserAgentHeader { - name = defaultUserAgent - } - - a.HostClient = &fasthttp.HostClient{ - Addr: fasthttp.AddMissingPort(string(uri.Host()), isTLS), - Name: name, - NoDefaultUserAgentHeader: a.NoDefaultUserAgentHeader, - IsTLS: isTLS, - } - - return nil -} - -/************************** Header Setting **************************/ - -// Set sets the given 'key: value' header. -// -// Use Add for setting multiple header values under the same key. -func (a *Agent) Set(k, v string) *Agent { - a.req.Header.Set(k, v) - - return a -} - -// SetBytesK sets the given 'key: value' header. -// -// Use AddBytesK for setting multiple header values under the same key. -func (a *Agent) SetBytesK(k []byte, v string) *Agent { - a.req.Header.SetBytesK(k, v) - - return a -} - -// SetBytesV sets the given 'key: value' header. -// -// Use AddBytesV for setting multiple header values under the same key. -func (a *Agent) SetBytesV(k string, v []byte) *Agent { - a.req.Header.SetBytesV(k, v) - - return a -} - -// SetBytesKV sets the given 'key: value' header. -// -// Use AddBytesKV for setting multiple header values under the same key. -func (a *Agent) SetBytesKV(k, v []byte) *Agent { - a.req.Header.SetBytesKV(k, v) - - return a -} - -// Add adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use Set for setting a single header for the given key. -func (a *Agent) Add(k, v string) *Agent { - a.req.Header.Add(k, v) - - return a -} - -// AddBytesK adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesK for setting a single header for the given key. -func (a *Agent) AddBytesK(k []byte, v string) *Agent { - a.req.Header.AddBytesK(k, v) - - return a -} - -// AddBytesV adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesV for setting a single header for the given key. -func (a *Agent) AddBytesV(k string, v []byte) *Agent { - a.req.Header.AddBytesV(k, v) - - return a -} - -// AddBytesKV adds the given 'key: value' header. -// -// Multiple headers with the same key may be added with this function. -// Use SetBytesKV for setting a single header for the given key. -func (a *Agent) AddBytesKV(k, v []byte) *Agent { - a.req.Header.AddBytesKV(k, v) - - return a -} - -// ConnectionClose sets 'Connection: close' header. -func (a *Agent) ConnectionClose() *Agent { - a.req.Header.SetConnectionClose() - - return a -} - -// UserAgent sets User-Agent header value. -func (a *Agent) UserAgent(userAgent string) *Agent { - a.req.Header.SetUserAgent(userAgent) - - return a -} - -// UserAgentBytes sets User-Agent header value. -func (a *Agent) UserAgentBytes(userAgent []byte) *Agent { - a.req.Header.SetUserAgentBytes(userAgent) - - return a -} - -// Cookie sets one 'key: value' cookie. -func (a *Agent) Cookie(key, value string) *Agent { - a.req.Header.SetCookie(key, value) - - return a -} - -// CookieBytesK sets one 'key: value' cookie. -func (a *Agent) CookieBytesK(key []byte, value string) *Agent { - a.req.Header.SetCookieBytesK(key, value) - - return a -} - -// CookieBytesKV sets one 'key: value' cookie. -func (a *Agent) CookieBytesKV(key, value []byte) *Agent { - a.req.Header.SetCookieBytesKV(key, value) - - return a -} - -// Cookies sets multiple 'key: value' cookies. -func (a *Agent) Cookies(kv ...string) *Agent { - for i := 1; i < len(kv); i += 2 { - a.req.Header.SetCookie(kv[i-1], kv[i]) - } - - return a -} - -// CookiesBytesKV sets multiple 'key: value' cookies. -func (a *Agent) CookiesBytesKV(kv ...[]byte) *Agent { - for i := 1; i < len(kv); i += 2 { - a.req.Header.SetCookieBytesKV(kv[i-1], kv[i]) - } - - return a -} - -// Referer sets Referer header value. -func (a *Agent) Referer(referer string) *Agent { - a.req.Header.SetReferer(referer) - - return a -} - -// RefererBytes sets Referer header value. -func (a *Agent) RefererBytes(referer []byte) *Agent { - a.req.Header.SetRefererBytes(referer) - - return a -} - -// ContentType sets Content-Type header value. -func (a *Agent) ContentType(contentType string) *Agent { - a.req.Header.SetContentType(contentType) - - return a -} - -// ContentTypeBytes sets Content-Type header value. -func (a *Agent) ContentTypeBytes(contentType []byte) *Agent { - a.req.Header.SetContentTypeBytes(contentType) - - return a -} - -/************************** End Header Setting **************************/ - -/************************** URI Setting **************************/ - -// Host sets host for the URI. -func (a *Agent) Host(host string) *Agent { - a.req.URI().SetHost(host) - - return a -} - -// HostBytes sets host for the URI. -func (a *Agent) HostBytes(host []byte) *Agent { - a.req.URI().SetHostBytes(host) - - return a -} - -// QueryString sets URI query string. -func (a *Agent) QueryString(queryString string) *Agent { - a.req.URI().SetQueryString(queryString) - - return a -} - -// QueryStringBytes sets URI query string. -func (a *Agent) QueryStringBytes(queryString []byte) *Agent { - a.req.URI().SetQueryStringBytes(queryString) - - return a -} - -// BasicAuth sets URI username and password. -func (a *Agent) BasicAuth(username, password string) *Agent { - a.req.URI().SetUsername(username) - a.req.URI().SetPassword(password) - - return a -} - -// BasicAuthBytes sets URI username and password. -func (a *Agent) BasicAuthBytes(username, password []byte) *Agent { - a.req.URI().SetUsernameBytes(username) - a.req.URI().SetPasswordBytes(password) - - return a -} - -/************************** End URI Setting **************************/ - -/************************** Request Setting **************************/ - -// BodyString sets request body. -func (a *Agent) BodyString(bodyString string) *Agent { - a.req.SetBodyString(bodyString) - - return a -} - -// Body sets request body. -func (a *Agent) Body(body []byte) *Agent { - a.req.SetBody(body) - - return a -} - -// BodyStream sets request body stream and, optionally body size. -// -// If bodySize is >= 0, then the bodyStream must provide exactly bodySize bytes -// before returning io.EOF. -// -// If bodySize < 0, then bodyStream is read until io.EOF. -// -// bodyStream.Close() is called after finishing reading all body data -// if it implements io.Closer. -// -// Note that GET and HEAD requests cannot have body. -func (a *Agent) BodyStream(bodyStream io.Reader, bodySize int) *Agent { - a.req.SetBodyStream(bodyStream, bodySize) - - return a -} - -// JSON sends a JSON request. -func (a *Agent) JSON(v any, ctype ...string) *Agent { - if a.jsonEncoder == nil { - a.jsonEncoder = json.Marshal - } - - if len(ctype) > 0 { - a.req.Header.SetContentType(ctype[0]) - } else { - a.req.Header.SetContentType(MIMEApplicationJSON) - } - - if body, err := a.jsonEncoder(v); err != nil { - a.errs = append(a.errs, err) - } else { - a.req.SetBody(body) - } - - return a -} - -// XML sends an XML request. -func (a *Agent) XML(v any) *Agent { - a.req.Header.SetContentType(MIMEApplicationXML) - - if body, err := xml.Marshal(v); err != nil { - a.errs = append(a.errs, err) - } else { - a.req.SetBody(body) - } - - return a -} - -// Form sends form request with body if args is non-nil. -// -// It is recommended obtaining args via AcquireArgs and release it -// manually in performance-critical code. -func (a *Agent) Form(args *Args) *Agent { - a.req.Header.SetContentType(MIMEApplicationForm) - - if args != nil { - a.req.SetBody(args.QueryString()) - } - - return a -} - -// FormFile represents multipart form file -type FormFile struct { - // Fieldname is form file's field name - Fieldname string - // Name is form file's name - Name string - // Content is form file's content - Content []byte - // autoRelease indicates if returns the object - // acquired via AcquireFormFile to the pool. - autoRelease bool -} - -// FileData appends files for multipart form request. -// -// It is recommended obtaining formFile via AcquireFormFile and release it -// manually in performance-critical code. -func (a *Agent) FileData(formFiles ...*FormFile) *Agent { - a.formFiles = append(a.formFiles, formFiles...) - - return a -} - -// SendFile reads file and appends it to multipart form request. -func (a *Agent) SendFile(filename string, fieldname ...string) *Agent { - content, err := os.ReadFile(filepath.Clean(filename)) - if err != nil { - a.errs = append(a.errs, err) - return a - } - - ff := AcquireFormFile() - if len(fieldname) > 0 && fieldname[0] != "" { - ff.Fieldname = fieldname[0] - } else { - ff.Fieldname = "file" + strconv.Itoa(len(a.formFiles)+1) - } - ff.Name = filepath.Base(filename) - ff.Content = append(ff.Content, content...) - ff.autoRelease = true - - a.formFiles = append(a.formFiles, ff) - - return a -} - -// SendFiles reads files and appends them to multipart form request. -// -// Examples: -// -// SendFile("/path/to/file1", "fieldname1", "/path/to/file2") -func (a *Agent) SendFiles(filenamesAndFieldnames ...string) *Agent { - pairs := len(filenamesAndFieldnames) - if pairs&1 == 1 { - filenamesAndFieldnames = append(filenamesAndFieldnames, "") - } - - for i := 0; i < pairs; i += 2 { - a.SendFile(filenamesAndFieldnames[i], filenamesAndFieldnames[i+1]) - } - - return a -} - -// Boundary sets boundary for multipart form request. -func (a *Agent) Boundary(boundary string) *Agent { - a.boundary = boundary - - return a -} - -// MultipartForm sends multipart form request with k-v and files. -// -// It is recommended obtaining args via AcquireArgs and release it -// manually in performance-critical code. -func (a *Agent) MultipartForm(args *Args) *Agent { - if a.mw == nil { - a.mw = multipart.NewWriter(a.req.BodyWriter()) - } - - if a.boundary != "" { - if err := a.mw.SetBoundary(a.boundary); err != nil { - a.errs = append(a.errs, err) - return a - } - } - - a.req.Header.SetMultipartFormBoundary(a.mw.Boundary()) - - if args != nil { - args.VisitAll(func(key, value []byte) { - if err := a.mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)); err != nil { - a.errs = append(a.errs, err) - } - }) - } - - for _, ff := range a.formFiles { - w, err := a.mw.CreateFormFile(ff.Fieldname, ff.Name) - if err != nil { - a.errs = append(a.errs, err) - continue - } - if _, err = w.Write(ff.Content); err != nil { - a.errs = append(a.errs, err) - } - } - - if err := a.mw.Close(); err != nil { - a.errs = append(a.errs, err) - } - - return a -} - -/************************** End Request Setting **************************/ - -/************************** Agent Setting **************************/ - -// Debug mode enables logging request and response detail -func (a *Agent) Debug(w ...io.Writer) *Agent { - a.debugWriter = os.Stdout - if len(w) > 0 { - a.debugWriter = w[0] - } - - return a -} - -// Timeout sets request timeout duration. -func (a *Agent) Timeout(timeout time.Duration) *Agent { - a.timeout = timeout - - return a -} - -// Reuse enables the Agent instance to be used again after one request. -// -// If agent is reusable, then it should be released manually when it is no -// longer used. -func (a *Agent) Reuse() *Agent { - a.reuse = true - - return a -} - -// InsecureSkipVerify controls whether the Agent verifies the server -// certificate chain and host name. -func (a *Agent) InsecureSkipVerify() *Agent { - if a.HostClient.TLSConfig == nil { - a.HostClient.TLSConfig = &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We explicitly let the user set insecure mode here - } else { - a.HostClient.TLSConfig.InsecureSkipVerify = true - } - - return a -} - -// TLSConfig sets tls config. -func (a *Agent) TLSConfig(config *tls.Config) *Agent { - a.HostClient.TLSConfig = config - - return a -} - -// MaxRedirectsCount sets max redirect count for GET and HEAD. -func (a *Agent) MaxRedirectsCount(count int) *Agent { - a.maxRedirectsCount = count - - return a -} - -// JSONEncoder sets custom json encoder. -func (a *Agent) JSONEncoder(jsonEncoder utils.JSONMarshal) *Agent { - a.jsonEncoder = jsonEncoder - - return a -} - -// JSONDecoder sets custom json decoder. -func (a *Agent) JSONDecoder(jsonDecoder utils.JSONUnmarshal) *Agent { - a.jsonDecoder = jsonDecoder - - return a -} - -// Request returns Agent request instance. -func (a *Agent) Request() *Request { - return a.req -} - -// SetResponse sets custom response for the Agent instance. -// -// It is recommended obtaining custom response via AcquireResponse and release it -// manually in performance-critical code. -func (a *Agent) SetResponse(customResp *Response) *Agent { - a.resp = customResp - - return a -} - -// Dest sets custom dest. -// -// The contents of dest will be replaced by the response body, if the dest -// is too small a new slice will be allocated. -func (a *Agent) Dest(dest []byte) *Agent { - a.dest = dest - - return a -} - -// RetryIf controls whether a retry should be attempted after an error. -// -// By default, will use isIdempotent function from fasthttp -func (a *Agent) RetryIf(retryIf RetryIfFunc) *Agent { - a.HostClient.RetryIf = retryIf - return a -} - -/************************** End Agent Setting **************************/ - -// Bytes returns the status code, bytes body and errors of url. -// -// it's not safe to use Agent after calling [Agent.Bytes] -func (a *Agent) Bytes() (int, []byte, []error) { - defer a.release() - return a.bytes() -} - -func (a *Agent) bytes() (code int, body []byte, errs []error) { //nolint:nonamedreturns,revive // We want to overwrite the body in a deferred func. TODO: Check if we really need to do this. We eventually want to get rid of all named returns. - if errs = append(errs, a.errs...); len(errs) > 0 { - return code, body, errs - } - - var ( - req = a.req - resp *Response - nilResp bool - ) - - if a.resp == nil { - resp = AcquireResponse() - nilResp = true - } else { - resp = a.resp - } - - defer func() { - if a.debugWriter != nil { - printDebugInfo(req, resp, a.debugWriter) - } - - if len(errs) == 0 { - code = resp.StatusCode() - } - - body = append(a.dest, resp.Body()...) //nolint:gocritic // We want to append to the returned slice here - - if nilResp { - ReleaseResponse(resp) - } - }() - - if a.timeout > 0 { - if err := a.HostClient.DoTimeout(req, resp, a.timeout); err != nil { - errs = append(errs, err) - return code, body, errs - } - } else if a.maxRedirectsCount > 0 && (string(req.Header.Method()) == MethodGet || string(req.Header.Method()) == MethodHead) { - if err := a.HostClient.DoRedirects(req, resp, a.maxRedirectsCount); err != nil { - errs = append(errs, err) - return code, body, errs - } - } else if err := a.HostClient.Do(req, resp); err != nil { - errs = append(errs, err) - } - - return code, body, errs -} - -func printDebugInfo(req *Request, resp *Response, w io.Writer) { - msg := fmt.Sprintf("Connected to %s(%s)\r\n\r\n", req.URI().Host(), resp.RemoteAddr()) - _, _ = w.Write(utils.UnsafeBytes(msg)) //nolint:errcheck // This will never fail - _, _ = req.WriteTo(w) //nolint:errcheck // This will never fail - _, _ = resp.WriteTo(w) //nolint:errcheck // This will never fail -} - -// String returns the status code, string body and errors of url. -// -// it's not safe to use Agent after calling [Agent.String] -func (a *Agent) String() (int, string, []error) { - defer a.release() - code, body, errs := a.bytes() - // TODO: There might be a data race here on body. Maybe use utils.CopyBytes on it? - - return code, utils.UnsafeString(body), errs -} - -// Struct returns the status code, bytes body and errors of URL. -// And bytes body will be unmarshalled to given v. -// -// it's not safe to use Agent after calling [Agent.Struct] -func (a *Agent) Struct(v any) (int, []byte, []error) { - defer a.release() - - code, body, errs := a.bytes() - if len(errs) > 0 { - return code, body, errs - } - - // TODO: This should only be done once - if a.jsonDecoder == nil { - a.jsonDecoder = json.Unmarshal - } - - if err := a.jsonDecoder(body, v); err != nil { - errs = append(errs, err) - } - - return code, body, errs -} - -func (a *Agent) release() { - if !a.reuse { - ReleaseAgent(a) - } else { - a.errs = a.errs[:0] - } -} - -func (a *Agent) reset() { - a.HostClient = nil - a.req.Reset() - a.resp = nil - a.dest = nil - a.timeout = 0 - a.args = nil - a.errs = a.errs[:0] - a.debugWriter = nil - a.mw = nil - a.reuse = false - a.parsed = false - a.maxRedirectsCount = 0 - a.boundary = "" - a.Name = "" - a.NoDefaultUserAgentHeader = false - for i, ff := range a.formFiles { - if ff.autoRelease { - ReleaseFormFile(ff) - } - a.formFiles[i] = nil - } - a.formFiles = a.formFiles[:0] -} - -var ( - clientPool sync.Pool - agentPool = sync.Pool{ - New: func() any { - return &Agent{req: &Request{}} - }, - } - responsePool sync.Pool - argsPool sync.Pool - formFilePool sync.Pool -) - -// AcquireClient returns an empty Client instance from client pool. -// -// The returned Client instance may be passed to ReleaseClient when it is -// no longer needed. This allows Client recycling, reduces GC pressure -// and usually improves performance. -func AcquireClient() *Client { - v := clientPool.Get() - if v == nil { - return &Client{} - } - c, ok := v.(*Client) - if !ok { - panic(errors.New("failed to type-assert to *Client")) - } - return c -} - -// ReleaseClient returns c acquired via AcquireClient to client pool. -// -// It is forbidden accessing req and/or it's members after returning -// it to client pool. -func ReleaseClient(c *Client) { - c.UserAgent = "" - c.NoDefaultUserAgentHeader = false - c.JSONEncoder = nil - c.JSONDecoder = nil - - clientPool.Put(c) -} - -// AcquireAgent returns an empty Agent instance from Agent pool. -// -// The returned Agent instance may be passed to ReleaseAgent when it is -// no longer needed. This allows Agent recycling, reduces GC pressure -// and usually improves performance. -func AcquireAgent() *Agent { - a, ok := agentPool.Get().(*Agent) - if !ok { - panic(errors.New("failed to type-assert to *Agent")) - } - return a -} - -// ReleaseAgent returns an acquired via AcquireAgent to Agent pool. -// -// It is forbidden accessing req and/or it's members after returning -// it to Agent pool. -func ReleaseAgent(a *Agent) { - a.reset() - agentPool.Put(a) -} - -// AcquireResponse returns an empty Response instance from response pool. -// -// The returned Response instance may be passed to ReleaseResponse when it is -// no longer needed. This allows Response recycling, reduces GC pressure -// and usually improves performance. -// Copy from fasthttp -func AcquireResponse() *Response { - v := responsePool.Get() - if v == nil { - return &Response{} - } - r, ok := v.(*Response) - if !ok { - panic(errors.New("failed to type-assert to *Response")) - } - return r -} - -// ReleaseResponse return resp acquired via AcquireResponse to response pool. -// -// It is forbidden accessing resp and/or it's members after returning -// it to response pool. -// Copy from fasthttp -func ReleaseResponse(resp *Response) { - resp.Reset() - responsePool.Put(resp) -} - -// AcquireArgs returns an empty Args object from the pool. -// -// The returned Args may be returned to the pool with ReleaseArgs -// when no longer needed. This allows reducing GC load. -// Copy from fasthttp -func AcquireArgs() *Args { - v := argsPool.Get() - if v == nil { - return &Args{} - } - a, ok := v.(*Args) - if !ok { - panic(errors.New("failed to type-assert to *Args")) - } - return a -} - -// ReleaseArgs returns the object acquired via AcquireArgs to the pool. -// -// String not access the released Args object, otherwise data races may occur. -// Copy from fasthttp -func ReleaseArgs(a *Args) { - a.Reset() - argsPool.Put(a) -} - -// AcquireFormFile returns an empty FormFile object from the pool. -// -// The returned FormFile may be returned to the pool with ReleaseFormFile -// when no longer needed. This allows reducing GC load. -func AcquireFormFile() *FormFile { - v := formFilePool.Get() - if v == nil { - return &FormFile{} - } - ff, ok := v.(*FormFile) - if !ok { - panic(errors.New("failed to type-assert to *FormFile")) - } - return ff -} - -// ReleaseFormFile returns the object acquired via AcquireFormFile to the pool. -// -// String not access the released FormFile object, otherwise data races may occur. -func ReleaseFormFile(ff *FormFile) { - ff.Fieldname = "" - ff.Name = "" - ff.Content = ff.Content[:0] - ff.autoRelease = false - - formFilePool.Put(ff) -} - -const ( - defaultUserAgent = "fiber" -) - -type multipartWriter interface { - Boundary() string - SetBoundary(boundary string) error - CreateFormFile(fieldname, filename string) (io.Writer, error) - WriteField(fieldname, value string) error - Close() error -} diff --git a/client/README.md b/client/README.md new file mode 100644 index 00000000000..a992fcc6fb7 --- /dev/null +++ b/client/README.md @@ -0,0 +1,35 @@ +

Fiber Client

+

Easy-to-use HTTP client based on fasthttp (inspired by resty and axios)

+

Features section describes in detail about Resty capabilities

+ +## Features + +> The characteristics have not yet been written. + +- GET, POST, PUT, DELETE, HEAD, PATCH, OPTIONS, etc. +- Simple and chainable methods for settings and request +- Request Body can be `string`, `[]byte`, `map`, `slice` + - Auto detects `Content-Type` + - Buffer processing for `files` + - Native `*fasthttp.Request` instance can be accessed during middleware and request execution via `Request.RawRequest` + - Request Body can be read multiple time via `Request.RawRequest.GetBody()` +- Response object gives you more possibility + - Access as `[]byte` by `response.Body()` or access as `string` by `response.String()` +- Automatic marshal and unmarshal for JSON and XML content type + - Default is JSON, if you supply struct/map without header Content-Type + - For auto-unmarshal, refer to - + - Success scenario Request.SetResult() and Response.Result(). + - Error scenario Request.SetError() and Response.Error(). + - Supports RFC7807 - application/problem+json & application/problem+xml + - Provide an option to override JSON Marshal/Unmarshal and XML Marshal/Unmarshal + +## Usage + +The following samples will assist you to become as comfortable as possible with `Fiber Client` library. + +```go +// Import Fiber Client into your code and refer it as `client`. +import "github.com/gofiber/fiber/client" +``` + +### Simple GET diff --git a/client/client.go b/client/client.go new file mode 100644 index 00000000000..22cbd197276 --- /dev/null +++ b/client/client.go @@ -0,0 +1,775 @@ +package client + +import ( + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "encoding/xml" + "errors" + "fmt" + "io" + urlpkg "net/url" + "os" + "path/filepath" + "sync" + "time" + + "github.com/gofiber/fiber/v3/log" + + "github.com/gofiber/utils/v2" + + "github.com/valyala/fasthttp" +) + +var ( + ErrInvalidProxyURL = errors.New("invalid proxy url scheme") + ErrFailedToAppendCert = errors.New("failed to append certificate") +) + +// The Client is used to create a Fiber Client with +// client-level settings that apply to all requests +// raise from the client. +// +// Fiber Client also provides an option to override +// or merge most of the client settings at the request. +type Client struct { + mu sync.RWMutex + + fasthttp *fasthttp.Client + + baseURL string + userAgent string + referer string + header *Header + params *QueryParam + cookies *Cookie + path *PathParam + + debug bool + + timeout time.Duration + + // user defined request hooks + userRequestHooks []RequestHook + + // client package defined request hooks + builtinRequestHooks []RequestHook + + // user defined response hooks + userResponseHooks []ResponseHook + + // client package defined response hooks + builtinResponseHooks []ResponseHook + + jsonMarshal utils.JSONMarshal + jsonUnmarshal utils.JSONUnmarshal + xmlMarshal utils.XMLMarshal + xmlUnmarshal utils.XMLUnmarshal + + cookieJar *CookieJar + + // proxy + proxyURL string + + // retry + retryConfig *RetryConfig + + // logger + logger log.CommonLogger +} + +// R raise a request from the client. +func (c *Client) R() *Request { + return AcquireRequest().SetClient(c) +} + +// RequestHook Request returns user-defined request hooks. +func (c *Client) RequestHook() []RequestHook { + return c.userRequestHooks +} + +// AddRequestHook Add user-defined request hooks. +func (c *Client) AddRequestHook(h ...RequestHook) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.userRequestHooks = append(c.userRequestHooks, h...) + return c +} + +// ResponseHook return user-define response hooks. +func (c *Client) ResponseHook() []ResponseHook { + return c.userResponseHooks +} + +// AddResponseHook Add user-defined response hooks. +func (c *Client) AddResponseHook(h ...ResponseHook) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.userResponseHooks = append(c.userResponseHooks, h...) + return c +} + +// JSONMarshal returns json marshal function in Core. +func (c *Client) JSONMarshal() utils.JSONMarshal { + return c.jsonMarshal +} + +// SetJSONMarshal Set json encoder. +func (c *Client) SetJSONMarshal(f utils.JSONMarshal) *Client { + c.jsonMarshal = f + return c +} + +// JSONUnmarshal returns json unmarshal function in Core. +func (c *Client) JSONUnmarshal() utils.JSONUnmarshal { + return c.jsonUnmarshal +} + +// Set json decoder. +func (c *Client) SetJSONUnmarshal(f utils.JSONUnmarshal) *Client { + c.jsonUnmarshal = f + return c +} + +// XMLMarshal returns xml marshal function in Core. +func (c *Client) XMLMarshal() utils.XMLMarshal { + return c.xmlMarshal +} + +// SetXMLMarshal Set xml encoder. +func (c *Client) SetXMLMarshal(f utils.XMLMarshal) *Client { + c.xmlMarshal = f + return c +} + +// XMLUnmarshal returns xml unmarshal function in Core. +func (c *Client) XMLUnmarshal() utils.XMLUnmarshal { + return c.xmlUnmarshal +} + +// SetXMLUnmarshal Set xml decoder. +func (c *Client) SetXMLUnmarshal(f utils.XMLUnmarshal) *Client { + c.xmlUnmarshal = f + return c +} + +// TLSConfig returns tlsConfig in client. +// If client don't have tlsConfig, this function will init it. +func (c *Client) TLSConfig() *tls.Config { + if c.fasthttp.TLSConfig == nil { + c.fasthttp.TLSConfig = &tls.Config{ + MinVersion: tls.VersionTLS12, + } + } + + return c.fasthttp.TLSConfig +} + +// SetTLSConfig sets tlsConfig in client. +func (c *Client) SetTLSConfig(config *tls.Config) *Client { + c.fasthttp.TLSConfig = config + return c +} + +// SetCertificates method sets client certificates into client. +func (c *Client) SetCertificates(certs ...tls.Certificate) *Client { + config := c.TLSConfig() + config.Certificates = append(config.Certificates, certs...) + return c +} + +// SetRootCertificate adds one or more root certificates into client. +func (c *Client) SetRootCertificate(path string) *Client { + cleanPath := filepath.Clean(path) + file, err := os.Open(cleanPath) + if err != nil { + c.logger.Panicf("client: %v", err) + } + defer func() { + if err := file.Close(); err != nil { + c.logger.Panicf("client: failed to close file: %v", err) + } + }() + + pem, err := io.ReadAll(file) + if err != nil { + c.logger.Panicf("client: %v", err) + } + + config := c.TLSConfig() + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + if !config.RootCAs.AppendCertsFromPEM(pem) { + c.logger.Panicf("client: %v", ErrFailedToAppendCert) + } + + return c +} + +// SetRootCertificateFromString method adds one or more root certificates into client. +func (c *Client) SetRootCertificateFromString(pem string) *Client { + config := c.TLSConfig() + + if config.RootCAs == nil { + config.RootCAs = x509.NewCertPool() + } + + if !config.RootCAs.AppendCertsFromPEM([]byte(pem)) { + c.logger.Panicf("client: %v", ErrFailedToAppendCert) + } + + return c +} + +// SetProxyURL sets proxy url in client. It will apply via core to hostclient. +func (c *Client) SetProxyURL(proxyURL string) error { + pURL, err := urlpkg.Parse(proxyURL) + if err != nil { + return fmt.Errorf("client: %w", err) + } + + if pURL.Scheme != "http" && pURL.Scheme != "https" { + return fmt.Errorf("client: %w", ErrInvalidProxyURL) + } + + c.proxyURL = pURL.String() + + return nil +} + +// RetryConfig returns retry config in client. +func (c *Client) RetryConfig() *RetryConfig { + return c.retryConfig +} + +// SetRetryConfig sets retry config in client which is impl by addon/retry package. +func (c *Client) SetRetryConfig(config *RetryConfig) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.retryConfig = config + return c +} + +// BaseURL returns baseurl in Client instance. +func (c *Client) BaseURL() string { + return c.baseURL +} + +// SetBaseURL Set baseUrl which is prefix of real url. +func (c *Client) SetBaseURL(url string) *Client { + c.baseURL = url + return c +} + +// Header method returns header value via key, +// this method will visit all field in the header, +// then sort them. +func (c *Client) Header(key string) []string { + return c.header.PeekMultiple(key) +} + +// AddHeader method adds a single header field and its value in the client instance. +// These headers will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level header options. +func (c *Client) AddHeader(key, val string) *Client { + c.header.Add(key, val) + return c +} + +// SetHeader method sets a single header field and its value in the client instance. +// These headers will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level header options. +func (c *Client) SetHeader(key, val string) *Client { + c.header.Set(key, val) + return c +} + +// AddHeaders method adds multiple headers field and its values at one go in the client instance. +// These headers will be applied to all requests raised from this client instance. Also it can be +// overridden at request level headers options. +func (c *Client) AddHeaders(h map[string][]string) *Client { + c.header.AddHeaders(h) + return c +} + +// SetHeaders method sets multiple headers field and its values at one go in the client instance. +// These headers will be applied to all requests raised from this client instance. Also it can be +// overridden at request level headers options. +func (c *Client) SetHeaders(h map[string]string) *Client { + c.header.SetHeaders(h) + return c +} + +// Param method returns params value via key, +// this method will visit all field in the query param. +func (c *Client) Param(key string) []string { + res := []string{} + tmp := c.params.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + + return res +} + +// AddParam method adds a single query param field and its value in the client instance. +// These params will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level param options. +func (c *Client) AddParam(key, val string) *Client { + c.params.Add(key, val) + return c +} + +// SetParam method sets a single query param field and its value in the client instance. +// These params will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level param options. +func (c *Client) SetParam(key, val string) *Client { + c.params.Set(key, val) + return c +} + +// AddParams method adds multiple query params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) AddParams(m map[string][]string) *Client { + c.params.AddParams(m) + return c +} + +// SetParams method sets multiple params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) SetParams(m map[string]string) *Client { + c.params.SetParams(m) + return c +} + +// SetParamsWithStruct method sets multiple params field and its values at one go in the client instance. +// These params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level params options. +func (c *Client) SetParamsWithStruct(v any) *Client { + c.params.SetParamsWithStruct(v) + return c +} + +// DelParams method deletes single or multiple params field and its values in client. +func (c *Client) DelParams(key ...string) *Client { + for _, v := range key { + c.params.Del(v) + } + return c +} + +// SetUserAgent method sets userAgent field and its value in the client instance. +// This ua will be applied to all requests raised from this client instance. +// Also it can be overridden at request level ua options. +func (c *Client) SetUserAgent(ua string) *Client { + c.userAgent = ua + return c +} + +// SetReferer method sets referer field and its value in the client instance. +// This referer will be applied to all requests raised from this client instance. +// Also it can be overridden at request level referer options. +func (c *Client) SetReferer(r string) *Client { + c.referer = r + return c +} + +// PathParam returns the path param be set in request instance. +// if path param doesn't exist, return empty string. +func (c *Client) PathParam(key string) string { + if val, ok := (*c.path)[key]; ok { + return val + } + + return "" +} + +// SetPathParam method sets a single path param field and its value in the client instance. +// These path params will be applied to all requests raised from this client instance. +// Also it can be overridden at request level path params options. +func (c *Client) SetPathParam(key, val string) *Client { + c.path.SetParam(key, val) + return c +} + +// SetPathParams method sets multiple path params field and its values at one go in the client instance. +// These path params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level path params options. +func (c *Client) SetPathParams(m map[string]string) *Client { + c.path.SetParams(m) + return c +} + +// SetPathParamsWithStruct method sets multiple path params field and its values at one go in the client instance. +// These path params will be applied to all requests raised from this client instance. Also it can be +// overridden at request level path params options. +func (c *Client) SetPathParamsWithStruct(v any) *Client { + c.path.SetParamsWithStruct(v) + return c +} + +// DelPathParams method deletes single or multiple path params field and its values in client. +func (c *Client) DelPathParams(key ...string) *Client { + c.path.DelParams(key...) + return c +} + +// Cookie returns the cookie be set in request instance. +// if cookie doesn't exist, return empty string. +func (c *Client) Cookie(key string) string { + if val, ok := (*c.cookies)[key]; ok { + return val + } + return "" +} + +// SetCookie method sets a single cookie field and its value in the client instance. +// These cookies will be applied to all requests raised from this client instance. +// Also it can be overridden at request level cookie options. +func (c *Client) SetCookie(key, val string) *Client { + c.cookies.SetCookie(key, val) + return c +} + +// SetCookies method sets multiple cookies field and its values at one go in the client instance. +// These cookies will be applied to all requests raised from this client instance. Also it can be +// overridden at request level cookie options. +func (c *Client) SetCookies(m map[string]string) *Client { + c.cookies.SetCookies(m) + return c +} + +// SetCookiesWithStruct method sets multiple cookies field and its values at one go in the client instance. +// These cookies will be applied to all requests raised from this client instance. Also it can be +// overridden at request level cookies options. +func (c *Client) SetCookiesWithStruct(v any) *Client { + c.cookies.SetCookiesWithStruct(v) + return c +} + +// DelCookies method deletes single or multiple cookies field and its values in client. +func (c *Client) DelCookies(key ...string) *Client { + c.cookies.DelCookies(key...) + return c +} + +// SetTimeout method sets timeout val in client instance. +// This value will be applied to all requests raised from this client instance. +// Also, it can be overridden at request level timeout options. +func (c *Client) SetTimeout(t time.Duration) *Client { + c.timeout = t + return c +} + +// Debug enable log debug level output. +func (c *Client) Debug() *Client { + c.debug = true + return c +} + +// DisableDebug disenable log debug level output. +func (c *Client) DisableDebug() *Client { + c.debug = false + return c +} + +// SetCookieJar sets cookie jar in client instance. +func (c *Client) SetCookieJar(cookieJar *CookieJar) *Client { + c.cookieJar = cookieJar + return c +} + +// Get provide an API like axios which send get request. +func (c *Client) Get(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Get(url) +} + +// Post provide an API like axios which send post request. +func (c *Client) Post(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Post(url) +} + +// Head provide a API like axios which send head request. +func (c *Client) Head(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Head(url) +} + +// Put provide an API like axios which send put request. +func (c *Client) Put(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Put(url) +} + +// Delete provide an API like axios which send delete request. +func (c *Client) Delete(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Delete(url) +} + +// Options provide an API like axios which send options request. +func (c *Client) Options(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Options(url) +} + +// Patch provide an API like axios which send patch request. +func (c *Client) Patch(url string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Patch(url) +} + +// Custom provide an API like axios which send custom request. +func (c *Client) Custom(url, method string, cfg ...Config) (*Response, error) { + req := AcquireRequest().SetClient(c) + setConfigToRequest(req, cfg...) + + return req.Custom(url, method) +} + +// SetDial sets dial function in client. +func (c *Client) SetDial(dial fasthttp.DialFunc) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.fasthttp.Dial = dial + return c +} + +// SetLogger sets logger instance in client. +func (c *Client) SetLogger(logger log.CommonLogger) *Client { + c.mu.Lock() + defer c.mu.Unlock() + + c.logger = logger + return c +} + +// Logger returns logger instance of client. +func (c *Client) Logger() log.CommonLogger { + return c.logger +} + +// Reset clear Client object +func (c *Client) Reset() { + c.fasthttp = &fasthttp.Client{} + c.baseURL = "" + c.timeout = 0 + c.userAgent = "" + c.referer = "" + c.proxyURL = "" + c.retryConfig = nil + c.debug = false + + if c.cookieJar != nil { + c.cookieJar.Release() + c.cookieJar = nil + } + + c.path.Reset() + c.cookies.Reset() + c.header.Reset() + c.params.Reset() +} + +// Config for easy to set the request parameters, it should be +// noted that when setting the request body will use JSON as +// the default serialization mechanism, while the priority of +// Body is higher than FormData, and the priority of FormData +// is higher than File. +type Config struct { + Ctx context.Context //nolint:containedctx // It's needed to be stored in the config. + + UserAgent string + Referer string + Header map[string]string + Param map[string]string + Cookie map[string]string + PathParam map[string]string + + Timeout time.Duration + MaxRedirects int + + Body any + FormData map[string]string + File []*File +} + +// setConfigToRequest Set the parameters passed via Config to Request. +func setConfigToRequest(req *Request, config ...Config) { + if len(config) == 0 { + return + } + cfg := config[0] + + if cfg.Ctx != nil { + req.SetContext(cfg.Ctx) + } + + if cfg.UserAgent != "" { + req.SetUserAgent(cfg.UserAgent) + } + + if cfg.Referer != "" { + req.SetReferer(cfg.Referer) + } + + if cfg.Header != nil { + req.SetHeaders(cfg.Header) + } + + if cfg.Param != nil { + req.SetParams(cfg.Param) + } + + if cfg.Cookie != nil { + req.SetCookies(cfg.Cookie) + } + + if cfg.PathParam != nil { + req.SetPathParams(cfg.PathParam) + } + + if cfg.Timeout != 0 { + req.SetTimeout(cfg.Timeout) + } + + if cfg.MaxRedirects != 0 { + req.SetMaxRedirects(cfg.MaxRedirects) + } + + if cfg.Body != nil { + req.SetJSON(cfg.Body) + return + } + + if cfg.FormData != nil { + req.SetFormDatas(cfg.FormData) + return + } + + if cfg.File != nil && len(cfg.File) != 0 { + req.AddFiles(cfg.File...) + return + } +} + +var ( + defaultClient *Client + replaceMu = sync.Mutex{} + defaultUserAgent = "fiber" +) + +// init acquire a default client. +func init() { + defaultClient = NewClient() +} + +// NewClient creates and returns a new Client object. +func NewClient() *Client { + // FOllOW-UP performance optimization + // trie to use a pool to reduce the cost of memory allocation + // for the fiber client and the fasthttp client + // if possible also for other structs -> request header, cookie, query param, path param... + return &Client{ + fasthttp: &fasthttp.Client{}, + header: &Header{ + RequestHeader: &fasthttp.RequestHeader{}, + }, + params: &QueryParam{ + Args: fasthttp.AcquireArgs(), + }, + cookies: &Cookie{}, + path: &PathParam{}, + + userRequestHooks: []RequestHook{}, + builtinRequestHooks: []RequestHook{parserRequestURL, parserRequestHeader, parserRequestBody}, + userResponseHooks: []ResponseHook{}, + builtinResponseHooks: []ResponseHook{parserResponseCookie, logger}, + jsonMarshal: json.Marshal, + jsonUnmarshal: json.Unmarshal, + xmlMarshal: xml.Marshal, + xmlUnmarshal: xml.Unmarshal, + logger: log.DefaultLogger(), + } +} + +// C get default client. +func C() *Client { + return defaultClient +} + +// Replace the defaultClient, the returned function can undo. +func Replace(c *Client) func() { + replaceMu.Lock() + defer replaceMu.Unlock() + + oldClient := defaultClient + defaultClient = c + + return func() { + replaceMu.Lock() + defer replaceMu.Unlock() + + defaultClient = oldClient + } +} + +// Get send a get request use defaultClient, a convenient method. +func Get(url string, cfg ...Config) (*Response, error) { + return C().Get(url, cfg...) +} + +// Post send a post request use defaultClient, a convenient method. +func Post(url string, cfg ...Config) (*Response, error) { + return C().Post(url, cfg...) +} + +// Head send a head request use defaultClient, a convenient method. +func Head(url string, cfg ...Config) (*Response, error) { + return C().Head(url, cfg...) +} + +// Put send a put request use defaultClient, a convenient method. +func Put(url string, cfg ...Config) (*Response, error) { + return C().Put(url, cfg...) +} + +// Delete send a delete request use defaultClient, a convenient method. +func Delete(url string, cfg ...Config) (*Response, error) { + return C().Delete(url, cfg...) +} + +// Options send a options request use defaultClient, a convenient method. +func Options(url string, cfg ...Config) (*Response, error) { + return C().Options(url, cfg...) +} + +// Patch send a patch request use defaultClient, a convenient method. +func Patch(url string, cfg ...Config) (*Response, error) { + return C().Patch(url, cfg...) +} diff --git a/client/client_test.go b/client/client_test.go new file mode 100644 index 00000000000..4fd2e484a63 --- /dev/null +++ b/client/client_test.go @@ -0,0 +1,1642 @@ +package client + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "os" + "reflect" + "sync" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/gofiber/fiber/v3/internal/tlstest" + "github.com/gofiber/utils/v2" + "github.com/stretchr/testify/require" + "github.com/valyala/bytebufferpool" +) + +func startTestServerWithPort(t *testing.T, beforeStarting func(app *fiber.App)) (*fiber.App, string) { + t.Helper() + + app := fiber.New() + + if beforeStarting != nil { + beforeStarting(app) + } + + addrChan := make(chan string) + errChan := make(chan error, 1) + go func() { + err := app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + }) + if err != nil { + errChan <- err + } + }() + + select { + case addr := <-addrChan: + return app, addr + case err := <-errChan: + t.Fatalf("Failed to start test server: %v", err) + } + + return nil, "" +} + +func Test_Client_Add_Hook(t *testing.T) { + t.Parallel() + + t.Run("add request hooks", func(t *testing.T) { + t.Parallel() + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + client := NewClient().AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook1") + return nil + }) + + require.Len(t, client.RequestHook(), 1) + + client.AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook2") + return nil + }, func(_ *Client, _ *Request) error { + buf.WriteString("hook3") + return nil + }) + + require.Len(t, client.RequestHook(), 3) + }) + + t.Run("add response hooks", func(t *testing.T) { + t.Parallel() + client := NewClient().AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { + return nil + }) + + require.Len(t, client.ResponseHook(), 1) + + client.AddResponseHook(func(_ *Client, _ *Response, _ *Request) error { + return nil + }, func(_ *Client, _ *Response, _ *Request) error { + return nil + }) + + require.Len(t, client.ResponseHook(), 3) + }) +} + +func Test_Client_Add_Hook_CheckOrder(t *testing.T) { + t.Parallel() + + buf := bytebufferpool.Get() + defer bytebufferpool.Put(buf) + + client := NewClient(). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook1") + return nil + }). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook2") + return nil + }). + AddRequestHook(func(_ *Client, _ *Request) error { + buf.WriteString("hook3") + return nil + }) + + for _, hook := range client.RequestHook() { + require.NoError(t, hook(client, &Request{})) + } + + require.Equal(t, "hook1hook2hook3", buf.String()) +} + +func Test_Client_Marshal(t *testing.T) { + t.Parallel() + + t.Run("set json marshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONMarshal(func(_ any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.JSONMarshal()(nil) + + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) + }) + + t.Run("set json marshal error", func(t *testing.T) { + t.Parallel() + + emptyErr := errors.New("empty json") + client := NewClient(). + SetJSONMarshal(func(_ any) ([]byte, error) { + return nil, emptyErr + }) + + val, err := client.JSONMarshal()(nil) + require.Nil(t, val) + require.ErrorIs(t, err, emptyErr) + }) + + t.Run("set json unmarshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty json"), err) + }) + + t.Run("set json unmarshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetJSONUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty json") + }) + + err := client.JSONUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty json"), err) + }) + + t.Run("set xml marshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLMarshal(func(_ any) ([]byte, error) { + return []byte("hello"), nil + }) + val, err := client.XMLMarshal()(nil) + + require.NoError(t, err) + require.Equal(t, []byte("hello"), val) + }) + + t.Run("set xml marshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLMarshal(func(_ any) ([]byte, error) { + return nil, errors.New("empty xml") + }) + + val, err := client.XMLMarshal()(nil) + require.Nil(t, val) + require.Equal(t, errors.New("empty xml"), err) + }) + + t.Run("set xml unmarshal", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty xml"), err) + }) + + t.Run("set xml unmarshal error", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetXMLUnmarshal(func(_ []byte, _ any) error { + return errors.New("empty xml") + }) + + err := client.XMLUnmarshal()(nil, nil) + require.Equal(t, errors.New("empty xml"), err) + }) +} + +func Test_Client_SetBaseURL(t *testing.T) { + t.Parallel() + + client := NewClient().SetBaseURL("http://example.com") + + require.Equal(t, "http://example.com", client.BaseURL()) +} + +func Test_Client_Invalid_URL(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + + _, err := NewClient().SetDial(dial). + R(). + Get("http//example") + + require.ErrorIs(t, err, ErrURLFormat) +} + +func Test_Client_Unsupported_Protocol(t *testing.T) { + t.Parallel() + + _, err := NewClient(). + R(). + Get("ftp://example.com") + + require.ErrorIs(t, err, ErrURLFormat) +} + +func Test_Client_ConcurrencyRequests(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + app.All("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname() + " " + c.Method()) + }) + go start() + + client := NewClient().SetDial(dial) + + wg := sync.WaitGroup{} + for i := 0; i < 5; i++ { + for _, method := range []string{"GET", "POST", "PUT", "DELETE", "PATCH"} { + wg.Add(1) + go func(m string) { + defer wg.Done() + resp, err := client.Custom("http://example.com", m) + require.NoError(t, err) + require.Equal(t, "example.com "+m, utils.UnsafeString(resp.RawResponse.Body())) + }(method) + } + } + + wg.Wait() +} + +func Test_Get(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + }) + + return app, addr + } + + t.Run("global get function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) + }) + + t.Run("client get", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := NewClient().Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, "0.0.0.0", utils.UnsafeString(resp.RawResponse.Body())) + }) +} + +func Test_Head(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Head("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + }) + + return app, addr + } + + t.Run("global head function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := Head("http://" + addr) + require.NoError(t, err) + require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) + require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) + }) + + t.Run("client head", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + resp, err := NewClient().Head("http://" + addr) + require.NoError(t, err) + require.Equal(t, "7", resp.Header(fiber.HeaderContentLength)) + require.Equal(t, "", utils.UnsafeString(resp.RawResponse.Body())) + }) +} + +func Test_Post(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) + }) + + return app, addr + } + + t.Run("global post function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Post("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) + + t.Run("client post", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Post("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} + +func Test_Put(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + }) + + return app, addr + } + + t.Run("global put function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Put("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) + + t.Run("client put", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Put("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} + +func Test_Delete(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + }) + + return app, addr + } + + t.Run("global delete function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + time.Sleep(1 * time.Second) + + for i := 0; i < 5; i++ { + resp, err := Delete("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) + + t.Run("client delete", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Delete("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) +} + +func Test_Options(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Options("/", func(c fiber.Ctx) error { + c.Set(fiber.HeaderAllow, "GET, POST, PUT, DELETE, PATCH") + return c.Status(fiber.StatusNoContent).SendString("") + }) + }) + + return app, addr + } + + t.Run("global options function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Options("http://" + addr) + + require.NoError(t, err) + require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) + + t.Run("client options", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Options("http://" + addr) + + require.NoError(t, err) + require.Equal(t, "GET, POST, PUT, DELETE, PATCH", resp.Header(fiber.HeaderAllow)) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + } + }) +} + +func Test_Patch(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + }) + + return app, addr + } + + t.Run("global patch function", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + time.Sleep(1 * time.Second) + + for i := 0; i < 5; i++ { + resp, err := Patch("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) + + t.Run("client patch", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := NewClient().Patch("http://"+addr, Config{ + FormData: map[string]string{ + "foo": "bar", + }, + }) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + } + }) +} + +func Test_Client_UserAgent(t *testing.T) { + t.Parallel() + + setupApp := func() (*fiber.App, string) { + app, addr := startTestServerWithPort(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + }) + }) + + return app, addr + } + + t.Run("default", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + resp, err := Get("http://" + addr) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, defaultUserAgent, resp.String()) + } + }) + + t.Run("custom", func(t *testing.T) { + t.Parallel() + + app, addr := setupApp() + defer func() { + require.NoError(t, app.Shutdown()) + }() + + for i := 0; i < 5; i++ { + c := NewClient(). + SetUserAgent("ua") + + resp, err := c.Get("http://" + addr) + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "ua", resp.String()) + } + }) +} + +func Test_Client_Header(t *testing.T) { + t.Parallel() + + t.Run("add header", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddHeader("foo", "bar").AddHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set header", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddHeader("foo", "bar").SetHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add headers", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + AddHeaders(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Header("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + SetHeaders(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set header case insensitive", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetHeader("foo", "bar"). + AddHeader("FOO", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) +} + +func Test_Client_Header_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, _ = c.Write(key) //nolint:errcheck // It is fine to ignore the error here + _, _ = c.Write(value) //nolint:errcheck // It is fine to ignore the error here + } + }) + return nil + } + + wrapAgent := func(c *Client) { + c.SetHeader("k1", "v1"). + AddHeader("k1", "v11"). + AddHeaders(map[string][]string{ + "k1": {"v22", "v33"}, + }). + SetHeaders(map[string]string{ + "k2": "v2", + }). + AddHeader("k2", "v22") + } + + testClient(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +} + +func Test_Client_Cookie(t *testing.T) { + t.Parallel() + + t.Run("set cookie", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetCookie("foo", "bar") + require.Equal(t, "bar", req.Cookie("foo")) + + req.SetCookie("foo", "bar1") + require.Equal(t, "bar1", req.Cookie("foo")) + }) + + t.Run("set cookies", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.SetCookies(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) + + t.Run("set cookies with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `cookie:"int"` + CookieString string `cookie:"string"` + } + + req := NewClient().SetCookiesWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.Cookie("int")) + require.Equal(t, "foo", req.Cookie("string")) + }) + + t.Run("del cookies", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.DelCookies("foo") + require.Equal(t, "", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) +} + +func Test_Client_Cookie_With_Server(t *testing.T) { + t.Parallel() + + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) + } + + wrapAgent := func(c *Client) { + c.SetCookie("k1", "v1"). + SetCookies(map[string]string{ + "k2": "v2", + "k3": "v3", + "k4": "v4", + }).DelCookies("k4") + } + + testClient(t, handler, wrapAgent, "v1v2v3") +} + +func Test_Client_CookieJar(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") +} + +func Test_Client_CookieJar_Response(t *testing.T) { + t.Parallel() + + t.Run("without expiration", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k4", + Value: "v4", + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + require.Len(t, jar.getCookiesByHost("example.com"), 3) + }) + + t.Run("with expiration", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k4", + Value: "v4", + Expires: time.Now().Add(1 * time.Nanosecond), + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + jar.SetKeyValue("example", "k3", "v3") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + require.Len(t, jar.getCookiesByHost("example.com"), 2) + }) + + t.Run("override cookie value", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "k1", + Value: "v2", + }) + return c.SendString( + c.Cookies("k1") + c.Cookies("k2")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + jar.SetKeyValue("example.com", "k2", "v2") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1v2") + + for _, cookie := range jar.getCookiesByHost("example.com") { + if string(cookie.Key()) == "k1" { + require.Equal(t, "v2", string(cookie.Value())) + } + } + }) + + t.Run("different domain", func(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.SendString(c.Cookies("k1")) + } + + jar := AcquireCookieJar() + defer ReleaseCookieJar(jar) + + jar.SetKeyValue("example.com", "k1", "v1") + + wrapAgent := func(c *Client) { + c.SetCookieJar(jar) + } + testClient(t, handler, wrapAgent, "v1") + + require.Len(t, jar.getCookiesByHost("example.com"), 1) + require.Empty(t, jar.getCookiesByHost("example")) + }) +} + +func Test_Client_Referer(t *testing.T) { + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.Referer()) + } + + wrapAgent := func(c *Client) { + c.SetReferer("http://referer.com") + } + + testClient(t, handler, wrapAgent, "http://referer.com") +} + +func Test_Client_QueryParam(t *testing.T) { + t.Parallel() + + t.Run("add param", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddParam("foo", "bar").AddParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.AddParam("foo", "bar").SetParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetParam("foo", "bar"). + AddParams(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Param("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + p := NewClient() + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Empty(t, p.Param("unexport")) + + require.Len(t, p.Param("TInt"), 1) + require.Equal(t, "5", p.Param("TInt")[0]) + + require.Len(t, p.Param("TString"), 1) + require.Equal(t, "string", p.Param("TString")[0]) + + require.Len(t, p.Param("TFloat"), 1) + require.Equal(t, "3.1", p.Param("TFloat")[0]) + + require.Len(t, p.Param("TBool"), 1) + + tslice := p.Param("TSlice") + require.Len(t, tslice, 2) + require.Equal(t, "foo", tslice[0]) + require.Equal(t, "bar", tslice[1]) + + tint := p.Param("TSlice") + require.Len(t, tint, 2) + require.Equal(t, "foo", tint[0]) + require.Equal(t, "bar", tint[1]) + }) + + t.Run("del params", func(t *testing.T) { + t.Parallel() + req := NewClient() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelParams("foo", "bar") + + res := req.Param("foo") + require.Empty(t, res) + + res = req.Param("bar") + require.Empty(t, res) + }) +} + +func Test_Client_QueryParam_With_Server(t *testing.T) { + handler := func(c fiber.Ctx) error { + _, _ = c.WriteString(c.Query("k1")) //nolint:errcheck // It is fine to ignore the error here + _, _ = c.WriteString(c.Query("k2")) //nolint:errcheck // It is fine to ignore the error here + + return nil + } + + wrapAgent := func(c *Client) { + c.SetParam("k1", "v1"). + AddParam("k2", "v2") + } + + testClient(t, handler, wrapAgent, "v1v2") +} + +func Test_Client_PathParam(t *testing.T) { + t.Parallel() + + t.Run("set path param", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetPathParam("foo", "bar") + require.Equal(t, "bar", req.PathParam("foo")) + + req.SetPathParam("foo", "bar1") + require.Equal(t, "bar1", req.PathParam("foo")) + }) + + t.Run("set path params", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.SetPathParams(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) + + t.Run("set path params with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `path:"int"` + CookieString string `path:"string"` + } + + req := NewClient().SetPathParamsWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.PathParam("int")) + require.Equal(t, "foo", req.PathParam("string")) + }) + + t.Run("del path params", func(t *testing.T) { + t.Parallel() + req := NewClient(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.DelPathParams("foo") + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) +} + +func Test_Client_PathParam_With_Server(t *testing.T) { + app, dial, start := createHelperServer(t) + + app.Get("/:test", func(c fiber.Ctx) error { + return c.SendString(c.Params("test")) + }) + + go start() + + resp, err := NewClient().SetDial(dial). + SetPathParam("path", "test"). + Get("http://example.com/:path") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test", resp.String()) +} + +func Test_Client_TLS(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.NoError(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls", resp.String()) +} + +func Test_Client_TLS_Error(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + clientTLSConf.MaxVersion = tls.VersionTLS12 + serverTLSConf.MinVersion = tls.VersionTLS13 + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.Error(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + +func Test_Client_TLS_Empty_TLSConfig(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("tls") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.Get("https://" + ln.Addr().String()) + + require.Error(t, err) + require.NotEqual(t, clientTLSConf, client.TLSConfig()) + require.Nil(t, resp) +} + +func Test_Client_SetCertificates(t *testing.T) { + t.Parallel() + + serverTLSConf, _, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + client := NewClient().SetCertificates(serverTLSConf.Certificates...) + require.Len(t, client.TLSConfig().Certificates, 1) +} + +func Test_Client_SetRootCertificate(t *testing.T) { + t.Parallel() + + client := NewClient().SetRootCertificate("../.github/testdata/ssl.pem") + require.NotNil(t, client.TLSConfig().RootCAs) +} + +func Test_Client_SetRootCertificateFromString(t *testing.T) { + t.Parallel() + + file, err := os.Open("../.github/testdata/ssl.pem") + defer func() { require.NoError(t, file.Close()) }() + require.NoError(t, err) + + pem, err := io.ReadAll(file) + require.NoError(t, err) + + client := NewClient().SetRootCertificateFromString(string(pem)) + require.NotNil(t, client.TLSConfig().RootCAs) +} + +func Test_Client_R(t *testing.T) { + t.Parallel() + + client := NewClient() + req := client.R() + + require.Equal(t, "Request", reflect.TypeOf(req).Elem().Name()) + require.Equal(t, client, req.Client()) +} + +func Test_Replace(t *testing.T) { + app, dial, start := createHelperServer(t) + + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(string(c.Request().Header.Peek("k1"))) + }) + + go start() + + C().SetDial(dial) + resp, err := Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) + + r := NewClient().SetDial(dial).SetHeader("k1", "v1") + clean := Replace(r) + resp, err = Get("http://example.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "v1", resp.String()) + + clean() + + C().SetDial(dial) + resp, err = Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) + + C().SetDial(nil) +} + +func Test_Set_Config_To_Request(t *testing.T) { + t.Parallel() + + t.Run("set ctx", func(t *testing.T) { + t.Parallel() + key := struct{}{} + + ctx := context.Background() + ctx = context.WithValue(ctx, key, "v1") + + req := AcquireRequest() + + setConfigToRequest(req, Config{Ctx: ctx}) + + require.Equal(t, "v1", req.Context().Value(key)) + }) + + t.Run("set useragent", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{UserAgent: "agent"}) + + require.Equal(t, "agent", req.UserAgent()) + }) + + t.Run("set referer", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Referer: "referer"}) + + require.Equal(t, "referer", req.Referer()) + }) + + t.Run("set header", func(t *testing.T) { + req := AcquireRequest() + + setConfigToRequest(req, Config{Header: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Header("k1")[0]) + }) + + t.Run("set params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Param: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Param("k1")[0]) + }) + + t.Run("set cookies", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Cookie: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.Cookie("k1")) + }) + + t.Run("set pathparam", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{PathParam: map[string]string{ + "k1": "v1", + }}) + + require.Equal(t, "v1", req.PathParam("k1")) + }) + + t.Run("set timeout", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Timeout: 1 * time.Second}) + + require.Equal(t, 1*time.Second, req.Timeout()) + }) + + t.Run("set maxredirects", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{MaxRedirects: 1}) + + require.Equal(t, 1, req.MaxRedirects()) + }) + + t.Run("set body", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{Body: "test"}) + + require.Equal(t, "test", req.body) + }) + + t.Run("set file", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + + setConfigToRequest(req, Config{File: []*File{ + { + name: "test", + path: "path", + }, + }}) + + require.Equal(t, "path", req.File("test").path) + }) +} + +func Test_Client_SetProxyURL(t *testing.T) { + t.Parallel() + + app, dial, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + t.Cleanup(func() { + require.NoError(t, app.Shutdown()) + }) + + time.Sleep(1 * time.Second) + + t.Run("success", func(t *testing.T) { + t.Parallel() + client := NewClient().SetDial(dial) + err := client.SetProxyURL("http://test.com") + + require.NoError(t, err) + + _, err = client.Get("http://localhost:3000") + + require.NoError(t, err) + }) + + t.Run("wrong url", func(t *testing.T) { + t.Parallel() + client := NewClient() + + err := client.SetProxyURL(":this is not a url") + + require.Error(t, err) + }) + + t.Run("error", func(t *testing.T) { + t.Parallel() + client := NewClient() + + err := client.SetProxyURL("htgdftp://test.com") + + require.Error(t, err) + }) +} + +func Test_Client_SetRetryConfig(t *testing.T) { + t.Parallel() + + retryConfig := &retry.Config{ + InitialInterval: 1 * time.Second, + MaxRetryCount: 3, + } + + core, client, req := newCore(), NewClient(), AcquireRequest() + req.SetURL("http://example.com") + client.SetRetryConfig(retryConfig) + _, err := core.execute(context.Background(), client, req) + + require.NoError(t, err) + require.Equal(t, retryConfig.InitialInterval, client.RetryConfig().InitialInterval) + require.Equal(t, retryConfig.MaxRetryCount, client.RetryConfig().MaxRetryCount) +} + +func Benchmark_Client_Request(b *testing.B) { + app, dial, start := createHelperServer(b) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + client := NewClient().SetDial(dial) + + b.ResetTimer() + b.ReportAllocs() + + var err error + var resp *Response + for i := 0; i < b.N; i++ { + resp, err = client.Get("http://example.com") + resp.Close() + } + require.NoError(b, err) +} + +func Benchmark_Client_Request_Parallel(b *testing.B) { + app, dial, start := createHelperServer(b) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + go start() + + client := NewClient().SetDial(dial) + + b.ResetTimer() + b.ReportAllocs() + + b.RunParallel(func(pb *testing.PB) { + var err error + var resp *Response + for pb.Next() { + resp, err = client.Get("http://example.com") + resp.Close() + } + require.NoError(b, err) + }) +} diff --git a/client/cookiejar.go b/client/cookiejar.go new file mode 100644 index 00000000000..c66d5f3b7c0 --- /dev/null +++ b/client/cookiejar.go @@ -0,0 +1,245 @@ +// The code has been taken from https://github.com/valyala/fasthttp/pull/526 originally. +package client + +import ( + "bytes" + "errors" + "net" + "sync" + "time" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +var cookieJarPool = sync.Pool{ + New: func() any { + return &CookieJar{} + }, +} + +// AcquireCookieJar returns an empty CookieJar object from pool. +func AcquireCookieJar() *CookieJar { + jar, ok := cookieJarPool.Get().(*CookieJar) + if !ok { + panic(errors.New("failed to type-assert to *CookieJar")) + } + + return jar +} + +// ReleaseCookieJar returns CookieJar to the pool. +func ReleaseCookieJar(c *CookieJar) { + c.Release() + cookieJarPool.Put(c) +} + +// CookieJar manages cookie storage. It is used by the client to store cookies. +type CookieJar struct { + mu sync.Mutex + hostCookies map[string][]*fasthttp.Cookie +} + +// Get returns the cookies stored from a specific domain. +// If there were no cookies related with host returned slice will be nil. +// +// CookieJar keeps a copy of the cookies, so the returned cookies can be released safely. +func (cj *CookieJar) Get(uri *fasthttp.URI) []*fasthttp.Cookie { + if uri == nil { + return nil + } + + return cj.getByHostAndPath(uri.Host(), uri.Path()) +} + +// get returns the cookies stored from a specific host and path. +func (cj *CookieJar) getByHostAndPath(host, path []byte) []*fasthttp.Cookie { + if cj.hostCookies == nil { + return nil + } + + var ( + err error + cookies []*fasthttp.Cookie + hostStr = utils.UnsafeString(host) + ) + + // port must not be included. + hostStr, _, err = net.SplitHostPort(hostStr) + if err != nil { + hostStr = utils.UnsafeString(host) + } + // get cookies deleting expired ones + cookies = cj.getCookiesByHost(hostStr) + + newCookies := make([]*fasthttp.Cookie, 0, len(cookies)) + for i := 0; i < len(cookies); i++ { + cookie := cookies[i] + if len(path) > 1 && len(cookie.Path()) > 1 && !bytes.HasPrefix(cookie.Path(), path) { + continue + } + newCookies = append(newCookies, cookie) + } + + return newCookies +} + +// getCookiesByHost returns the cookies stored from a specific host. +// If cookies are expired they will be deleted. +func (cj *CookieJar) getCookiesByHost(host string) []*fasthttp.Cookie { + cj.mu.Lock() + defer cj.mu.Unlock() + + now := time.Now() + cookies := cj.hostCookies[host] + + for i := 0; i < len(cookies); i++ { + c := cookies[i] + if !c.Expire().Equal(fasthttp.CookieExpireUnlimited) && c.Expire().Before(now) { // release cookie if expired + cookies = append(cookies[:i], cookies[i+1:]...) + fasthttp.ReleaseCookie(c) + i-- + } + } + + return cookies +} + +// Set sets cookies for a specific host. +// The host is get from uri.Host(). +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) Set(uri *fasthttp.URI, cookies ...*fasthttp.Cookie) { + if uri == nil { + return + } + + cj.SetByHost(uri.Host(), cookies...) +} + +// SetByHost sets cookies for a specific host. +// If the cookie key already exists it will be replaced by the new cookie value. +// +// CookieJar keeps a copy of the cookies, so the parsed cookies can be released safely. +func (cj *CookieJar) SetByHost(host []byte, cookies ...*fasthttp.Cookie) { + hostStr := utils.UnsafeString(host) + + cj.mu.Lock() + defer cj.mu.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*fasthttp.Cookie) + } + + hostCookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exist in the map, then we must make a copy for the key to avoid unsafe usage. + hostStr = string(host) + } + + for _, cookie := range cookies { + c := searchCookieByKeyAndPath(cookie.Key(), cookie.Path(), hostCookies) + if c == nil { + // If the cookie does not exist in the slice, let's acquire new cookie and store it. + c = fasthttp.AcquireCookie() + hostCookies = append(hostCookies, c) + } + c.CopyTo(cookie) // override cookie properties + } + cj.hostCookies[hostStr] = hostCookies +} + +// SetKeyValue sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValue(host, key, value string) { + c := fasthttp.AcquireCookie() + c.SetKey(key) + c.SetValue(value) + + cj.SetByHost(utils.UnsafeBytes(host), c) +} + +// SetKeyValueBytes sets a cookie by key and value for a specific host. +// +// This function prevents extra allocations by making repeated cookies +// not being duplicated. +func (cj *CookieJar) SetKeyValueBytes(host string, key, value []byte) { + c := fasthttp.AcquireCookie() + c.SetKeyBytes(key) + c.SetValueBytes(value) + + cj.SetByHost(utils.UnsafeBytes(host), c) +} + +// dumpCookiesToReq dumps the stored cookies to the request. +func (cj *CookieJar) dumpCookiesToReq(req *fasthttp.Request) { + uri := req.URI() + + cookies := cj.getByHostAndPath(uri.Host(), uri.Path()) + for _, cookie := range cookies { + req.Header.SetCookieBytesKV(cookie.Key(), cookie.Value()) + } +} + +// parseCookiesFromResp parses the response cookies and stores them. +func (cj *CookieJar) parseCookiesFromResp(host, path []byte, resp *fasthttp.Response) { + hostStr := utils.UnsafeString(host) + + cj.mu.Lock() + defer cj.mu.Unlock() + + if cj.hostCookies == nil { + cj.hostCookies = make(map[string][]*fasthttp.Cookie) + } + cookies, ok := cj.hostCookies[hostStr] + if !ok { + // If the key does not exist in the map then + // we must make a copy for the key to avoid unsafe usage. + hostStr = string(host) + } + + now := time.Now() + resp.Header.VisitAllCookie(func(key, value []byte) { + isCreated := false + c := searchCookieByKeyAndPath(key, path, cookies) + if c == nil { + c, isCreated = fasthttp.AcquireCookie(), true + } + + _ = c.ParseBytes(value) //nolint:errcheck // ignore error + if c.Expire().Equal(fasthttp.CookieExpireUnlimited) || c.Expire().After(now) { + cookies = append(cookies, c) + } else if isCreated { + fasthttp.ReleaseCookie(c) + } + }) + cj.hostCookies[hostStr] = cookies +} + +// Release releases all cookie values. +func (cj *CookieJar) Release() { + // FOllOW-UP performance optimization + // currently a race condition is found because the reset method modifies a value which is not a copy but a reference -> solution should be to make a copy + // for _, v := range cj.hostCookies { + // for _, c := range v { + // fasthttp.ReleaseCookie(c) + // } + // } + cj.hostCookies = nil +} + +// searchCookieByKeyAndPath searches for a cookie by key and path. +func searchCookieByKeyAndPath(key, path []byte, cookies []*fasthttp.Cookie) *fasthttp.Cookie { + for _, c := range cookies { + if bytes.Equal(key, c.Key()) { + if len(path) <= 1 || bytes.HasPrefix(c.Path(), path) { + return c + } + } + } + + return nil +} diff --git a/client/cookiejar_test.go b/client/cookiejar_test.go new file mode 100644 index 00000000000..3b6fdcda83d --- /dev/null +++ b/client/cookiejar_test.go @@ -0,0 +1,213 @@ +package client + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" +) + +func checkKeyValue(t *testing.T, cj *CookieJar, cookie *fasthttp.Cookie, uri *fasthttp.URI, n int) { + t.Helper() + + cs := cj.Get(uri) + require.GreaterOrEqual(t, len(cs), n) + + c := cs[n-1] + require.NotNil(t, c) + + require.Equal(t, string(c.Key()), string(cookie.Key())) + require.Equal(t, string(c.Value()), string(cookie.Value())) +} + +func TestCookieJarGet(t *testing.T) { + t.Parallel() + + url := []byte("http://fasthttp.com/") + url1 := []byte("http://fasthttp.com/make") + url11 := []byte("http://fasthttp.com/hola") + url2 := []byte("http://fasthttp.com/make/fasthttp") + url3 := []byte("http://fasthttp.com/make/fasthttp/great") + prefix := []byte("/") + prefix1 := []byte("/make") + prefix2 := []byte("/make/fasthttp") + prefix3 := []byte("/make/fasthttp/great") + cj := &CookieJar{} + + c1 := &fasthttp.Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetPath("/make/") + + c2 := &fasthttp.Cookie{} + c2.SetKey("kk") + c2.SetValue("vv") + c2.SetPath("/make/fasthttp") + + c3 := &fasthttp.Cookie{} + c3.SetKey("kkk") + c3.SetValue("vvv") + c3.SetPath("/make/fasthttp/great") + + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, url)) + + uri1 := fasthttp.AcquireURI() + require.NoError(t, uri1.Parse(nil, url1)) + + uri11 := fasthttp.AcquireURI() + require.NoError(t, uri11.Parse(nil, url11)) + + uri2 := fasthttp.AcquireURI() + require.NoError(t, uri2.Parse(nil, url2)) + + uri3 := fasthttp.AcquireURI() + require.NoError(t, uri3.Parse(nil, url3)) + + cj.Set(uri1, c1, c2, c3) + + cookies := cj.Get(uri1) + require.Len(t, cookies, 3) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix1)) + } + + cookies = cj.Get(uri11) + require.Empty(t, cookies) + + cookies = cj.Get(uri2) + require.Len(t, cookies, 2) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix2)) + } + + cookies = cj.Get(uri3) + require.Len(t, cookies, 1) + + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix3)) + } + + cookies = cj.Get(uri) + require.Len(t, cookies, 3) + for _, cookie := range cookies { + require.True(t, bytes.HasPrefix(cookie.Path(), prefix)) + } +} + +func TestCookieJarGetExpired(t *testing.T) { + t.Parallel() + + url1 := []byte("http://fasthttp.com/make/") + uri1 := fasthttp.AcquireURI() + require.NoError(t, uri1.Parse(nil, url1)) + + c1 := &fasthttp.Cookie{} + c1.SetKey("k") + c1.SetValue("v") + c1.SetExpire(time.Now().Add(-time.Hour)) + + cj := &CookieJar{} + cj.Set(uri1, c1) + + cookies := cj.Get(uri1) + require.Empty(t, cookies) +} + +func TestCookieJarSet(t *testing.T) { + t.Parallel() + + url := []byte("http://fasthttp.com/hello/world") + cj := &CookieJar{} + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + uri := fasthttp.AcquireURI() + require.NoError(t, uri.Parse(nil, url)) + + cj.Set(uri, cookie) + checkKeyValue(t, cj, cookie, uri, 1) +} + +func TestCookieJarSetRepeatedCookieKeys(t *testing.T) { + t.Parallel() + + host := "fast.http" + cj := &CookieJar{} + + uri := fasthttp.AcquireURI() + uri.SetHost(host) + + cookie := &fasthttp.Cookie{} + cookie.SetKey("k") + cookie.SetValue("v") + + cookie2 := &fasthttp.Cookie{} + cookie2.SetKey("k") + cookie2.SetValue("v2") + + cookie3 := &fasthttp.Cookie{} + cookie3.SetKey("key") + cookie3.SetValue("value") + + cj.Set(uri, cookie, cookie2, cookie3) + + cookies := cj.Get(uri) + require.Len(t, cookies, 2) + require.Equal(t, cookies[0], cookie2) + require.True(t, bytes.Equal(cookies[0].Value(), cookie2.Value())) +} + +func TestCookieJarSetKeyValue(t *testing.T) { + t.Parallel() + + host := "fast.http" + cj := &CookieJar{} + + uri := fasthttp.AcquireURI() + uri.SetHost(host) + + cj.SetKeyValue(host, "k", "v") + cj.SetKeyValue(host, "key", "value") + cj.SetKeyValue(host, "k", "vv") + cj.SetKeyValue(host, "key", "value2") + + cookies := cj.Get(uri) + require.Len(t, cookies, 2) +} + +func TestCookieJarGetFromResponse(t *testing.T) { + t.Parallel() + + res := fasthttp.AcquireResponse() + host := []byte("fast.http") + uri := fasthttp.AcquireURI() + uri.SetHostBytes(host) + + c := &fasthttp.Cookie{} + c.SetKey("key") + c.SetValue("val") + + c2 := &fasthttp.Cookie{} + c2.SetKey("k") + c2.SetValue("v") + + c3 := &fasthttp.Cookie{} + c3.SetKey("kk") + c3.SetValue("vv") + + res.Header.SetStatusCode(200) + res.Header.SetCookie(c) + res.Header.SetCookie(c2) + res.Header.SetCookie(c3) + + cj := &CookieJar{} + cj.parseCookiesFromResp(host, nil, res) + + cookies := cj.Get(uri) + require.Len(t, cookies, 3) +} diff --git a/client/core.go b/client/core.go new file mode 100644 index 00000000000..315d12d474f --- /dev/null +++ b/client/core.go @@ -0,0 +1,272 @@ +package client + +import ( + "context" + "errors" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/fiber/v3/addon/retry" + "github.com/valyala/fasthttp" +) + +var boundary = "--FiberFormBoundary" + +// RequestHook is a function that receives Agent and Request, +// it can change the data in Request and Agent. +// +// Called before a request is sent. +type RequestHook func(*Client, *Request) error + +// ResponseHook is a function that receives Agent, Response and Request, +// it can change the data is Response or deal with some effects. +// +// Called after a response has been received. +type ResponseHook func(*Client, *Response, *Request) error + +// RetryConfig is an alias for config in the `addon/retry` package. +type RetryConfig = retry.Config + +// addMissingPort will add the corresponding port number for host. +func addMissingPort(addr string, isTLS bool) string { //revive:disable-line:flag-parameter // Accepting a bool param named isTLS if fine here + n := strings.Index(addr, ":") + if n >= 0 { + return addr + } + port := 80 + if isTLS { + port = 443 + } + return net.JoinHostPort(addr, strconv.Itoa(port)) +} + +// `core` stores middleware and plugin definitions, +// and defines the execution process +type core struct { + client *Client + req *Request + ctx context.Context //nolint:containedctx // It's needed to be stored in the core. +} + +// getRetryConfig returns the retry configuration of the client. +func (c *core) getRetryConfig() *RetryConfig { + c.client.mu.RLock() + defer c.client.mu.RUnlock() + + cfg := c.client.RetryConfig() + if cfg == nil { + return nil + } + + return &RetryConfig{ + InitialInterval: cfg.InitialInterval, + MaxBackoffTime: cfg.MaxBackoffTime, + Multiplier: cfg.Multiplier, + MaxRetryCount: cfg.MaxRetryCount, + } +} + +// execFunc is the core function of the client. +// It sends the request and receives the response. +func (c *core) execFunc() (*Response, error) { + resp := AcquireResponse() + resp.setClient(c.client) + resp.setRequest(c.req) + + // To avoid memory allocation reuse of data structures such as errch. + done := int32(0) + errCh, reqv := acquireErrChan(), fasthttp.AcquireRequest() + defer func() { + releaseErrChan(errCh) + }() + + c.req.RawRequest.CopyTo(reqv) + cfg := c.getRetryConfig() + + var err error + go func() { + respv := fasthttp.AcquireResponse() + defer func() { + fasthttp.ReleaseRequest(reqv) + fasthttp.ReleaseResponse(respv) + }() + + if cfg != nil { + err = retry.NewExponentialBackoff(*cfg).Retry(func() error { + if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + return c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) + } + + return c.client.fasthttp.Do(reqv, respv) + }) + } else { + if c.req.maxRedirects > 0 && (string(reqv.Header.Method()) == fiber.MethodGet || string(reqv.Header.Method()) == fiber.MethodHead) { + err = c.client.fasthttp.DoRedirects(reqv, respv, c.req.maxRedirects) + } else { + err = c.client.fasthttp.Do(reqv, respv) + } + } + + if atomic.CompareAndSwapInt32(&done, 0, 1) { + if err != nil { + errCh <- err + return + } + respv.CopyTo(resp.RawResponse) + errCh <- nil + } + }() + + select { + case err := <-errCh: + if err != nil { + // When get error should release Response + ReleaseResponse(resp) + return nil, err + } + return resp, nil + case <-c.ctx.Done(): + atomic.SwapInt32(&done, 1) + ReleaseResponse(resp) + return nil, ErrTimeoutOrCancel + } +} + +// preHooks Exec request hook +func (c *core) preHooks() error { + c.client.mu.Lock() + defer c.client.mu.Unlock() + + for _, f := range c.client.userRequestHooks { + err := f(c.client, c.req) + if err != nil { + return err + } + } + + for _, f := range c.client.builtinRequestHooks { + err := f(c.client, c.req) + if err != nil { + return err + } + } + + return nil +} + +// afterHooks Exec response hooks +func (c *core) afterHooks(resp *Response) error { + c.client.mu.Lock() + defer c.client.mu.Unlock() + + for _, f := range c.client.builtinResponseHooks { + err := f(c.client, resp, c.req) + if err != nil { + return err + } + } + + for _, f := range c.client.userResponseHooks { + err := f(c.client, resp, c.req) + if err != nil { + return err + } + } + + return nil +} + +// timeout deals with timeout +func (c *core) timeout() context.CancelFunc { + var cancel context.CancelFunc + + if c.req.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.req.timeout) + } else if c.client.timeout > 0 { + c.ctx, cancel = context.WithTimeout(c.ctx, c.client.timeout) + } + + return cancel +} + +// execute will exec each hooks and plugins. +func (c *core) execute(ctx context.Context, client *Client, req *Request) (*Response, error) { + // keep a reference, because pass param is boring + c.ctx = ctx + c.client = client + c.req = req + + // The built-in hooks will be executed only + // after the user-defined hooks are executed. + err := c.preHooks() + if err != nil { + return nil, err + } + + cancel := c.timeout() + if cancel != nil { + defer cancel() + } + + // Do http request + resp, err := c.execFunc() + if err != nil { + return nil, err + } + + // The built-in hooks will be executed only + // before the user-defined hooks are executed. + err = c.afterHooks(resp) + if err != nil { + resp.Close() + return nil, err + } + + return resp, nil +} + +var errChanPool = &sync.Pool{ + New: func() any { + return make(chan error, 1) + }, +} + +// acquireErrChan returns an empty error chan from the pool. +// +// The returned error chan may be returned to the pool with releaseErrChan when no longer needed. +// This allows reducing GC load. +func acquireErrChan() chan error { + ch, ok := errChanPool.Get().(chan error) + if !ok { + panic(errors.New("failed to type-assert to chan error")) + } + + return ch +} + +// releaseErrChan returns the object acquired via acquireErrChan to the pool. +// +// Do not access the released core object, otherwise data races may occur. +func releaseErrChan(ch chan error) { + errChanPool.Put(ch) +} + +// newCore returns an empty core object. +func newCore() *core { + c := &core{} + + return c +} + +var ( + ErrTimeoutOrCancel = errors.New("timeout or cancel") + ErrURLFormat = errors.New("the url is a mistake") + ErrNotSupportSchema = errors.New("the protocol is not support, only http or https") + ErrFileNoName = errors.New("the file should have name") + ErrBodyType = errors.New("the body type should be []byte") + ErrNotSupportSaveMethod = errors.New("file path and io.Writer are supported") +) diff --git a/client/core_test.go b/client/core_test.go new file mode 100644 index 00000000000..1b8ea42b9d0 --- /dev/null +++ b/client/core_test.go @@ -0,0 +1,248 @@ +package client + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp/fasthttputil" +) + +func Test_AddMissing_Port(t *testing.T) { + t.Parallel() + + type args struct { + addr string + isTLS bool + } + tests := []struct { + name string + args args + want string + }{ + { + name: "do anything", + args: args{ + addr: "example.com:1234", + }, + want: "example.com:1234", + }, + { + name: "add 80 port", + args: args{ + addr: "example.com", + }, + want: "example.com:80", + }, + { + name: "add 443 port", + args: args{ + addr: "example.com", + isTLS: true, + }, + want: "example.com:443", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, addMissingPort(tt.args.addr, tt.args.isTLS)) + }) + } +} + +func Test_Exec_Func(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + app.Get("/normal", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + app.Get("/return-error", func(_ fiber.Ctx) error { + return errors.New("the request is error") + }) + + app.Get("/hang-up", func(c fiber.Ctx) error { + time.Sleep(time.Second) + return c.SendString(c.Hostname() + " hang up") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + }() + + time.Sleep(300 * time.Millisecond) + + t.Run("normal request", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + core.ctx = context.Background() + core.client = client + core.req = req + + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + req.RawRequest.SetRequestURI("http://example.com/normal") + + resp, err := core.execFunc() + require.NoError(t, err) + require.Equal(t, 200, resp.RawResponse.StatusCode()) + require.Equal(t, "example.com", string(resp.RawResponse.Body())) + }) + + t.Run("the request return an error", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + core.ctx = context.Background() + core.client = client + core.req = req + + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + req.RawRequest.SetRequestURI("http://example.com/return-error") + + resp, err := core.execFunc() + + require.NoError(t, err) + require.Equal(t, 500, resp.RawResponse.StatusCode()) + require.Equal(t, "the request is error", string(resp.RawResponse.Body())) + }) + + t.Run("the request timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + core.ctx = ctx + core.client = client + core.req = req + + client.SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + req.RawRequest.SetRequestURI("http://example.com/hang-up") + + _, err := core.execFunc() + + require.Equal(t, ErrTimeoutOrCancel, err) + }) +} + +func Test_Execute(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + app.Get("/normal", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + app.Get("/return-error", func(_ fiber.Ctx) error { + return errors.New("the request is error") + }) + + app.Get("/hang-up", func(c fiber.Ctx) error { + time.Sleep(time.Second) + return c.SendString(c.Hostname() + " hang up") + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + }() + + t.Run("add user request hooks", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.AddRequestHook(func(_ *Client, _ *Request) error { + require.Equal(t, "http://example.com", req.URL()) + return nil + }) + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com") + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) + }) + + t.Run("add user response hooks", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.AddResponseHook(func(_ *Client, _ *Response, req *Request) error { + require.Equal(t, "http://example.com", req.URL()) + return nil + }) + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com") + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "Cannot GET /", string(resp.RawResponse.Body())) + }) + + t.Run("no timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up") + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) + }) + + t.Run("client timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.SetTimeout(500 * time.Millisecond) + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up") + + _, err := core.execute(context.Background(), client, req) + require.Equal(t, ErrTimeoutOrCancel, err) + }) + + t.Run("request timeout", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up"). + SetTimeout(300 * time.Millisecond) + + _, err := core.execute(context.Background(), client, req) + require.Equal(t, ErrTimeoutOrCancel, err) + }) + + t.Run("request timeout has higher level", func(t *testing.T) { + t.Parallel() + core, client, req := newCore(), NewClient(), AcquireRequest() + client.SetTimeout(30 * time.Millisecond) + + client.SetDial(func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }) + req.SetURL("http://example.com/hang-up"). + SetTimeout(3000 * time.Millisecond) + + resp, err := core.execute(context.Background(), client, req) + require.NoError(t, err) + require.Equal(t, "example.com hang up", string(resp.RawResponse.Body())) + }) +} diff --git a/client/helper_test.go b/client/helper_test.go new file mode 100644 index 00000000000..67380f34705 --- /dev/null +++ b/client/helper_test.go @@ -0,0 +1,157 @@ +package client + +import ( + "net" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp/fasthttputil" +) + +type testServer struct { + app *fiber.App + ch chan struct{} + ln *fasthttputil.InmemoryListener + tb testing.TB +} + +func startTestServer(tb testing.TB, beforeStarting func(app *fiber.App)) *testServer { + tb.Helper() + + ln := fasthttputil.NewInmemoryListener() + app := fiber.New() + + if beforeStarting != nil { + beforeStarting(app) + } + + ch := make(chan struct{}) + go func() { + if err := app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true}); err != nil { + tb.Fatal(err) + } + + close(ch) + }() + + return &testServer{ + app: app, + ch: ch, + ln: ln, + tb: tb, + } +} + +func (ts *testServer) stop() { + ts.tb.Helper() + + if err := ts.app.Shutdown(); err != nil { + ts.tb.Fatal(err) + } + + select { + case <-ts.ch: + case <-time.After(time.Second): + ts.tb.Fatalf("timeout when waiting for server close") + } +} + +func (ts *testServer) dial() func(addr string) (net.Conn, error) { + ts.tb.Helper() + + return func(_ string) (net.Conn, error) { + return ts.ln.Dial() //nolint:wrapcheck // not needed + } +} + +func createHelperServer(tb testing.TB) (*fiber.App, func(addr string) (net.Conn, error), func()) { + tb.Helper() + + ln := fasthttputil.NewInmemoryListener() + + app := fiber.New() + + return app, func(_ string) (net.Conn, error) { + return ln.Dial() //nolint:wrapcheck // not needed + }, func() { + require.NoError(tb, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) + } +} + +func testRequest(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted string, count ...int) { + t.Helper() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + client := NewClient().SetDial(ln) + + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) + + resp, err := req.Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, excepted, resp.String()) + resp.Close() + } +} + +func testRequestFail(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Request), excepted error, count ...int) { + t.Helper() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + client := NewClient().SetDial(ln) + + for i := 0; i < c; i++ { + req := AcquireRequest().SetClient(client) + wrapAgent(req) + + _, err := req.Get("http://example.com") + + require.Equal(t, excepted.Error(), err.Error()) + } +} + +func testClient(t *testing.T, handler fiber.Handler, wrapAgent func(agent *Client), excepted string, count ...int) { //nolint: unparam // maybe needed + t.Helper() + + app, ln, start := createHelperServer(t) + app.Get("/", handler) + go start() + + c := 1 + if len(count) > 0 { + c = count[0] + } + + for i := 0; i < c; i++ { + client := NewClient().SetDial(ln) + wrapAgent(client) + + resp, err := client.Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, excepted, resp.String()) + resp.Close() + } +} diff --git a/client/hooks.go b/client/hooks.go new file mode 100644 index 00000000000..0ecc970d53c --- /dev/null +++ b/client/hooks.go @@ -0,0 +1,328 @@ +package client + +import ( + "errors" + "fmt" + "io" + "math/rand" + "mime/multipart" + "os" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +var ( + protocolCheck = regexp.MustCompile(`^https?://.*$`) + + headerAccept = "Accept" + + applicationJSON = "application/json" + applicationXML = "application/xml" + applicationForm = "application/x-www-form-urlencoded" + multipartFormData = "multipart/form-data" + + letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + letterIdxBits = 6 // 6 bits to represent a letter index + letterIdxMask = 1<= 0; { + if remain == 0 { + cache, remain = src.Int63(), letterIdxMax + } + + if idx := int(cache & int64(letterIdxMask)); idx < length { + b[i] = letterBytes[idx] + i-- + } + cache >>= int64(letterIdxBits) + remain-- + } + + return utils.UnsafeString(b) +} + +// parserRequestURL will set the options for the hostclient +// and normalize the url. +// The baseUrl will be merge with request uri. +// Query params and path params deal in this function. +func parserRequestURL(c *Client, req *Request) error { + splitURL := strings.Split(req.url, "?") + // I don't want to judge splitURL length. + splitURL = append(splitURL, "") + + // Determine whether to superimpose baseurl based on + // whether the URL starts with the protocol + uri := splitURL[0] + if !protocolCheck.MatchString(uri) { + uri = c.baseURL + uri + if !protocolCheck.MatchString(uri) { + return ErrURLFormat + } + } + + // set path params + req.path.VisitAll(func(key, val string) { + uri = strings.ReplaceAll(uri, ":"+key, val) + }) + c.path.VisitAll(func(key, val string) { + uri = strings.ReplaceAll(uri, ":"+key, val) + }) + + // set uri to request and other related setting + req.RawRequest.SetRequestURI(uri) + + // merge query params + hashSplit := strings.Split(splitURL[1], "#") + hashSplit = append(hashSplit, "") + args := fasthttp.AcquireArgs() + defer func() { + fasthttp.ReleaseArgs(args) + }() + + args.Parse(hashSplit[0]) + c.params.VisitAll(func(key, value []byte) { + args.AddBytesKV(key, value) + }) + req.params.VisitAll(func(key, value []byte) { + args.AddBytesKV(key, value) + }) + req.RawRequest.URI().SetQueryStringBytes(utils.CopyBytes(args.QueryString())) + req.RawRequest.URI().SetHash(hashSplit[1]) + + return nil +} + +// parserRequestHeader will make request header up. +// It will merge headers from client and request. +// Header should be set automatically based on data. +// User-Agent should be set. +func parserRequestHeader(c *Client, req *Request) error { + // set method + req.RawRequest.Header.SetMethod(req.Method()) + // merge header + c.header.VisitAll(func(key, value []byte) { + req.RawRequest.Header.AddBytesKV(key, value) + }) + + req.header.VisitAll(func(key, value []byte) { + req.RawRequest.Header.AddBytesKV(key, value) + }) + + // according to data set content-type + switch req.bodyType { + case jsonBody: + req.RawRequest.Header.SetContentType(applicationJSON) + req.RawRequest.Header.Set(headerAccept, applicationJSON) + case xmlBody: + req.RawRequest.Header.SetContentType(applicationXML) + case formBody: + req.RawRequest.Header.SetContentType(applicationForm) + case filesBody: + req.RawRequest.Header.SetContentType(multipartFormData) + // set boundary + if req.boundary == boundary { + req.boundary += randString(16) + } + req.RawRequest.Header.SetMultipartFormBoundary(req.boundary) + default: + } + + // set useragent + req.RawRequest.Header.SetUserAgent(defaultUserAgent) + if c.userAgent != "" { + req.RawRequest.Header.SetUserAgent(c.userAgent) + } + if req.userAgent != "" { + req.RawRequest.Header.SetUserAgent(req.userAgent) + } + + // set referer + req.RawRequest.Header.SetReferer(c.referer) + if req.referer != "" { + req.RawRequest.Header.SetReferer(req.referer) + } + + // set cookie + // add cookie form jar to req + if c.cookieJar != nil { + c.cookieJar.dumpCookiesToReq(req.RawRequest) + } + + c.cookies.VisitAll(func(key, val string) { + req.RawRequest.Header.SetCookie(key, val) + }) + + req.cookies.VisitAll(func(key, val string) { + req.RawRequest.Header.SetCookie(key, val) + }) + + return nil +} + +// parserRequestBody automatically serializes the data according to +// the data type and stores it in the body of the rawRequest +func parserRequestBody(c *Client, req *Request) error { + switch req.bodyType { + case jsonBody: + body, err := c.jsonMarshal(req.body) + if err != nil { + return err + } + req.RawRequest.SetBody(body) + case xmlBody: + body, err := c.xmlMarshal(req.body) + if err != nil { + return err + } + req.RawRequest.SetBody(body) + case formBody: + req.RawRequest.SetBody(req.formData.QueryString()) + case filesBody: + return parserRequestBodyFile(req) + case rawBody: + if body, ok := req.body.([]byte); ok { + req.RawRequest.SetBody(body) + } else { + return ErrBodyType + } + case noBody: + return nil + } + + return nil +} + +// parserRequestBodyFile parses request body if body type is file +// this is an addition of parserRequestBody. +func parserRequestBodyFile(req *Request) error { + mw := multipart.NewWriter(req.RawRequest.BodyWriter()) + err := mw.SetBoundary(req.boundary) + if err != nil { + return fmt.Errorf("set boundary error: %w", err) + } + defer func() { + err := mw.Close() + if err != nil { + return + } + }() + + // add formdata + req.formData.VisitAll(func(key, value []byte) { + if err != nil { + return + } + err = mw.WriteField(utils.UnsafeString(key), utils.UnsafeString(value)) + }) + if err != nil { + return fmt.Errorf("write formdata error: %w", err) + } + + // add file + b := make([]byte, 512) + for i, v := range req.files { + if v.name == "" && v.path == "" { + return ErrFileNoName + } + + // if name is not exist, set name + if v.name == "" && v.path != "" { + v.path = filepath.Clean(v.path) + v.name = filepath.Base(v.path) + } + + // if field name is not exist, set it + if v.fieldName == "" { + v.fieldName = "file" + strconv.Itoa(i+1) + } + + // check the reader + if v.reader == nil { + v.reader, err = os.Open(v.path) + if err != nil { + return fmt.Errorf("open file error: %w", err) + } + } + + // write file + w, err := mw.CreateFormFile(v.fieldName, v.name) + if err != nil { + return fmt.Errorf("create file error: %w", err) + } + + for { + n, err := v.reader.Read(b) + if err != nil && !errors.Is(err, io.EOF) { + return fmt.Errorf("read file error: %w", err) + } + + if errors.Is(err, io.EOF) { + break + } + + _, err = w.Write(b[:n]) + if err != nil { + return fmt.Errorf("write file error: %w", err) + } + } + + err = v.reader.Close() + if err != nil { + return fmt.Errorf("close file error: %w", err) + } + } + + return nil +} + +// parserResponseHeader will parse the response header and store it in the response +func parserResponseCookie(c *Client, resp *Response, req *Request) error { + var err error + resp.RawResponse.Header.VisitAllCookie(func(key, value []byte) { + cookie := fasthttp.AcquireCookie() + err = cookie.ParseBytes(value) + if err != nil { + return + } + cookie.SetKeyBytes(key) + + resp.cookie = append(resp.cookie, cookie) + }) + + if err != nil { + return err + } + + // store cookies to jar + if c.cookieJar != nil { + c.cookieJar.parseCookiesFromResp(req.RawRequest.URI().Host(), req.RawRequest.URI().Path(), resp.RawResponse) + } + + return nil +} + +// logger is a response hook that logs the request and response +func logger(c *Client, resp *Response, req *Request) error { + if !c.debug { + return nil + } + + c.logger.Debugf("%s\n", req.RawRequest.String()) + c.logger.Debugf("%s\n", resp.RawResponse.String()) + + return nil +} diff --git a/client/hooks_test.go b/client/hooks_test.go new file mode 100644 index 00000000000..a555bba833b --- /dev/null +++ b/client/hooks_test.go @@ -0,0 +1,652 @@ +package client + +import ( + "bytes" + "encoding/xml" + "fmt" + "io" + "net" + "net/url" + "strings" + "testing" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_Rand_String(t *testing.T) { + t.Parallel() + tests := []struct { + name string + args int + }{ + { + name: "test generate", + args: 16, + }, + { + name: "test generate smaller string", + args: 8, + }, + { + name: "test generate larger string", + args: 32, + }, + } + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := randString(tt.args) + require.Len(t, got, tt.args) + }) + } +} + +func Test_Parser_Request_URL(t *testing.T) { + t.Parallel() + + t.Run("client baseurl should be set", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) + }) + + t.Run("request url should be set", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest().SetURL("http://example.com/api") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api", req.RawRequest.URI().String()) + }) + + t.Run("the request url will override baseurl with protocol", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("http://example.com/api/v1") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) + }) + + t.Run("the request url should be append after baseurl without protocol", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("http://example.com/api") + req := AcquireRequest().SetURL("/v1") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/v1", req.RawRequest.URI().String()) + }) + + t.Run("the url is error", func(t *testing.T) { + t.Parallel() + client := NewClient().SetBaseURL("example.com/api") + req := AcquireRequest().SetURL("/v1") + + err := parserRequestURL(client, req) + require.Equal(t, ErrURLFormat, err) + }) + + t.Run("the path param from client", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetBaseURL("http://example.com/api/:id"). + SetPathParam("id", "5") + req := AcquireRequest() + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/5", req.RawRequest.URI().String()) + }) + + t.Run("the path param from request", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetBaseURL("http://example.com/api/:id/:name"). + SetPathParam("id", "5") + req := AcquireRequest(). + SetURL("/{key}"). + SetPathParams(map[string]string{ + "name": "fiber", + "key": "val", + }). + DelPathParams("key") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/5/fiber/%7Bkey%7D", req.RawRequest.URI().String()) + }) + + t.Run("the path param from request and client", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetBaseURL("http://example.com/api/:id/:name"). + SetPathParam("id", "5") + req := AcquireRequest(). + SetURL("/:key"). + SetPathParams(map[string]string{ + "name": "fiber", + "key": "val", + "id": "12", + }) + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, "http://example.com/api/12/fiber/val", req.RawRequest.URI().String()) + }) + + t.Run("query params from client should be set", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetParam("foo", "bar") + req := AcquireRequest().SetURL("http://example.com/api/v1") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, []byte("foo=bar"), req.RawRequest.URI().QueryString()) + }) + + t.Run("query params from request should be set", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetURL("http://example.com/api/v1"). + SetParam("bar", "foo") + + err := parserRequestURL(client, req) + require.NoError(t, err) + require.Equal(t, []byte("bar=foo"), req.RawRequest.URI().QueryString()) + }) + + t.Run("query params should be merged", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetParam("bar", "foo1") + req := AcquireRequest(). + SetURL("http://example.com/api/v1?bar=foo2"). + SetParam("bar", "foo") + + err := parserRequestURL(client, req) + require.NoError(t, err) + + values, err := url.ParseQuery(string(req.RawRequest.URI().QueryString())) + require.NoError(t, err) + + flag1, flag2, flag3 := false, false, false + for _, v := range values["bar"] { + if v == "foo1" { + flag1 = true + } else if v == "foo2" { + flag2 = true + } else if v == "foo" { + flag3 = true + } + } + require.True(t, flag1) + require.True(t, flag2) + require.True(t, flag3) + }) +} + +func Test_Parser_Request_Header(t *testing.T) { + t.Parallel() + + t.Run("client header should be set", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetHeaders(map[string]string{ + fiber.HeaderContentType: "application/json", + }) + + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("application/json"), req.RawRequest.Header.ContentType()) + }) + + t.Run("request header should be set", func(t *testing.T) { + t.Parallel() + client := NewClient() + + req := AcquireRequest(). + SetHeaders(map[string]string{ + fiber.HeaderContentType: "application/json, utf-8", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) + }) + + t.Run("request header should override client header", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetHeader(fiber.HeaderContentType, "application/xml") + + req := AcquireRequest(). + SetHeader(fiber.HeaderContentType, "application/json, utf-8") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("application/json, utf-8"), req.RawRequest.Header.ContentType()) + }) + + t.Run("auto set json header", func(t *testing.T) { + t.Parallel() + type jsonData struct { + Name string `json:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetJSON(jsonData{ + Name: "foo", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte(applicationJSON), req.RawRequest.Header.ContentType()) + }) + + t.Run("auto set xml header", func(t *testing.T) { + t.Parallel() + type xmlData struct { + XMLName xml.Name `xml:"body"` + Name string `xml:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetXML(xmlData{ + Name: "foo", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte(applicationXML), req.RawRequest.Header.ContentType()) + }) + + t.Run("auto set form data header", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "foo": "bar", + "ball": "cricle and square", + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, applicationForm, string(req.RawRequest.Header.ContentType())) + }) + + t.Run("auto set file header", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). + SetFormData("foo", "bar") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Header.MultipartFormBoundary()), "--FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Header.ContentType()), multipartFormData)) + }) + + t.Run("ua should have default value", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("fiber"), req.RawRequest.Header.UserAgent()) + }) + + t.Run("ua in client should be set", func(t *testing.T) { + t.Parallel() + client := NewClient().SetUserAgent("foo") + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("foo"), req.RawRequest.Header.UserAgent()) + }) + + t.Run("ua in request should have higher level", func(t *testing.T) { + t.Parallel() + client := NewClient().SetUserAgent("foo") + req := AcquireRequest().SetUserAgent("bar") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("bar"), req.RawRequest.Header.UserAgent()) + }) + + t.Run("referer in client should be set", func(t *testing.T) { + t.Parallel() + client := NewClient().SetReferer("https://example.com") + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) + }) + + t.Run("referer in request should have higher level", func(t *testing.T) { + t.Parallel() + client := NewClient().SetReferer("http://example.com") + req := AcquireRequest().SetReferer("https://example.com") + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, []byte("https://example.com"), req.RawRequest.Header.Referer()) + }) + + t.Run("client cookie should be set", func(t *testing.T) { + t.Parallel() + client := NewClient(). + SetCookie("foo", "bar"). + SetCookies(map[string]string{ + "bar": "foo", + "bar1": "foo1", + }). + DelCookies("bar1") + + req := AcquireRequest() + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "foo", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1"))) + }) + + t.Run("request cookie should be set", func(t *testing.T) { + t.Parallel() + type cookies struct { + Foo string `cookie:"foo"` + Bar int `cookie:"bar"` + } + + client := NewClient() + + req := AcquireRequest(). + SetCookiesWithStruct(&cookies{ + Foo: "bar", + Bar: 67, + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "", string(req.RawRequest.Header.Cookie("bar1"))) + }) + + t.Run("request cookie will override client cookie", func(t *testing.T) { + t.Parallel() + type cookies struct { + Foo string `cookie:"foo"` + Bar int `cookie:"bar"` + } + + client := NewClient(). + SetCookie("foo", "bar"). + SetCookies(map[string]string{ + "bar": "foo", + "bar1": "foo1", + }) + + req := AcquireRequest(). + SetCookiesWithStruct(&cookies{ + Foo: "bar", + Bar: 67, + }) + + err := parserRequestHeader(client, req) + require.NoError(t, err) + require.Equal(t, "bar", string(req.RawRequest.Header.Cookie("foo"))) + require.Equal(t, "67", string(req.RawRequest.Header.Cookie("bar"))) + require.Equal(t, "foo1", string(req.RawRequest.Header.Cookie("bar1"))) + }) +} + +func Test_Parser_Request_Body(t *testing.T) { + t.Parallel() + + t.Run("json body", func(t *testing.T) { + t.Parallel() + type jsonData struct { + Name string `json:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetJSON(jsonData{ + Name: "foo", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, []byte("{\"name\":\"foo\"}"), req.RawRequest.Body()) + }) + + t.Run("xml body", func(t *testing.T) { + t.Parallel() + type xmlData struct { + XMLName xml.Name `xml:"body"` + Name string `xml:"name"` + } + client := NewClient() + req := AcquireRequest(). + SetXML(xmlData{ + Name: "foo", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, []byte("foo"), req.RawRequest.Body()) + }) + + t.Run("form data body", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "ball": "cricle and square", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, "ball=cricle+and+square", string(req.RawRequest.Body())) + }) + + t.Run("form data body error", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetFormDatas(map[string]string{ + "": "", + }) + + err := parserRequestBody(client, req) + require.NoError(t, err) + }) + + t.Run("file body", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "world")) + }) + + t.Run("file and form data", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + AddFileWithReader("hello", io.NopCloser(strings.NewReader("world"))). + SetFormData("foo", "bar") + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "----FiberFormBoundary")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "world")) + require.True(t, strings.Contains(string(req.RawRequest.Body()), "bar")) + }) + + t.Run("raw body", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + err := parserRequestBody(client, req) + require.NoError(t, err) + require.Equal(t, []byte("hello world"), req.RawRequest.Body()) + }) + + t.Run("raw body error", func(t *testing.T) { + t.Parallel() + client := NewClient() + req := AcquireRequest(). + SetRawBody([]byte("hello world")) + + req.body = nil + + err := parserRequestBody(client, req) + require.ErrorIs(t, err, ErrBodyType) + }) +} + +type dummyLogger struct { + buf *bytes.Buffer +} + +func (*dummyLogger) Trace(_ ...any) {} + +func (*dummyLogger) Debug(_ ...any) {} + +func (*dummyLogger) Info(_ ...any) {} + +func (*dummyLogger) Warn(_ ...any) {} + +func (*dummyLogger) Error(_ ...any) {} + +func (*dummyLogger) Fatal(_ ...any) {} + +func (*dummyLogger) Panic(_ ...any) {} + +func (*dummyLogger) Tracef(_ string, _ ...any) {} + +func (l *dummyLogger) Debugf(format string, v ...any) { + _, _ = l.buf.WriteString(fmt.Sprintf(format, v...)) //nolint:errcheck // not needed +} + +func (*dummyLogger) Infof(_ string, _ ...any) {} + +func (*dummyLogger) Warnf(_ string, _ ...any) {} + +func (*dummyLogger) Errorf(_ string, _ ...any) {} + +func (*dummyLogger) Fatalf(_ string, _ ...any) {} + +func (*dummyLogger) Panicf(_ string, _ ...any) {} + +func (*dummyLogger) Tracew(_ string, _ ...any) {} + +func (*dummyLogger) Debugw(_ string, _ ...any) {} + +func (*dummyLogger) Infow(_ string, _ ...any) {} + +func (*dummyLogger) Warnw(_ string, _ ...any) {} + +func (*dummyLogger) Errorw(_ string, _ ...any) {} + +func (*dummyLogger) Fatalw(_ string, _ ...any) {} + +func (*dummyLogger) Panicw(_ string, _ ...any) {} + +func Test_Client_Logger_Debug(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + addrChan := make(chan string) + go func() { + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + })) + }() + + defer func(app *fiber.App) { + require.NoError(t, app.Shutdown()) + }(app) + + var buf bytes.Buffer + logger := &dummyLogger{buf: &buf} + + client := NewClient() + client.Debug().SetLogger(logger) + + addr := <-addrChan + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + defer resp.Close() + + require.NoError(t, err) + require.Contains(t, buf.String(), "Host: "+addr) + require.Contains(t, buf.String(), "Content-Length: 8") +} + +func Test_Client_Logger_DisableDebug(t *testing.T) { + t.Parallel() + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("response") + }) + + addrChan := make(chan string) + go func() { + require.NoError(t, app.Listen(":0", fiber.ListenConfig{ + DisableStartupMessage: true, + ListenerAddrFunc: func(addr net.Addr) { + addrChan <- addr.String() + }, + })) + }() + + defer func(app *fiber.App) { + require.NoError(t, app.Shutdown()) + }(app) + + var buf bytes.Buffer + logger := &dummyLogger{buf: &buf} + + client := NewClient() + client.DisableDebug().SetLogger(logger) + + addr := <-addrChan + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + defer resp.Close() + + require.NoError(t, err) + require.Empty(t, buf.String()) +} diff --git a/client/request.go b/client/request.go new file mode 100644 index 00000000000..0bf2fb321cb --- /dev/null +++ b/client/request.go @@ -0,0 +1,985 @@ +package client + +import ( + "bytes" + "context" + "errors" + "io" + "path/filepath" + "reflect" + "strconv" + "sync" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +// WithStruct Implementing this interface allows data to +// be stored from a struct via reflect. +type WithStruct interface { + Add(name, obj string) + Del(name string) +} + +// Types of request bodies. +type bodyType int + +// Enumeration definition of the request body type. +const ( + noBody bodyType = iota + jsonBody + xmlBody + formBody + filesBody + rawBody +) + +var ErrClientNil = errors.New("client can not be nil") + +// Request is a struct which contains the request data. +type Request struct { + url string + method string + userAgent string + boundary string + referer string + ctx context.Context //nolint:containedctx // It's needed to be stored in the request. + header *Header + params *QueryParam + cookies *Cookie + path *PathParam + + timeout time.Duration + maxRedirects int + + client *Client + + body any + formData *FormData + files []*File + bodyType bodyType + + RawRequest *fasthttp.Request +} + +// Method returns http method in request. +func (r *Request) Method() string { + return r.method +} + +// SetMethod will set method for Request object, +// user should use request method to set method. +func (r *Request) SetMethod(method string) *Request { + r.method = method + return r +} + +// URL returns request url in Request instance. +func (r *Request) URL() string { + return r.url +} + +// SetURL will set url for Request object. +func (r *Request) SetURL(url string) *Request { + r.url = url + return r +} + +// Client get Client instance in Request. +func (r *Request) Client() *Client { + return r.client +} + +// SetClient method sets client in request instance. +func (r *Request) SetClient(c *Client) *Request { + if c == nil { + panic(ErrClientNil) + } + + r.client = c + return r +} + +// Context returns the Context if its already set in request +// otherwise it creates new one using `context.Background()`. +func (r *Request) Context() context.Context { + if r.ctx == nil { + return context.Background() + } + return r.ctx +} + +// SetContext sets the context.Context for current Request. It allows +// to interrupt the request execution if ctx.Done() channel is closed. +// See https://blog.golang.org/context article and the "context" package +// documentation. +func (r *Request) SetContext(ctx context.Context) *Request { + r.ctx = ctx + return r +} + +// Header method returns header value via key, +// this method will visit all field in the header, +// then sort them. +func (r *Request) Header(key string) []string { + return r.header.PeekMultiple(key) +} + +// AddHeader method adds a single header field and its value in the request instance. +// It will override header which set in client instance. +func (r *Request) AddHeader(key, val string) *Request { + r.header.Add(key, val) + return r +} + +// SetHeader method sets a single header field and its value in the request instance. +// It will override header which set in client instance. +func (r *Request) SetHeader(key, val string) *Request { + r.header.Del(key) + r.header.Set(key, val) + return r +} + +// AddHeaders method adds multiple header fields and its values at one go in the request instance. +// It will override header which set in client instance. +func (r *Request) AddHeaders(h map[string][]string) *Request { + r.header.AddHeaders(h) + return r +} + +// SetHeaders method sets multiple header fields and its values at one go in the request instance. +// It will override header which set in client instance. +func (r *Request) SetHeaders(h map[string]string) *Request { + r.header.SetHeaders(h) + return r +} + +// Param method returns params value via key, +// this method will visit all field in the query param. +func (r *Request) Param(key string) []string { + var res []string + tmp := r.params.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + + return res +} + +// AddParam method adds a single param field and its value in the request instance. +// It will override param which set in client instance. +func (r *Request) AddParam(key, val string) *Request { + r.params.Add(key, val) + return r +} + +// SetParam method sets a single param field and its value in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParam(key, val string) *Request { + r.params.Set(key, val) + return r +} + +// AddParams method adds multiple param fields and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) AddParams(m map[string][]string) *Request { + r.params.AddParams(m) + return r +} + +// SetParams method sets multiple param fields and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParams(m map[string]string) *Request { + r.params.SetParams(m) + return r +} + +// SetParamsWithStruct method sets multiple param fields and its values at one go in the request instance. +// It will override param which set in client instance. +func (r *Request) SetParamsWithStruct(v any) *Request { + r.params.SetParamsWithStruct(v) + return r +} + +// DelParams method deletes single or multiple param fields ant its values. +func (r *Request) DelParams(key ...string) *Request { + for _, v := range key { + r.params.Del(v) + } + return r +} + +// UserAgent returns user agent in request instance. +func (r *Request) UserAgent() string { + return r.userAgent +} + +// SetUserAgent method sets user agent in request. +// It will override user agent which set in client instance. +func (r *Request) SetUserAgent(ua string) *Request { + r.userAgent = ua + return r +} + +// Boundary returns boundary in multipart boundary. +func (r *Request) Boundary() string { + return r.boundary +} + +// SetBoundary method sets multipart boundary. +func (r *Request) SetBoundary(b string) *Request { + r.boundary = b + + return r +} + +// Referer returns referer in request instance. +func (r *Request) Referer() string { + return r.referer +} + +// SetReferer method sets referer in request. +// It will override referer which set in client instance. +func (r *Request) SetReferer(referer string) *Request { + r.referer = referer + return r +} + +// Cookie returns the cookie be set in request instance. +// if cookie doesn't exist, return empty string. +func (r *Request) Cookie(key string) string { + if val, ok := (*r.cookies)[key]; ok { + return val + } + return "" +} + +// SetCookie method sets a single cookie field and its value in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookie(key, val string) *Request { + r.cookies.SetCookie(key, val) + return r +} + +// SetCookies method sets multiple cookie fields and its values at one go in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookies(m map[string]string) *Request { + r.cookies.SetCookies(m) + return r +} + +// SetCookiesWithStruct method sets multiple cookie fields and its values at one go in the request instance. +// It will override cookie which set in client instance. +func (r *Request) SetCookiesWithStruct(v any) *Request { + r.cookies.SetCookiesWithStruct(v) + return r +} + +// DelCookies method deletes single or multiple cookie fields ant its values. +func (r *Request) DelCookies(key ...string) *Request { + r.cookies.DelCookies(key...) + return r +} + +// PathParam returns the path param be set in request instance. +// if path param doesn't exist, return empty string. +func (r *Request) PathParam(key string) string { + if val, ok := (*r.path)[key]; ok { + return val + } + + return "" +} + +// SetPathParam method sets a single path param field and its value in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParam(key, val string) *Request { + r.path.SetParam(key, val) + return r +} + +// SetPathParams method sets multiple path param fields and its values at one go in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParams(m map[string]string) *Request { + r.path.SetParams(m) + return r +} + +// SetPathParamsWithStruct method sets multiple path param fields and its values at one go in the request instance. +// It will override path param which set in client instance. +func (r *Request) SetPathParamsWithStruct(v any) *Request { + r.path.SetParamsWithStruct(v) + return r +} + +// DelPathParams method deletes single or multiple path param fields ant its values. +func (r *Request) DelPathParams(key ...string) *Request { + r.path.DelParams(key...) + return r +} + +// ResetPathParams deletes all path params. +func (r *Request) ResetPathParams() *Request { + r.path.Reset() + return r +} + +// SetJSON method sets json body in request. +func (r *Request) SetJSON(v any) *Request { + r.body = v + r.bodyType = jsonBody + return r +} + +// SetXML method sets xml body in request. +func (r *Request) SetXML(v any) *Request { + r.body = v + r.bodyType = xmlBody + return r +} + +// SetRawBody method sets body with raw data in request. +func (r *Request) SetRawBody(v []byte) *Request { + r.body = v + r.bodyType = rawBody + return r +} + +// resetBody will clear body object and set bodyType +// if body type is formBody and filesBody, the new body type will be ignored. +func (r *Request) resetBody(t bodyType) { + r.body = nil + + // Set form data after set file ignore. + if r.bodyType == filesBody && t == formBody { + return + } + r.bodyType = t +} + +// FormData method returns form data value via key, +// this method will visit all field in the form data. +func (r *Request) FormData(key string) []string { + var res []string + tmp := r.formData.PeekMulti(key) + for _, v := range tmp { + res = append(res, utils.UnsafeString(v)) + } + + return res +} + +// AddFormData method adds a single form data field and its value in the request instance. +func (r *Request) AddFormData(key, val string) *Request { + r.formData.AddData(key, val) + r.resetBody(formBody) + return r +} + +// SetFormData method sets a single form data field and its value in the request instance. +func (r *Request) SetFormData(key, val string) *Request { + r.formData.SetData(key, val) + r.resetBody(formBody) + return r +} + +// AddFormDatas method adds multiple form data fields and its values in the request instance. +func (r *Request) AddFormDatas(m map[string][]string) *Request { + r.formData.AddDatas(m) + r.resetBody(formBody) + return r +} + +// SetFormDatas method sets multiple form data fields and its values in the request instance. +func (r *Request) SetFormDatas(m map[string]string) *Request { + r.formData.SetDatas(m) + r.resetBody(formBody) + return r +} + +// SetFormDatasWithStruct method sets multiple form data fields +// and its values in the request instance via struct. +func (r *Request) SetFormDatasWithStruct(v any) *Request { + r.formData.SetDatasWithStruct(v) + r.resetBody(formBody) + return r +} + +// DelFormDatas method deletes multiple form data fields and its value in the request instance. +func (r *Request) DelFormDatas(key ...string) *Request { + r.formData.DelDatas(key...) + r.resetBody(formBody) + return r +} + +// File returns file ptr store in request obj by name. +// If name field is empty, it will try to match path. +func (r *Request) File(name string) *File { + for _, v := range r.files { + if v.name == "" { + if filepath.Base(v.path) == name { + return v + } + } else if v.name == name { + return v + } + } + + return nil +} + +// FileByPath returns file ptr store in request obj by path. +func (r *Request) FileByPath(path string) *File { + for _, v := range r.files { + if v.path == path { + return v + } + } + + return nil +} + +// AddFile method adds single file field +// and its value in the request instance via file path. +func (r *Request) AddFile(path string) *Request { + r.files = append(r.files, AcquireFile(SetFilePath(path))) + r.resetBody(filesBody) + return r +} + +// AddFileWithReader method adds single field +// and its value in the request instance via reader. +func (r *Request) AddFileWithReader(name string, reader io.ReadCloser) *Request { + r.files = append(r.files, AcquireFile(SetFileName(name), SetFileReader(reader))) + r.resetBody(filesBody) + return r +} + +// AddFiles method adds multiple file fields +// and its value in the request instance via File instance. +func (r *Request) AddFiles(files ...*File) *Request { + r.files = append(r.files, files...) + r.resetBody(filesBody) + return r +} + +// Timeout returns the length of timeout in request. +func (r *Request) Timeout() time.Duration { + return r.timeout +} + +// SetTimeout method sets timeout field and its values at one go in the request instance. +// It will override timeout which set in client instance. +func (r *Request) SetTimeout(t time.Duration) *Request { + r.timeout = t + return r +} + +// MaxRedirects returns the max redirects count in request. +func (r *Request) MaxRedirects() int { + return r.maxRedirects +} + +// SetMaxRedirects method sets the maximum number of redirects at one go in the request instance. +// It will override max redirect which set in client instance. +func (r *Request) SetMaxRedirects(count int) *Request { + r.maxRedirects = count + return r +} + +// checkClient method checks whether the client has been set in request. +func (r *Request) checkClient() { + if r.client == nil { + r.SetClient(defaultClient) + } +} + +// Get Send get request. +func (r *Request) Get(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodGet).Send() +} + +// Post Send post request. +func (r *Request) Post(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodPost).Send() +} + +// Head Send head request. +func (r *Request) Head(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodHead).Send() +} + +// Put Send put request. +func (r *Request) Put(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodPut).Send() +} + +// Delete Send Delete request. +func (r *Request) Delete(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodDelete).Send() +} + +// Options Send Options request. +func (r *Request) Options(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodOptions).Send() +} + +// Patch Send patch request. +func (r *Request) Patch(url string) (*Response, error) { + return r.SetURL(url).SetMethod(fiber.MethodPatch).Send() +} + +// Custom Send custom request. +func (r *Request) Custom(url, method string) (*Response, error) { + return r.SetURL(url).SetMethod(method).Send() +} + +// Send a request. +func (r *Request) Send() (*Response, error) { + r.checkClient() + + return newCore().execute(r.Context(), r.Client(), r) +} + +// Reset clear Request object, used by ReleaseRequest method. +func (r *Request) Reset() { + r.url = "" + r.method = fiber.MethodGet + r.userAgent = "" + r.referer = "" + r.ctx = nil + r.body = nil + r.timeout = 0 + r.maxRedirects = 0 + r.bodyType = noBody + r.boundary = boundary + + for len(r.files) != 0 { + t := r.files[0] + r.files = r.files[1:] + ReleaseFile(t) + } + + r.formData.Reset() + r.path.Reset() + r.cookies.Reset() + r.header.Reset() + r.params.Reset() + r.RawRequest.Reset() +} + +// Header is a wrapper which wrap http.Header, +// the header in client and request will store in it. +type Header struct { + *fasthttp.RequestHeader +} + +// PeekMultiple methods returns multiple field in header with same key. +func (h *Header) PeekMultiple(key string) []string { + var res []string + byteKey := []byte(key) + h.RequestHeader.VisitAll(func(key, value []byte) { + if bytes.EqualFold(key, byteKey) { + res = append(res, utils.UnsafeString(value)) + } + }) + + return res +} + +// AddHeaders receive a map and add each value to header. +func (h *Header) AddHeaders(r map[string][]string) { + for k, v := range r { + for _, vv := range v { + h.Add(k, vv) + } + } +} + +// SetHeaders will override all headers. +func (h *Header) SetHeaders(r map[string]string) { + for k, v := range r { + h.Del(k) + h.Set(k, v) + } +} + +// QueryParam is a wrapper which wrap url.Values, +// the query string and formdata in client and request will store in it. +type QueryParam struct { + *fasthttp.Args +} + +// AddParams receive a map and add each value to param. +func (p *QueryParam) AddParams(r map[string][]string) { + for k, v := range r { + for _, vv := range v { + p.Add(k, vv) + } + } +} + +// SetParams will override all params. +func (p *QueryParam) SetParams(r map[string]string) { + for k, v := range r { + p.Set(k, v) + } +} + +// SetParamsWithStruct will override all params with struct or pointer of struct. +// Now nested structs are not currently supported. +func (p *QueryParam) SetParamsWithStruct(v any) { + SetValWithStruct(p, "param", v) +} + +// Cookie is a map which to store the cookies. +type Cookie map[string]string + +// Add method impl the method in WithStruct interface. +func (c Cookie) Add(key, val string) { + c[key] = val +} + +// Del method impl the method in WithStruct interface. +func (c Cookie) Del(key string) { + delete(c, key) +} + +// SetCookie method sets a single val in Cookie. +func (c Cookie) SetCookie(key, val string) { + c[key] = val +} + +// SetCookies method sets multiple val in Cookie. +func (c Cookie) SetCookies(m map[string]string) { + for k, v := range m { + c[k] = v + } +} + +// SetCookiesWithStruct method sets multiple val in Cookie via a struct. +func (c Cookie) SetCookiesWithStruct(v any) { + SetValWithStruct(c, "cookie", v) +} + +// DelCookies method deletes multiple val in Cookie. +func (c Cookie) DelCookies(key ...string) { + for _, v := range key { + c.Del(v) + } +} + +// VisitAll method receive a function which can travel the all val. +func (c Cookie) VisitAll(f func(key, val string)) { + for k, v := range c { + f(k, v) + } +} + +// Reset clear the Cookie object. +func (c Cookie) Reset() { + for k := range c { + delete(c, k) + } +} + +// PathParam is a map which to store the cookies. +type PathParam map[string]string + +// Add method impl the method in WithStruct interface. +func (p PathParam) Add(key, val string) { + p[key] = val +} + +// Del method impl the method in WithStruct interface. +func (p PathParam) Del(key string) { + delete(p, key) +} + +// SetParam method sets a single val in PathParam. +func (p PathParam) SetParam(key, val string) { + p[key] = val +} + +// SetParams method sets multiple val in PathParam. +func (p PathParam) SetParams(m map[string]string) { + for k, v := range m { + p[k] = v + } +} + +// SetParamsWithStruct method sets multiple val in PathParam via a struct. +func (p PathParam) SetParamsWithStruct(v any) { + SetValWithStruct(p, "path", v) +} + +// DelParams method deletes multiple val in PathParams. +func (p PathParam) DelParams(key ...string) { + for _, v := range key { + p.Del(v) + } +} + +// VisitAll method receive a function which can travel the all val. +func (p PathParam) VisitAll(f func(key, val string)) { + for k, v := range p { + f(k, v) + } +} + +// Reset clear the PathParams object. +func (p PathParam) Reset() { + for k := range p { + delete(p, k) + } +} + +// FormData is a wrapper of fasthttp.Args, +// and it be used for url encode body and file body. +type FormData struct { + *fasthttp.Args +} + +// AddData method is a wrapper of Args's Add method. +func (f *FormData) AddData(key, val string) { + f.Add(key, val) +} + +// SetData method is a wrapper of Args's Set method. +func (f *FormData) SetData(key, val string) { + f.Set(key, val) +} + +// AddDatas method supports add multiple fields. +func (f *FormData) AddDatas(m map[string][]string) { + for k, v := range m { + for _, vv := range v { + f.Add(k, vv) + } + } +} + +// SetDatas method supports set multiple fields. +func (f *FormData) SetDatas(m map[string]string) { + for k, v := range m { + f.Set(k, v) + } +} + +// SetDatasWithStruct method supports set multiple fields via a struct. +func (f *FormData) SetDatasWithStruct(v any) { + SetValWithStruct(f, "form", v) +} + +// DelDatas method deletes multiple fields. +func (f *FormData) DelDatas(key ...string) { + for _, v := range key { + f.Del(v) + } +} + +// Reset clear the FormData object. +func (f *FormData) Reset() { + f.Args.Reset() +} + +// File is a struct which support send files via request. +type File struct { + name string + fieldName string + path string + reader io.ReadCloser +} + +// SetName method sets file name. +func (f *File) SetName(n string) { + f.name = n +} + +// SetFieldName method sets key of file in the body. +func (f *File) SetFieldName(n string) { + f.fieldName = n +} + +// SetPath method set file path. +func (f *File) SetPath(p string) { + f.path = p +} + +// SetReader method can receive a io.ReadCloser +// which will be closed in parserBody hook. +func (f *File) SetReader(r io.ReadCloser) { + f.reader = r +} + +// Reset clear the File object. +func (f *File) Reset() { + f.name = "" + f.fieldName = "" + f.path = "" + f.reader = nil +} + +var requestPool = &sync.Pool{ + New: func() any { + return &Request{ + header: &Header{RequestHeader: &fasthttp.RequestHeader{}}, + params: &QueryParam{Args: fasthttp.AcquireArgs()}, + cookies: &Cookie{}, + path: &PathParam{}, + boundary: "--FiberFormBoundary", + formData: &FormData{Args: fasthttp.AcquireArgs()}, + files: make([]*File, 0), + RawRequest: fasthttp.AcquireRequest(), + } + }, +} + +// AcquireRequest returns an empty request object from the pool. +// +// The returned request may be returned to the pool with ReleaseRequest when no longer needed. +// This allows reducing GC load. +func AcquireRequest() *Request { + req, ok := requestPool.Get().(*Request) + if !ok { + panic(errors.New("failed to type-assert to *Request")) + } + + return req +} + +// ReleaseRequest returns the object acquired via AcquireRequest to the pool. +// +// Do not access the released Request object, otherwise data races may occur. +func ReleaseRequest(req *Request) { + req.Reset() + requestPool.Put(req) +} + +var filePool sync.Pool + +// SetFileFunc The methods as follows is used by AcquireFile method. +// You can set file field via these method. +type SetFileFunc func(f *File) + +// SetFileName method sets file name. +func SetFileName(n string) SetFileFunc { + return func(f *File) { + f.SetName(n) + } +} + +// SetFileFieldName method sets key of file in the body. +func SetFileFieldName(p string) SetFileFunc { + return func(f *File) { + f.SetFieldName(p) + } +} + +// SetFilePath method set file path. +func SetFilePath(p string) SetFileFunc { + return func(f *File) { + f.SetPath(p) + } +} + +// SetFileReader method can receive a io.ReadCloser +func SetFileReader(r io.ReadCloser) SetFileFunc { + return func(f *File) { + f.SetReader(r) + } +} + +// AcquireFile returns an File object from the pool. +// And you can set field in the File with SetFileFunc. +// +// The returned file may be returned to the pool with ReleaseFile when no longer needed. +// This allows reducing GC load. +func AcquireFile(setter ...SetFileFunc) *File { + fv := filePool.Get() + if fv != nil { + f, ok := fv.(*File) + if !ok { + panic(errors.New("failed to type-assert to *File")) + } + for _, v := range setter { + v(f) + } + return f + } + f := &File{} + for _, v := range setter { + v(f) + } + return f +} + +// ReleaseFile returns the object acquired via AcquireFile to the pool. +// +// Do not access the released File object, otherwise data races may occur. +func ReleaseFile(f *File) { + f.Reset() + filePool.Put(f) +} + +// SetValWithStruct Set some values using structs. +// `p` is a structure that implements the WithStruct interface, +// The field name can be specified by `tagName`. +// `v` is a struct include some data. +// Note: This method only supports simple types and nested structs are not currently supported. +func SetValWithStruct(p WithStruct, tagName string, v any) { + valueOfV := reflect.ValueOf(v) + typeOfV := reflect.TypeOf(v) + + // The v should be struct or point of struct + if typeOfV.Kind() == reflect.Pointer && typeOfV.Elem().Kind() == reflect.Struct { + valueOfV = valueOfV.Elem() + typeOfV = typeOfV.Elem() + } else if typeOfV.Kind() != reflect.Struct { + return + } + + // Boring type judge. + // TODO: cover more types and complex data structure. + var setVal func(name string, value reflect.Value) + setVal = func(name string, val reflect.Value) { + switch val.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + p.Add(name, strconv.Itoa(int(val.Int()))) + case reflect.Bool: + if val.Bool() { + p.Add(name, "true") + } + case reflect.String: + p.Add(name, val.String()) + case reflect.Float32, reflect.Float64: + p.Add(name, strconv.FormatFloat(val.Float(), 'f', -1, 64)) + case reflect.Slice, reflect.Array: + for i := 0; i < val.Len(); i++ { + setVal(name, val.Index(i)) + } + default: + } + } + + for i := 0; i < typeOfV.NumField(); i++ { + field := typeOfV.Field(i) + if !field.IsExported() { + continue + } + + name := field.Tag.Get(tagName) + if name == "" { + name = field.Name + } + val := valueOfV.Field(i) + if val.IsZero() { + continue + } + // To cover slice and array, we delete the val then add it. + p.Del(name) + setVal(name, val) + } +} diff --git a/client/request_test.go b/client/request_test.go new file mode 100644 index 00000000000..07e5254e15f --- /dev/null +++ b/client/request_test.go @@ -0,0 +1,1623 @@ +package client + +import ( + "bytes" + "context" + "errors" + "io" + "mime/multipart" + "net" + "os" + "path/filepath" + "regexp" + "strings" + "testing" + "time" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" + "github.com/valyala/fasthttp/fasthttputil" +) + +func Test_Request_Method(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + req.SetMethod("GET") + require.Equal(t, "GET", req.Method()) + + req.SetMethod("POST") + require.Equal(t, "POST", req.Method()) + + req.SetMethod("PUT") + require.Equal(t, "PUT", req.Method()) + + req.SetMethod("DELETE") + require.Equal(t, "DELETE", req.Method()) + + req.SetMethod("PATCH") + require.Equal(t, "PATCH", req.Method()) + + req.SetMethod("OPTIONS") + require.Equal(t, "OPTIONS", req.Method()) + + req.SetMethod("HEAD") + require.Equal(t, "HEAD", req.Method()) + + req.SetMethod("TRACE") + require.Equal(t, "TRACE", req.Method()) + + req.SetMethod("CUSTOM") + require.Equal(t, "CUSTOM", req.Method()) +} + +func Test_Request_URL(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + + req.SetURL("http://example.com/normal") + require.Equal(t, "http://example.com/normal", req.URL()) + + req.SetURL("https://example.com/normal") + require.Equal(t, "https://example.com/normal", req.URL()) +} + +func Test_Request_Client(t *testing.T) { + t.Parallel() + + client := NewClient() + req := AcquireRequest() + + req.SetClient(client) + require.Equal(t, client, req.Client()) +} + +func Test_Request_Context(t *testing.T) { + t.Parallel() + + req := AcquireRequest() + ctx := req.Context() + key := struct{}{} + + require.Nil(t, ctx.Value(key)) + + ctx = context.WithValue(ctx, key, "string") + req.SetContext(ctx) + ctx = req.Context() + + v, ok := ctx.Value(key).(string) + require.True(t, ok) + require.Equal(t, "string", v) +} + +func Test_Request_Header(t *testing.T) { + t.Parallel() + + t.Run("add header", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddHeader("foo", "bar").AddHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set header", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddHeader("foo", "bar").SetHeader("foo", "fiber") + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetHeader("foo", "bar"). + AddHeaders(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Header("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetHeader("foo", "bar"). + SetHeaders(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Header("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Header("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) +} + +func Test_Request_QueryParam(t *testing.T) { + t.Parallel() + + t.Run("add param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddParam("foo", "bar").AddParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.AddParam("foo", "bar").SetParam("foo", "fiber") + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetParam("foo", "bar"). + AddParams(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.Param("foo") + require.Len(t, res, 3) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + require.Equal(t, "buaa", res[2]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.Param("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.Param("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + p := AcquireRequest() + p.SetParamsWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Empty(t, p.Param("unexport")) + + require.Len(t, p.Param("TInt"), 1) + require.Equal(t, "5", p.Param("TInt")[0]) + + require.Len(t, p.Param("TString"), 1) + require.Equal(t, "string", p.Param("TString")[0]) + + require.Len(t, p.Param("TFloat"), 1) + require.Equal(t, "3.1", p.Param("TFloat")[0]) + + require.Len(t, p.Param("TBool"), 1) + + tslice := p.Param("TSlice") + require.Len(t, tslice, 2) + require.Equal(t, "foo", tslice[0]) + require.Equal(t, "bar", tslice[1]) + + tint := p.Param("TSlice") + require.Len(t, tint, 2) + require.Equal(t, "foo", tint[0]) + require.Equal(t, "bar", tint[1]) + }) + + t.Run("del params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelParams("foo", "bar") + + res := req.Param("foo") + require.Empty(t, res) + + res = req.Param("bar") + require.Empty(t, res) + }) +} + +func Test_Request_UA(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetUserAgent("fiber") + require.Equal(t, "fiber", req.UserAgent()) + + req.SetUserAgent("foo") + require.Equal(t, "foo", req.UserAgent()) +} + +func Test_Request_Referer(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetReferer("http://example.com") + require.Equal(t, "http://example.com", req.Referer()) + + req.SetReferer("https://example.com") + require.Equal(t, "https://example.com", req.Referer()) +} + +func Test_Request_Cookie(t *testing.T) { + t.Parallel() + + t.Run("set cookie", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetCookie("foo", "bar") + require.Equal(t, "bar", req.Cookie("foo")) + + req.SetCookie("foo", "bar1") + require.Equal(t, "bar1", req.Cookie("foo")) + }) + + t.Run("set cookies", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.SetCookies(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) + + t.Run("set cookies with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `cookie:"int"` + CookieString string `cookie:"string"` + } + + req := AcquireRequest().SetCookiesWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.Cookie("int")) + require.Equal(t, "foo", req.Cookie("string")) + }) + + t.Run("del cookies", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetCookies(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + + req.DelCookies("foo") + require.Equal(t, "", req.Cookie("foo")) + require.Equal(t, "foo", req.Cookie("bar")) + }) +} + +func Test_Request_PathParam(t *testing.T) { + t.Parallel() + + t.Run("set path param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParam("foo", "bar") + require.Equal(t, "bar", req.PathParam("foo")) + + req.SetPathParam("foo", "bar1") + require.Equal(t, "bar1", req.PathParam("foo")) + }) + + t.Run("set path params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.SetPathParams(map[string]string{ + "foo": "bar1", + }) + require.Equal(t, "bar1", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) + + t.Run("set path params with struct", func(t *testing.T) { + t.Parallel() + type args struct { + CookieInt int `path:"int"` + CookieString string `path:"string"` + } + + req := AcquireRequest().SetPathParamsWithStruct(&args{ + CookieInt: 5, + CookieString: "foo", + }) + + require.Equal(t, "5", req.PathParam("int")) + require.Equal(t, "foo", req.PathParam("string")) + }) + + t.Run("del path params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.DelPathParams("foo") + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + }) + + t.Run("clear path params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + SetPathParams(map[string]string{ + "foo": "bar", + "bar": "foo", + }) + require.Equal(t, "bar", req.PathParam("foo")) + require.Equal(t, "foo", req.PathParam("bar")) + + req.ResetPathParams() + require.Equal(t, "", req.PathParam("foo")) + require.Equal(t, "", req.PathParam("bar")) + }) +} + +func Test_Request_FormData(t *testing.T) { + t.Parallel() + + t.Run("add form data", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.AddFormData("foo", "bar").AddFormData("foo", "fiber") + + res := req.FormData("foo") + require.Len(t, res, 2) + require.Equal(t, "bar", res[0]) + require.Equal(t, "fiber", res[1]) + }) + + t.Run("set param", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.AddFormData("foo", "bar").SetFormData("foo", "fiber") + + res := req.FormData("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + }) + + t.Run("add params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.SetFormData("foo", "bar"). + AddFormDatas(map[string][]string{ + "foo": {"fiber", "buaa"}, + "bar": {"foo"}, + }) + + res := req.FormData("foo") + require.Len(t, res, 3) + require.Contains(t, res, "bar") + require.Contains(t, res, "buaa") + require.Contains(t, res, "fiber") + + res = req.FormData("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set headers", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "foo": "fiber", + "bar": "foo", + }) + + res := req.FormData("foo") + require.Len(t, res, 1) + require.Equal(t, "fiber", res[0]) + + res = req.FormData("bar") + require.Len(t, res, 1) + require.Equal(t, "foo", res[0]) + }) + + t.Run("set params with struct", func(t *testing.T) { + t.Parallel() + + type args struct { + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `form:"int_slice"` + } + + p := AcquireRequest() + defer ReleaseRequest(p) + p.SetFormDatasWithStruct(&args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Empty(t, p.FormData("unexport")) + + require.Len(t, p.FormData("TInt"), 1) + require.Equal(t, "5", p.FormData("TInt")[0]) + + require.Len(t, p.FormData("TString"), 1) + require.Equal(t, "string", p.FormData("TString")[0]) + + require.Len(t, p.FormData("TFloat"), 1) + require.Equal(t, "3.1", p.FormData("TFloat")[0]) + + require.Len(t, p.FormData("TBool"), 1) + + tslice := p.FormData("TSlice") + require.Len(t, tslice, 2) + require.Contains(t, tslice, "bar") + require.Contains(t, tslice, "foo") + + tint := p.FormData("TSlice") + require.Len(t, tint, 2) + require.Contains(t, tint, "bar") + require.Contains(t, tint, "foo") + }) + + t.Run("del params", func(t *testing.T) { + t.Parallel() + req := AcquireRequest() + defer ReleaseRequest(req) + req.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "foo": "fiber", + "bar": "foo", + }).DelFormDatas("foo", "bar") + + res := req.FormData("foo") + require.Empty(t, res) + + res = req.FormData("bar") + require.Empty(t, res) + }) +} + +func Test_Request_File(t *testing.T) { + t.Parallel() + + t.Run("add file", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + AddFile("../.github/index.html"). + AddFiles(AcquireFile(SetFileName("tmp.txt"))) + + require.Equal(t, "../.github/index.html", req.File("index.html").path) + require.Equal(t, "../.github/index.html", req.FileByPath("../.github/index.html").path) + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Nil(t, req.File("tmp2.txt")) + require.Nil(t, req.FileByPath("tmp2.txt")) + }) + + t.Run("add file by reader", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + AddFileWithReader("tmp.txt", io.NopCloser(strings.NewReader("world"))) + + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + + content, err := io.ReadAll(req.File("tmp.txt").reader) + require.NoError(t, err) + require.Equal(t, "world", string(content)) + }) + + t.Run("add files", func(t *testing.T) { + t.Parallel() + req := AcquireRequest(). + AddFiles(AcquireFile(SetFileName("tmp.txt")), AcquireFile(SetFileName("foo.txt"))) + + require.Equal(t, "tmp.txt", req.File("tmp.txt").name) + require.Equal(t, "foo.txt", req.File("foo.txt").name) + }) +} + +func Test_Request_Timeout(t *testing.T) { + t.Parallel() + + req := AcquireRequest().SetTimeout(5 * time.Second) + + require.Equal(t, 5*time.Second, req.Timeout()) +} + +func Test_Request_Invalid_URL(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + Get("http://example.com\r\n\r\nGET /\r\n\r\n") + + require.Equal(t, ErrURLFormat, err) + require.Equal(t, (*Response)(nil), resp) +} + +func Test_Request_Unsupport_Protocol(t *testing.T) { + t.Parallel() + + resp, err := AcquireRequest(). + Get("ftp://example.com") + require.Equal(t, ErrURLFormat, err) + require.Equal(t, (*Response)(nil), resp) +} + +func Test_Request_Get(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + req := AcquireRequest().SetClient(client) + + resp, err := req.Get("http://example.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "example.com", resp.String()) + resp.Close() + } +} + +func Test_Request_Post(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusCreated). + SendString(c.FormValue("foo")) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Post("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusCreated, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + resp.Close() + } +} + +func Test_Request_Head(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Head("/", func(c fiber.Ctx) error { + return c.SendString(c.Hostname()) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Head("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "", resp.String()) + resp.Close() + } +} + +func Test_Request_Put(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Put("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Put("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + + resp.Close() + } +} + +func Test_Request_Delete(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Delete("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusNoContent). + SendString("deleted") + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Delete("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusNoContent, resp.StatusCode()) + require.Equal(t, "", resp.String()) + + resp.Close() + } +} + +func Test_Request_Options(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Options("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusOK). + SendString("options") + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + Options("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "options", resp.String()) + + resp.Close() + } +} + +func Test_Request_Send(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Post("/", func(c fiber.Ctx) error { + return c.Status(fiber.StatusOK). + SendString("post") + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetURL("http://example.com"). + SetMethod(fiber.MethodPost). + Send() + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "post", resp.String()) + + resp.Close() + } +} + +func Test_Request_Patch(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + + app.Patch("/", func(c fiber.Ctx) error { + return c.SendString(c.FormValue("foo")) + }) + + go start() + time.Sleep(100 * time.Millisecond) + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + resp, err := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + Patch("http://example.com") + + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "bar", resp.String()) + + resp.Close() + } +} + +func Test_Request_Header_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + c.Request().Header.VisitAll(func(key, value []byte) { + if k := string(key); k == "K1" || k == "K2" { + _, err := c.Write(key) + require.NoError(t, err) + _, err = c.Write(value) + require.NoError(t, err) + } + }) + return nil + } + + wrapAgent := func(r *Request) { + r.SetHeader("k1", "v1"). + AddHeader("k1", "v11"). + AddHeaders(map[string][]string{ + "k1": {"v22", "v33"}, + }). + SetHeaders(map[string]string{ + "k2": "v2", + }). + AddHeader("k2", "v22") + } + + testRequest(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") +} + +func Test_Request_UserAgent_With_Server(t *testing.T) { + t.Parallel() + + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.UserAgent()) + } + + t.Run("default", func(t *testing.T) { + t.Parallel() + testRequest(t, handler, func(_ *Request) {}, defaultUserAgent, 5) + }) + + t.Run("custom", func(t *testing.T) { + t.Parallel() + testRequest(t, handler, func(agent *Request) { + agent.SetUserAgent("ua") + }, "ua", 5) + }) +} + +func Test_Request_Cookie_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.SendString( + c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) + } + + wrapAgent := func(req *Request) { + req.SetCookie("k1", "v1"). + SetCookies(map[string]string{ + "k2": "v2", + "k3": "v3", + "k4": "v4", + }).DelCookies("k4") + } + + testRequest(t, handler, wrapAgent, "v1v2v3") +} + +func Test_Request_Referer_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().Header.Referer()) + } + + wrapAgent := func(req *Request) { + req.SetReferer("http://referer.com") + } + + testRequest(t, handler, wrapAgent, "http://referer.com") +} + +func Test_Request_QueryString_With_Server(t *testing.T) { + t.Parallel() + handler := func(c fiber.Ctx) error { + return c.Send(c.Request().URI().QueryString()) + } + + wrapAgent := func(req *Request) { + req.SetParam("foo", "bar"). + SetParams(map[string]string{ + "bar": "baz", + }) + } + + testRequest(t, handler, wrapAgent, "foo=bar&bar=baz") +} + +func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { + t.Helper() + + basename := filepath.Base(filename) + require.Equal(t, fh.Filename, basename) + + b1, err := os.ReadFile(filepath.Clean(filename)) + require.NoError(t, err) + + b2 := make([]byte, fh.Size) + f, err := fh.Open() + require.NoError(t, err) + defer func() { require.NoError(t, f.Close()) }() + _, err = f.Read(b2) + require.NoError(t, err) + require.Equal(t, b1, b2) +} + +func Test_Request_Body_With_Server(t *testing.T) { + t.Parallel() + + t.Run("json body", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + require.Equal(t, "application/json", string(c.Request().Header.ContentType())) + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + agent.SetJSON(map[string]string{ + "success": "hello", + }) + }, + "{\"success\":\"hello\"}", + ) + }) + + t.Run("xml body", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + require.Equal(t, "application/xml", string(c.Request().Header.ContentType())) + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + type args struct { + Content string `xml:"content"` + } + agent.SetXML(args{ + Content: "hello", + }) + }, + "hello", + ) + }) + + t.Run("formdata", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + require.Equal(t, fiber.MIMEApplicationForm, string(c.Request().Header.ContentType())) + return c.Send([]byte("foo=" + c.FormValue("foo") + "&bar=" + c.FormValue("bar") + "&fiber=" + c.FormValue("fiber"))) + }, + func(agent *Request) { + agent.SetFormData("foo", "bar"). + SetFormDatas(map[string]string{ + "bar": "baz", + "fiber": "fast", + }) + }, + "foo=bar&bar=baz&fiber=fast") + }) + + t.Run("multipart form", func(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + + mf, err := c.MultipartForm() + require.NoError(t, err) + require.Equal(t, "bar", mf.Value["foo"][0]) + + return c.Send(c.Request().Body()) + }) + + go start() + + client := NewClient().SetDial(ln) + + req := AcquireRequest(). + SetClient(client). + SetBoundary("myBoundary"). + SetFormData("foo", "bar"). + AddFiles(AcquireFile( + SetFileName("hello.txt"), + SetFileFieldName("foo"), + SetFileReader(io.NopCloser(strings.NewReader("world"))), + )) + + require.Equal(t, "myBoundary", req.Boundary()) + + resp, err := req.Post("http://exmaple.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + + form, err := multipart.NewReader(bytes.NewReader(resp.Body()), "myBoundary").ReadForm(1024 * 1024) + require.NoError(t, err) + require.Equal(t, "bar", form.Value["foo"][0]) + resp.Close() + }) + + t.Run("multipart form send file", func(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(fiber.HeaderContentType)) + + fh1, err := c.FormFile("field1") + require.NoError(t, err) + require.Equal(t, "name", fh1.Filename) + buf := make([]byte, fh1.Size) + f, err := fh1.Open() + require.NoError(t, err) + defer func() { require.NoError(t, f.Close()) }() + _, err = f.Read(buf) + require.NoError(t, err) + require.Equal(t, "form file", string(buf)) + + fh2, err := c.FormFile("file2") + require.NoError(t, err) + checkFormFile(t, fh2, "../.github/testdata/index.html") + + fh3, err := c.FormFile("file3") + require.NoError(t, err) + checkFormFile(t, fh3, "../.github/testdata/index.tmpl") + + return c.SendString("multipart form files") + }) + + go start() + + client := NewClient().SetDial(ln) + + for i := 0; i < 5; i++ { + req := AcquireRequest(). + SetClient(client). + AddFiles( + AcquireFile( + SetFileFieldName("field1"), + SetFileName("name"), + SetFileReader(io.NopCloser(bytes.NewReader([]byte("form file")))), + ), + ). + AddFile("../.github/testdata/index.html"). + AddFile("../.github/testdata/index.tmpl"). + SetBoundary("myBoundary") + + resp, err := req.Post("http://example.com") + require.NoError(t, err) + require.Equal(t, "multipart form files", resp.String()) + + resp.Close() + } + }) + + t.Run("multipart random boundary", func(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Post("/", func(c fiber.Ctx) error { + reg := regexp.MustCompile(`multipart/form-data; boundary=[\-\w]{35}`) + require.True(t, reg.MatchString(c.Get(fiber.HeaderContentType))) + + return c.Send(c.Request().Body()) + }) + + go start() + + client := NewClient().SetDial(ln) + + req := AcquireRequest(). + SetClient(client). + SetFormData("foo", "bar"). + AddFiles(AcquireFile( + SetFileName("hello.txt"), + SetFileFieldName("foo"), + SetFileReader(io.NopCloser(strings.NewReader("world"))), + )) + + resp, err := req.Post("http://exmaple.com") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + }) + + t.Run("raw body", func(t *testing.T) { + t.Parallel() + testRequest(t, + func(c fiber.Ctx) error { + return c.SendString(string(c.Request().Body())) + }, + func(agent *Request) { + agent.SetRawBody([]byte("hello")) + }, + "hello", + ) + }) +} + +func Test_Request_Error_Body_With_Server(t *testing.T) { + t.Parallel() + t.Run("json error", func(t *testing.T) { + t.Parallel() + testRequestFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetJSON(complex(1, 1)) + }, + errors.New("json: unsupported type: complex128"), + ) + }) + + t.Run("xml error", func(t *testing.T) { + t.Parallel() + testRequestFail(t, + func(c fiber.Ctx) error { + return c.SendString("") + }, + func(agent *Request) { + agent.SetXML(complex(1, 1)) + }, + errors.New("xml: unsupported type: complex128"), + ) + }) + + t.Run("form body with invalid boundary", func(t *testing.T) { + t.Parallel() + + _, err := AcquireRequest(). + SetBoundary("*"). + AddFileWithReader("t.txt", io.NopCloser(strings.NewReader("world"))). + Get("http://example.com") + require.Equal(t, "set boundary error: mime: invalid boundary character", err.Error()) + }) + + t.Run("open non exist file", func(t *testing.T) { + t.Parallel() + + _, err := AcquireRequest(). + AddFile("non-exist-file!"). + Get("http://example.com") + require.Contains(t, err.Error(), "open non-exist-file!") + }) +} + +func Test_Request_Timeout_With_Server(t *testing.T) { + t.Parallel() + + app, ln, start := createHelperServer(t) + app.Get("/", func(c fiber.Ctx) error { + time.Sleep(time.Millisecond * 200) + return c.SendString("timeout") + }) + go start() + + client := NewClient().SetDial(ln) + + _, err := AcquireRequest(). + SetClient(client). + SetTimeout(50 * time.Millisecond). + Get("http://example.com") + + require.Equal(t, ErrTimeoutOrCancel, err) +} + +func Test_Request_MaxRedirects(t *testing.T) { + t.Parallel() + + ln := fasthttputil.NewInmemoryListener() + + app := fiber.New() + + app.Get("/", func(c fiber.Ctx) error { + if c.Request().URI().QueryArgs().Has("foo") { + return c.Redirect().To("/foo") + } + return c.Redirect().To("/") + }) + app.Get("/foo", func(c fiber.Ctx) error { + return c.SendString("redirect") + }) + + go func() { require.NoError(t, app.Listener(ln, fiber.ListenConfig{DisableStartupMessage: true})) }() + + t.Run("success", func(t *testing.T) { + t.Parallel() + + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + + resp, err := AcquireRequest(). + SetClient(client). + SetMaxRedirects(1). + Get("http://example.com?foo") + body := resp.String() + code := resp.StatusCode() + + require.Equal(t, 200, code) + require.Equal(t, "redirect", body) + require.NoError(t, err) + + resp.Close() + }) + + t.Run("error", func(t *testing.T) { + t.Parallel() + + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + + resp, err := AcquireRequest(). + SetClient(client). + SetMaxRedirects(1). + Get("http://example.com") + + require.Nil(t, resp) + require.Equal(t, "too many redirects detected when doing the request", err.Error()) + }) + + t.Run("MaxRedirects", func(t *testing.T) { + t.Parallel() + + client := NewClient().SetDial(func(_ string) (net.Conn, error) { return ln.Dial() }) //nolint:wrapcheck // not needed + + req := AcquireRequest(). + SetClient(client). + SetMaxRedirects(3) + + require.Equal(t, 3, req.MaxRedirects()) + }) +} + +func Test_SetValWithStruct(t *testing.T) { + t.Parallel() + + // test SetValWithStruct vai QueryParam struct. + type args struct { + unexport int + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + t.Run("the struct should be applied", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + SetValWithStruct(p, "param", args{ + unexport: 5, + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: false, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Equal(t, "", string(p.Peek("unexport"))) + require.Equal(t, []byte("5"), p.Peek("TInt")) + require.Equal(t, []byte("string"), p.Peek("TString")) + require.Equal(t, []byte("3.1"), p.Peek("TFloat")) + require.Equal(t, "", string(p.Peek("TBool"))) + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + t.Run("the pointer of a struct should be applied", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + SetValWithStruct(p, "param", &args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + + require.Equal(t, []byte("5"), p.Peek("TInt")) + require.Equal(t, []byte("string"), p.Peek("TString")) + require.Equal(t, []byte("3.1"), p.Peek("TFloat")) + require.Equal(t, "true", string(p.Peek("TBool"))) + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(t, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + t.Run("the zero val should be ignore", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + SetValWithStruct(p, "param", &args{ + TInt: 0, + TString: "", + TFloat: 0.0, + }) + + require.Equal(t, "", string(p.Peek("TInt"))) + require.Equal(t, "", string(p.Peek("TString"))) + require.Equal(t, "", string(p.Peek("TFloat"))) + require.Empty(t, p.PeekMulti("TSlice")) + require.Empty(t, p.PeekMulti("int_slice")) + }) + + t.Run("error type should ignore", func(t *testing.T) { + t.Parallel() + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + SetValWithStruct(p, "param", 5) + require.Equal(t, 0, p.Len()) + }) +} + +func Benchmark_SetValWithStruct(b *testing.B) { + // test SetValWithStruct vai QueryParam struct. + type args struct { + unexport int + TInt int + TString string + TFloat float64 + TBool bool + TSlice []string + TIntSlice []int `param:"int_slice"` + } + + b.Run("the struct should be applied", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", args{ + unexport: 5, + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: false, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + } + + require.Equal(b, "", string(p.Peek("unexport"))) + require.Equal(b, []byte("5"), p.Peek("TInt")) + require.Equal(b, []byte("string"), p.Peek("TString")) + require.Equal(b, []byte("3.1"), p.Peek("TFloat")) + require.Equal(b, "", string(p.Peek("TBool"))) + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + b.Run("the pointer of a struct should be applied", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", &args{ + TInt: 5, + TString: "string", + TFloat: 3.1, + TBool: true, + TSlice: []string{"foo", "bar"}, + TIntSlice: []int{1, 2}, + }) + } + + require.Equal(b, []byte("5"), p.Peek("TInt")) + require.Equal(b, []byte("string"), p.Peek("TString")) + require.Equal(b, []byte("3.1"), p.Peek("TFloat")) + require.Equal(b, "true", string(p.Peek("TBool"))) + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "foo" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("TSlice") { + if string(v) == "bar" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "1" { + return true + } + } + return false + }()) + + require.True(b, func() bool { + for _, v := range p.PeekMulti("int_slice") { + if string(v) == "2" { + return true + } + } + return false + }()) + }) + + b.Run("the zero val should be ignore", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", &args{ + TInt: 0, + TString: "", + TFloat: 0.0, + }) + } + + require.Empty(b, string(p.Peek("TInt"))) + require.Empty(b, string(p.Peek("TString"))) + require.Empty(b, string(p.Peek("TFloat"))) + require.Empty(b, len(p.PeekMulti("TSlice"))) + require.Empty(b, len(p.PeekMulti("int_slice"))) + }) + + b.Run("error type should ignore", func(b *testing.B) { + p := &QueryParam{ + Args: fasthttp.AcquireArgs(), + } + + b.ReportAllocs() + b.StartTimer() + + for i := 0; i < b.N; i++ { + SetValWithStruct(p, "param", 5) + } + + require.Equal(b, 0, p.Len()) + }) +} diff --git a/client/response.go b/client/response.go new file mode 100644 index 00000000000..f6ecd6fcd85 --- /dev/null +++ b/client/response.go @@ -0,0 +1,184 @@ +package client + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/fs" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/gofiber/utils/v2" + "github.com/valyala/fasthttp" +) + +// Response is the result of a request. This object is used to access the response data. +type Response struct { + client *Client + request *Request + cookie []*fasthttp.Cookie + + RawResponse *fasthttp.Response +} + +// setClient method sets client object in response instance. +// Use core object in the client. +func (r *Response) setClient(c *Client) { + r.client = c +} + +// setRequest method sets Request object in response instance. +// The request will be released when the Response.Close is called. +func (r *Response) setRequest(req *Request) { + r.request = req +} + +// Status method returns the HTTP status string for the executed request. +func (r *Response) Status() string { + return string(r.RawResponse.Header.StatusMessage()) +} + +// StatusCode method returns the HTTP status code for the executed request. +func (r *Response) StatusCode() int { + return r.RawResponse.StatusCode() +} + +// Protocol method returns the HTTP response protocol used for the request. +func (r *Response) Protocol() string { + return string(r.RawResponse.Header.Protocol()) +} + +// Header method returns the response headers. +func (r *Response) Header(key string) string { + return utils.UnsafeString(r.RawResponse.Header.Peek(key)) +} + +// Cookies method to access all the response cookies. +func (r *Response) Cookies() []*fasthttp.Cookie { + return r.cookie +} + +// Body method returns HTTP response as []byte array for the executed request. +func (r *Response) Body() []byte { + return r.RawResponse.Body() +} + +// String method returns the body of the server response as String. +func (r *Response) String() string { + return strings.TrimSpace(string(r.Body())) +} + +// JSON method will unmarshal body to json. +func (r *Response) JSON(v any) error { + return r.client.jsonUnmarshal(r.Body(), v) +} + +// XML method will unmarshal body to xml. +func (r *Response) XML(v any) error { + return r.client.xmlUnmarshal(r.Body(), v) +} + +// Save method will save the body to a file or io.Writer. +func (r *Response) Save(v any) error { + switch p := v.(type) { + case string: + file := filepath.Clean(p) + dir := filepath.Dir(file) + + // create directory + if _, err := os.Stat(dir); err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("failed to check directory: %w", err) + } + + if err = os.MkdirAll(dir, 0o750); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + } + + // create file + outFile, err := os.Create(file) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer func() { _ = outFile.Close() }() //nolint:errcheck // not needed + + _, err = io.Copy(outFile, bytes.NewReader(r.Body())) + if err != nil { + return fmt.Errorf("failed to write response body to file: %w", err) + } + + return nil + case io.Writer: + _, err := io.Copy(p, bytes.NewReader(r.Body())) + if err != nil { + return fmt.Errorf("failed to write response body to io.Writer: %w", err) + } + defer func() { + if pc, ok := p.(io.WriteCloser); ok { + _ = pc.Close() //nolint:errcheck // not needed + } + }() + + return nil + default: + return ErrNotSupportSaveMethod + } +} + +// Reset clear Response object. +func (r *Response) Reset() { + r.client = nil + r.request = nil + + for len(r.cookie) != 0 { + t := r.cookie[0] + r.cookie = r.cookie[1:] + fasthttp.ReleaseCookie(t) + } + + r.RawResponse.Reset() +} + +// Close method will release Request object and Response object, +// after call Close please don't use these object. +func (r *Response) Close() { + if r.request != nil { + tmp := r.request + r.request = nil + ReleaseRequest(tmp) + } + ReleaseResponse(r) +} + +var responsePool = &sync.Pool{ + New: func() any { + return &Response{ + cookie: []*fasthttp.Cookie{}, + RawResponse: fasthttp.AcquireResponse(), + } + }, +} + +// AcquireResponse returns an empty response object from the pool. +// +// The returned response may be returned to the pool with ReleaseResponse when no longer needed. +// This allows reducing GC load. +func AcquireResponse() *Response { + resp, ok := responsePool.Get().(*Response) + if !ok { + panic("unexpected type from responsePool.Get()") + } + return resp +} + +// ReleaseResponse returns the object acquired via AcquireResponse to the pool. +// +// Do not access the released Response object, otherwise data races may occur. +func ReleaseResponse(resp *Response) { + resp.Reset() + responsePool.Put(resp) +} diff --git a/client/response_test.go b/client/response_test.go new file mode 100644 index 00000000000..622e8357149 --- /dev/null +++ b/client/response_test.go @@ -0,0 +1,418 @@ +package client + +import ( + "bytes" + "crypto/tls" + "encoding/xml" + "io" + "net" + "os" + "testing" + + "github.com/gofiber/fiber/v3/internal/tlstest" + + "github.com/gofiber/fiber/v3" + "github.com/stretchr/testify/require" +) + +func Test_Response_Status(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + }) + + return server + } + + t.Run("success", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + require.NoError(t, err) + require.Equal(t, "OK", resp.Status()) + resp.Close() + }) + + t.Run("fail", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example/fail") + + require.NoError(t, err) + require.Equal(t, "Proxy Authentication Required", resp.Status()) + resp.Close() + }) +} + +func Test_Response_Status_Code(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + app.Get("/fail", func(c fiber.Ctx) error { + return c.SendStatus(407) + }) + }) + + return server + } + + t.Run("success", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + require.NoError(t, err) + require.Equal(t, 200, resp.StatusCode()) + resp.Close() + }) + + t.Run("fail", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example/fail") + + require.NoError(t, err) + require.Equal(t, 407, resp.StatusCode()) + resp.Close() + }) +} + +func Test_Response_Protocol(t *testing.T) { + t.Parallel() + + t.Run("http", func(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("foo") + }) + }) + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example") + + require.NoError(t, err) + require.Equal(t, "HTTP/1.1", resp.Protocol()) + resp.Close() + }) + + t.Run("https", func(t *testing.T) { + t.Parallel() + + serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() + require.NoError(t, err) + + ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") + require.NoError(t, err) + + ln = tls.NewListener(ln, serverTLSConf) + + app := fiber.New() + app.Get("/", func(c fiber.Ctx) error { + return c.SendString(c.Scheme()) + }) + + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := NewClient() + resp, err := client.SetTLSConfig(clientTLSConf).Get("https://" + ln.Addr().String()) + + require.NoError(t, err) + require.Equal(t, clientTLSConf, client.TLSConfig()) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "https", resp.String()) + require.Equal(t, "HTTP/1.1", resp.Protocol()) + + resp.Close() + }) +} + +func Test_Response_Header(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Response().Header.Add("foo", "bar") + return c.SendString("helo world") + }) + }) + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, "bar", resp.Header("foo")) + resp.Close() +} + +func Test_Response_Cookie(t *testing.T) { + t.Parallel() + + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + c.Cookie(&fiber.Cookie{ + Name: "foo", + Value: "bar", + }) + return c.SendString("helo world") + }) + }) + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, "bar", string(resp.Cookies()[0].Value())) + resp.Close() +} + +func Test_Response_Body(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/", func(c fiber.Ctx) error { + return c.SendString("hello world") + }) + + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + + app.Get("/xml", func(c fiber.Ctx) error { + return c.SendString("success") + }) + }) + + return server + } + + t.Run("raw body", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, []byte("hello world"), resp.Body()) + resp.Close() + }) + + t.Run("string body", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com") + + require.NoError(t, err) + require.Equal(t, "hello world", resp.String()) + resp.Close() + }) + + t.Run("json body", func(t *testing.T) { + t.Parallel() + type body struct { + Status string `json:"status"` + } + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + tmp := &body{} + err = resp.JSON(tmp) + require.NoError(t, err) + require.Equal(t, "success", tmp.Status) + resp.Close() + }) + + t.Run("xml body", func(t *testing.T) { + t.Parallel() + type body struct { + Name xml.Name `xml:"status"` + Status string `xml:"name"` + } + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/xml") + + require.NoError(t, err) + + tmp := &body{} + err = resp.XML(tmp) + require.NoError(t, err) + require.Equal(t, "success", tmp.Status) + resp.Close() + }) +} + +func Test_Response_Save(t *testing.T) { + t.Parallel() + + setupApp := func() *testServer { + server := startTestServer(t, func(app *fiber.App) { + app.Get("/json", func(c fiber.Ctx) error { + return c.SendString("{\"status\":\"success\"}") + }) + }) + + return server + } + + t.Run("file path", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + err = resp.Save("./test/tmp.json") + require.NoError(t, err) + defer func() { + _, err := os.Stat("./test/tmp.json") + require.NoError(t, err) + + err = os.RemoveAll("./test") + require.NoError(t, err) + }() + + file, err := os.Open("./test/tmp.json") + require.NoError(t, err) + defer func(file *os.File) { + err := file.Close() + require.NoError(t, err) + }(file) + + data, err := io.ReadAll(file) + require.NoError(t, err) + require.Equal(t, "{\"status\":\"success\"}", string(data)) + }) + + t.Run("io.Writer", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + buf := &bytes.Buffer{} + + err = resp.Save(buf) + require.NoError(t, err) + require.Equal(t, "{\"status\":\"success\"}", buf.String()) + }) + + t.Run("error type", func(t *testing.T) { + t.Parallel() + + server := setupApp() + defer server.stop() + + client := NewClient().SetDial(server.dial()) + + resp, err := AcquireRequest(). + SetClient(client). + Get("http://example.com/json") + + require.NoError(t, err) + + err = resp.Save(nil) + require.Error(t, err) + }) +} diff --git a/client_test.go b/client_test.go deleted file mode 100644 index 57cc4e4d2eb..00000000000 --- a/client_test.go +++ /dev/null @@ -1,1337 +0,0 @@ -//nolint:wrapcheck // We must not wrap errors in tests -package fiber - -import ( - "bytes" - "crypto/tls" - "encoding/base64" - "encoding/json" - "encoding/xml" - "errors" - "io" - "mime/multipart" - "net" - "os" - "path/filepath" - "regexp" - "strings" - "testing" - "time" - - "github.com/gofiber/fiber/v3/internal/tlstest" - "github.com/stretchr/testify/require" - "github.com/valyala/fasthttp" - "github.com/valyala/fasthttp/fasthttputil" -) - -func Test_Client_Invalid_URL(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Host()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://example.com\r\n\r\nGET /\r\n\r\n") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.Error(t, errs[0], - `Expected error "missing required Host header in request"`) -} - -func Test_Client_Unsupported_Protocol(t *testing.T) { - t.Parallel() - - a := Get("ftp://example.com") - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.ErrorContains(t, errs[0], `unsupported protocol "ftp". http and https are supported`) -} - -func Test_Client_Get(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Host()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "example.com", body) - require.Empty(t, errs) - } -} - -func Test_Client_Head(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Head("/", func(c Ctx) error { - return c.SendStatus(StatusAccepted) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - for i := 0; i < 5; i++ { - a := Head("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusAccepted, code) - require.Equal(t, "", body) - require.Empty(t, errs) - } -} - -func Test_Client_Post(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Post("/", func(c Ctx) error { - return c.Status(StatusCreated). - SendString(c.FormValue("foo")) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Post("http://example.com"). - Form(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusCreated, code) - require.Equal(t, "bar", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_Put(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Put("/", func(c Ctx) error { - return c.SendString(c.FormValue("foo")) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Put("http://example.com"). - Form(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "bar", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_Patch(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Patch("/", func(c Ctx) error { - return c.SendString(c.FormValue("foo")) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Patch("http://example.com"). - Form(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "bar", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_Delete(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Delete("/", func(c Ctx) error { - return c.Status(StatusNoContent). - SendString("deleted") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - args := AcquireArgs() - - a := Delete("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusNoContent, code) - require.Equal(t, "", body) - require.Empty(t, errs) - - ReleaseArgs(args) - } -} - -func Test_Client_UserAgent(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("default", func(t *testing.T) { - t.Parallel() - for i := 0; i < 5; i++ { - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, defaultUserAgent, body) - require.Empty(t, errs) - } - }) - - t.Run("custom", func(t *testing.T) { - t.Parallel() - for i := 0; i < 5; i++ { - c := AcquireClient() - c.UserAgent = "ua" - - a := c.Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "ua", body) - require.Empty(t, errs) - ReleaseClient(c) - } - }) -} - -func Test_Client_Agent_Set_Or_Add_Headers(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - c.Request().Header.VisitAll(func(key, value []byte) { - if k := string(key); k == "K1" || k == "K2" { - _, err := c.Write(key) - require.NoError(t, err) - _, err = c.Write(value) - require.NoError(t, err) - } - }) - return nil - } - - wrapAgent := func(a *Agent) { - a.Set("k1", "v1"). - SetBytesK([]byte("k1"), "v1"). - SetBytesV("k1", []byte("v1")). - AddBytesK([]byte("k1"), "v11"). - AddBytesV("k1", []byte("v22")). - AddBytesKV([]byte("k1"), []byte("v33")). - SetBytesKV([]byte("k2"), []byte("v2")). - Add("k2", "v22") - } - - testAgent(t, handler, wrapAgent, "K1v1K1v11K1v22K1v33K2v2K2v22") -} - -func Test_Client_Agent_Connection_Close(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - if c.Request().Header.ConnectionClose() { - return c.SendString("close") - } - return c.SendString("not close") - } - - wrapAgent := func(a *Agent) { - a.ConnectionClose() - } - - testAgent(t, handler, wrapAgent, "close") -} - -func Test_Client_Agent_UserAgent(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Header.UserAgent()) - } - - wrapAgent := func(a *Agent) { - a.UserAgent("ua"). - UserAgentBytes([]byte("ua")) - } - - testAgent(t, handler, wrapAgent, "ua") -} - -func Test_Client_Agent_Cookie(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.SendString( - c.Cookies("k1") + c.Cookies("k2") + c.Cookies("k3") + c.Cookies("k4")) - } - - wrapAgent := func(a *Agent) { - a.Cookie("k1", "v1"). - CookieBytesK([]byte("k2"), "v2"). - CookieBytesKV([]byte("k2"), []byte("v2")). - Cookies("k3", "v3", "k4", "v4"). - CookiesBytesKV([]byte("k3"), []byte("v3"), []byte("k4"), []byte("v4")) - } - - testAgent(t, handler, wrapAgent, "v1v2v3v4") -} - -func Test_Client_Agent_Referer(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Header.Referer()) - } - - wrapAgent := func(a *Agent) { - a.Referer("http://referer.com"). - RefererBytes([]byte("http://referer.com")) - } - - testAgent(t, handler, wrapAgent, "http://referer.com") -} - -func Test_Client_Agent_ContentType(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Header.ContentType()) - } - - wrapAgent := func(a *Agent) { - a.ContentType("custom-type"). - ContentTypeBytes([]byte("custom-type")) - } - - testAgent(t, handler, wrapAgent, "custom-type") -} - -func Test_Client_Agent_Host(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString(c.Host()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://1.1.1.1:8080"). - Host("example.com"). - HostBytes([]byte("example.com")) - - require.Equal(t, "1.1.1.1:8080", a.HostClient.Addr) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "example.com", body) - require.Empty(t, errs) -} - -func Test_Client_Agent_QueryString(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().URI().QueryString()) - } - - wrapAgent := func(a *Agent) { - a.QueryString("foo=bar&bar=baz"). - QueryStringBytes([]byte("foo=bar&bar=baz")) - } - - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} - -func Test_Client_Agent_BasicAuth(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - // Get authorization header - auth := c.Get(HeaderAuthorization) - // Decode the header contents - raw, err := base64.StdEncoding.DecodeString(auth[6:]) - require.NoError(t, err) - - return c.Send(raw) - } - - wrapAgent := func(a *Agent) { - a.BasicAuth("foo", "bar"). - BasicAuthBytes([]byte("foo"), []byte("bar")) - } - - testAgent(t, handler, wrapAgent, "foo:bar") -} - -func Test_Client_Agent_BodyString(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.BodyString("foo=bar&bar=baz") - } - - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} - -func Test_Client_Agent_Body(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.Body([]byte("foo=bar&bar=baz")) - } - - testAgent(t, handler, wrapAgent, "foo=bar&bar=baz") -} - -func Test_Client_Agent_BodyStream(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.BodyStream(strings.NewReader("body stream"), -1) - } - - testAgent(t, handler, wrapAgent, "body stream") -} - -func Test_Client_Agent_Custom_Response(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("custom") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - a := AcquireAgent() - resp := AcquireResponse() - - req := a.Request() - req.Header.SetMethod(MethodGet) - req.SetRequestURI("http://example.com") - - require.NoError(t, a.Parse()) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.SetResponse(resp). - String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "custom", body) - require.Equal(t, "custom", string(resp.Body())) - require.Empty(t, errs) - - ReleaseResponse(resp) - } -} - -func Test_Client_Agent_Dest(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("dest") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("small dest", func(t *testing.T) { - t.Parallel() - dest := []byte("de") - - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.Dest(dest[:0]).String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "dest", body) - require.Equal(t, "de", string(dest)) - require.Empty(t, errs) - }) - - t.Run("enough dest", func(t *testing.T) { - t.Parallel() - dest := []byte("foobar") - - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.Dest(dest[:0]).String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "dest", body) - require.Equal(t, "destar", string(dest)) - require.Empty(t, errs) - }) -} - -// readErrorConn is a struct for testing retryIf -type readErrorConn struct { - net.Conn -} - -func (*readErrorConn) Read(_ []byte) (int, error) { - return 0, errors.New("error") -} - -func (*readErrorConn) Write(p []byte) (int, error) { - return len(p), nil -} - -func (*readErrorConn) Close() error { - return nil -} - -func (*readErrorConn) LocalAddr() net.Addr { - return nil -} - -func (*readErrorConn) RemoteAddr() net.Addr { - return nil -} - -func (*readErrorConn) SetReadDeadline(_ time.Time) error { - return nil -} - -func (*readErrorConn) SetWriteDeadline(_ time.Time) error { - return nil -} - -func Test_Client_Agent_RetryIf(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Post("http://example.com"). - RetryIf(func(_ *Request) bool { - return true - }) - dialsCount := 0 - a.HostClient.Dial = func(_ string) (net.Conn, error) { - dialsCount++ - switch dialsCount { - case 1: - return &readErrorConn{}, nil - case 2: - return &readErrorConn{}, nil - case 3: - return &readErrorConn{}, nil - case 4: - return ln.Dial() - default: - t.Fatalf("unexpected number of dials: %d", dialsCount) - } - panic("unreachable") - } - - _, _, errs := a.String() - require.Equal(t, 4, dialsCount) - require.Empty(t, errs) -} - -func Test_Client_Agent_Json(t *testing.T) { - t.Parallel() - // Test without ctype parameter - handler := func(c Ctx) error { - require.Equal(t, MIMEApplicationJSON, string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.JSON(data{Success: true}) - } - - testAgent(t, handler, wrapAgent, `{"success":true}`) - - // Test with ctype parameter - handler = func(c Ctx) error { - require.Equal(t, "application/problem+json", string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - wrapAgent = func(a *Agent) { - a.JSON(data{Success: true}, "application/problem+json") - } - - testAgent(t, handler, wrapAgent, `{"success":true}`) -} - -func Test_Client_Agent_Json_Error(t *testing.T) { - t.Parallel() - a := Get("http://example.com"). - JSONEncoder(json.Marshal). - JSON(complex(1, 1)) - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - wantErr := new(json.UnsupportedTypeError) - require.ErrorAs(t, errs[0], &wantErr) -} - -func Test_Client_Agent_XML(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - require.Equal(t, MIMEApplicationXML, string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - wrapAgent := func(a *Agent) { - a.XML(data{Success: true}) - } - - testAgent(t, handler, wrapAgent, "true") -} - -func Test_Client_Agent_XML_Error(t *testing.T) { - t.Parallel() - a := Get("http://example.com"). - XML(complex(1, 1)) - - _, body, errs := a.String() - require.Equal(t, "", body) - require.Len(t, errs, 1) - wantErr := new(xml.UnsupportedTypeError) - require.ErrorAs(t, errs[0], &wantErr) -} - -func Test_Client_Agent_Form(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - require.Equal(t, MIMEApplicationForm, string(c.Request().Header.ContentType())) - - return c.Send(c.Request().Body()) - } - - args := AcquireArgs() - - args.Set("foo", "bar") - - wrapAgent := func(a *Agent) { - a.Form(args) - } - - testAgent(t, handler, wrapAgent, "foo=bar") - - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Post("/", func(c Ctx) error { - require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) - - mf, err := c.MultipartForm() - require.NoError(t, err) - require.Equal(t, "bar", mf.Value["foo"][0]) - - return c.Send(c.Request().Body()) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - args := AcquireArgs() - - args.Set("foo", "bar") - - a := Post("http://example.com"). - Boundary("myBoundary"). - MultipartForm(args) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "--myBoundary\r\nContent-Disposition: form-data; name=\"foo\"\r\n\r\nbar\r\n--myBoundary--\r\n", body) - require.Empty(t, errs) - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm_Errors(t *testing.T) { - t.Parallel() - - a := AcquireAgent() - a.mw = &errorMultipartWriter{} - - args := AcquireArgs() - args.Set("foo", "bar") - - ff1 := &FormFile{"", "name1", []byte("content"), false} - ff2 := &FormFile{"", "name2", []byte("content"), false} - a.FileData(ff1, ff2). - MultipartForm(args) - - require.Len(t, a.errs, 4) - ReleaseArgs(args) -} - -func Test_Client_Agent_MultipartForm_SendFiles(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Post("/", func(c Ctx) error { - require.Equal(t, "multipart/form-data; boundary=myBoundary", c.Get(HeaderContentType)) - - fh1, err := c.FormFile("field1") - require.NoError(t, err) - require.Equal(t, "name", fh1.Filename) - buf := make([]byte, fh1.Size) - f, err := fh1.Open() - require.NoError(t, err) - defer func() { - err := f.Close() - require.NoError(t, err) - }() - _, err = f.Read(buf) - require.NoError(t, err) - require.Equal(t, "form file", string(buf)) - - fh2, err := c.FormFile("index") - require.NoError(t, err) - checkFormFile(t, fh2, ".github/testdata/index.html") - - fh3, err := c.FormFile("file3") - require.NoError(t, err) - checkFormFile(t, fh3, ".github/testdata/index.tmpl") - - return c.SendString("multipart form files") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - for i := 0; i < 5; i++ { - ff := AcquireFormFile() - ff.Fieldname = "field1" - ff.Name = "name" - ff.Content = []byte("form file") - - a := Post("http://example.com"). - Boundary("myBoundary"). - FileData(ff). - SendFiles(".github/testdata/index.html", "index", ".github/testdata/index.tmpl"). - MultipartForm(nil) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "multipart form files", body) - require.Empty(t, errs) - - ReleaseFormFile(ff) - } -} - -func checkFormFile(t *testing.T, fh *multipart.FileHeader, filename string) { - t.Helper() - - basename := filepath.Base(filename) - require.Equal(t, fh.Filename, basename) - - b1, err := os.ReadFile(filename) //nolint:gosec // We're in a test so reading user-provided files by name is fine - require.NoError(t, err) - - b2 := make([]byte, fh.Size) - f, err := fh.Open() - require.NoError(t, err) - defer func() { - err := f.Close() - require.NoError(t, err) - }() - _, err = f.Read(b2) - require.NoError(t, err) - require.Equal(t, b1, b2) -} - -func Test_Client_Agent_Multipart_Random_Boundary(t *testing.T) { - t.Parallel() - - a := Post("http://example.com"). - MultipartForm(nil) - - reg := regexp.MustCompile(`multipart/form-data; boundary=\w{30}`) - - require.True(t, reg.Match(a.req.Header.Peek(HeaderContentType))) -} - -func Test_Client_Agent_Multipart_Invalid_Boundary(t *testing.T) { - t.Parallel() - - a := Post("http://example.com"). - Boundary("*"). - MultipartForm(nil) - - require.Len(t, a.errs, 1) - require.ErrorContains(t, a.errs[0], "mime: invalid boundary character") -} - -func Test_Client_Agent_SendFile_Error(t *testing.T) { - t.Parallel() - - a := Post("http://example.com"). - SendFile("non-exist-file!", "") - - require.Len(t, a.errs, 1) - require.ErrorIs(t, a.errs[0], os.ErrNotExist) -} - -func Test_Client_Debug(t *testing.T) { - t.Parallel() - handler := func(c Ctx) error { - return c.SendString("debug") - } - - var output bytes.Buffer - - wrapAgent := func(a *Agent) { - a.Debug(&output) - } - - testAgent(t, handler, wrapAgent, "debug", 1) - - str := output.String() - - require.Contains(t, str, "Connected to example.com(InmemoryListener)") - require.Contains(t, str, "GET / HTTP/1.1") - require.Contains(t, str, "User-Agent: fiber") - require.Contains(t, str, "Host: example.com\r\n\r\n") - require.Contains(t, str, "HTTP/1.1 200 OK") - require.Contains(t, str, "Content-Type: text/plain; charset=utf-8\r\nContent-Length: 5\r\n\r\ndebug") -} - -func Test_Client_Agent_Timeout(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - time.Sleep(time.Millisecond * 200) - return c.SendString("timeout") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://example.com"). - Timeout(time.Millisecond * 50) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.ErrorIs(t, errs[0], fasthttp.ErrTimeout) -} - -func Test_Client_Agent_Reuse(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("reuse") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - a := Get("http://example.com"). - Reuse() - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "reuse", body) - require.Empty(t, errs) - - code, body, errs = a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, "reuse", body) - require.Empty(t, errs) -} - -func Test_Client_Agent_InsecureSkipVerify(t *testing.T) { - t.Parallel() - - cer, err := tls.LoadX509KeyPair("./.github/testdata/ssl.pem", "./.github/testdata/ssl.key") - require.NoError(t, err) - - //nolint:gosec // We're in a test so using old ciphers is fine - serverTLSConf := &tls.Config{ - Certificates: []tls.Certificate{cer}, - } - - ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") - require.NoError(t, err) - - ln = tls.NewListener(ln, serverTLSConf) - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("ignore tls") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - code, body, errs := Get("https://" + ln.Addr().String()). - InsecureSkipVerify(). - InsecureSkipVerify(). - String() - - require.Empty(t, errs) - require.Equal(t, StatusOK, code) - require.Equal(t, "ignore tls", body) -} - -func Test_Client_Agent_TLS(t *testing.T) { - t.Parallel() - - serverTLSConf, clientTLSConf, err := tlstest.GetTLSConfigs() - require.NoError(t, err) - - ln, err := net.Listen(NetworkTCP4, "127.0.0.1:0") - require.NoError(t, err) - - ln = tls.NewListener(ln, serverTLSConf) - - app := New() - - app.Get("/", func(c Ctx) error { - return c.SendString("tls") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - code, body, errs := Get("https://" + ln.Addr().String()). - TLSConfig(clientTLSConf). - String() - - require.Empty(t, errs) - require.Equal(t, StatusOK, code) - require.Equal(t, "tls", body) -} - -func Test_Client_Agent_MaxRedirectsCount(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - if c.Request().URI().QueryArgs().Has("foo") { - return c.Redirect().To("/foo") - } - return c.Redirect().To("/") - }) - app.Get("/foo", func(c Ctx) error { - return c.SendString("redirect") - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("success", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com?foo"). - MaxRedirectsCount(1) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, 200, code) - require.Equal(t, "redirect", body) - require.Empty(t, errs) - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com"). - MaxRedirectsCount(1) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - _, body, errs := a.String() - - require.Equal(t, "", body) - require.Len(t, errs, 1) - require.ErrorIs(t, errs[0], fasthttp.ErrTooManyRedirects) - }) -} - -func Test_Client_Agent_Struct(t *testing.T) { - t.Parallel() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", func(c Ctx) error { - return c.JSON(data{true}) - }) - - app.Get("/error", func(c Ctx) error { - return c.SendString(`{"success"`) - }) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - t.Run("success", func(t *testing.T) { - t.Parallel() - - a := Get("http://example.com") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - var d data - - code, body, errs := a.Struct(&d) - - require.Equal(t, StatusOK, code) - require.Equal(t, `{"success":true}`, string(body)) - require.Empty(t, errs) - require.True(t, d.Success) - }) - - t.Run("pre error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com") - - errPre := errors.New("pre errors") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - a.errs = append(a.errs, errPre) - - var d data - _, body, errs := a.Struct(&d) - - require.Equal(t, "", string(body)) - require.Len(t, errs, 1) - require.ErrorIs(t, errs[0], errPre) - require.False(t, d.Success) - }) - - t.Run("error", func(t *testing.T) { - t.Parallel() - a := Get("http://example.com/error") - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - var d data - - code, body, errs := a.JSONDecoder(json.Unmarshal).Struct(&d) - - require.Equal(t, StatusOK, code) - require.Equal(t, `{"success"`, string(body)) - require.Len(t, errs, 1) - wantErr := new(json.SyntaxError) - require.ErrorAs(t, errs[0], &wantErr) - require.EqualValues(t, 10, wantErr.Offset) - }) - - t.Run("nil jsonDecoder", func(t *testing.T) { - t.Parallel() - a := AcquireAgent() - defer ReleaseAgent(a) - defer a.ConnectionClose() - request := a.Request() - request.Header.SetMethod(MethodGet) - request.SetRequestURI("http://example.com") - err := a.Parse() - require.NoError(t, err) - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - var d data - code, body, errs := a.Struct(&d) - require.Equal(t, StatusOK, code) - require.Equal(t, `{"success":true}`, string(body)) - require.Empty(t, errs) - require.True(t, d.Success) - }) -} - -func Test_Client_Agent_Parse(t *testing.T) { - t.Parallel() - - a := Get("https://example.com:10443") - - require.NoError(t, a.Parse()) -} - -func testAgent(t *testing.T, handler Handler, wrapAgent func(agent *Agent), excepted string, count ...int) { - t.Helper() - - ln := fasthttputil.NewInmemoryListener() - - app := New() - - app.Get("/", handler) - - go func() { - require.NoError(t, app.Listener(ln, ListenConfig{ - DisableStartupMessage: true, - })) - }() - - c := 1 - if len(count) > 0 { - c = count[0] - } - - for i := 0; i < c; i++ { - a := Get("http://example.com") - - wrapAgent(a) - - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - - code, body, errs := a.String() - - require.Equal(t, StatusOK, code) - require.Equal(t, excepted, body) - require.Empty(t, errs) - } -} - -type data struct { - Success bool `json:"success" xml:"success"` -} - -type errorMultipartWriter struct { - count int -} - -func (*errorMultipartWriter) Boundary() string { return "myBoundary" } -func (*errorMultipartWriter) SetBoundary(_ string) error { return nil } -func (e *errorMultipartWriter) CreateFormFile(_, _ string) (io.Writer, error) { - if e.count == 0 { - e.count++ - return nil, errors.New("CreateFormFile error") - } - return errorWriter{}, nil -} -func (*errorMultipartWriter) WriteField(_, _ string) error { return errors.New("WriteField error") } -func (*errorMultipartWriter) Close() error { return errors.New("Close error") } - -type errorWriter struct{} - -func (errorWriter) Write(_ []byte) (int, error) { return 0, errors.New("Write error") } diff --git a/listen_test.go b/listen_test.go index d92b9fb3962..a5d419ac869 100644 --- a/listen_test.go +++ b/listen_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/stretchr/testify/require" + "github.com/valyala/fasthttp" "github.com/valyala/fasthttp/fasthttputil" ) @@ -68,22 +69,30 @@ func Test_Listen_Graceful_Shutdown(t *testing.T) { Time time.Duration ExpectedBody string ExpectedStatusCode int - ExceptedErrsLen int + ExpectedErr error }{ - {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExceptedErrsLen: 0}, - {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: 0, ExceptedErrsLen: 1}, + {Time: 100 * time.Millisecond, ExpectedBody: "example.com", ExpectedStatusCode: StatusOK, ExpectedErr: nil}, + {Time: 500 * time.Millisecond, ExpectedBody: "", ExpectedStatusCode: StatusOK, ExpectedErr: errors.New("InmemoryListener is already closed: use of closed network connection")}, } for _, tc := range testCases { time.Sleep(tc.Time) - a := Get("http://example.com") - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() + req := fasthttp.AcquireRequest() + req.SetRequestURI("http://example.com") - require.Equal(t, tc.ExpectedStatusCode, code) - require.Equal(t, tc.ExpectedBody, body) - require.Len(t, errs, tc.ExceptedErrsLen) + client := fasthttp.HostClient{} + client.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } + + resp := fasthttp.AcquireResponse() + err := client.Do(req, resp) + + require.Equal(t, tc.ExpectedErr, err) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + + fasthttp.ReleaseRequest(req) + fasthttp.ReleaseResponse(resp) } mu.Lock() diff --git a/log/default.go b/log/default.go index e9c3d1bbbd9..67fa137ecfd 100644 --- a/log/default.go +++ b/log/default.go @@ -34,6 +34,7 @@ func (l *defaultLogger) privateLog(lv Level, fmtArgs []any) { if lv == LevelPanic { panic(buf.String()) } + buf.Reset() bytebufferpool.Put(buf) if lv == LevelFatal { @@ -56,6 +57,7 @@ func (l *defaultLogger) privateLogf(lv Level, format string, fmtArgs []any) { } else { _, _ = fmt.Fprint(buf, fmtArgs...) } + _ = l.stdlog.Output(l.depth, buf.String()) //nolint:errcheck // It is fine to ignore the error if lv == LevelPanic { panic(buf.String()) diff --git a/middleware/proxy/proxy.go b/middleware/proxy/proxy.go index 284b67c8f57..167a0e6f31f 100644 --- a/middleware/proxy/proxy.go +++ b/middleware/proxy/proxy.go @@ -2,7 +2,6 @@ package proxy import ( "bytes" - "crypto/tls" "net/url" "strings" "sync" @@ -105,13 +104,6 @@ var client = &fasthttp.Client{ var lock sync.RWMutex -// WithTLSConfig update http client with a user specified tls.config -// This function should be called before Do and Forward. -// Deprecated: use WithClient instead. -func WithTLSConfig(tlsConfig *tls.Config) { - client.TLSConfig = tlsConfig -} - // WithClient sets the global proxy client. // This function should be called before Do and Forward. func WithClient(cli *fasthttp.Client) { diff --git a/middleware/proxy/proxy_test.go b/middleware/proxy/proxy_test.go index 408ee71a5f3..4aa0065040a 100644 --- a/middleware/proxy/proxy_test.go +++ b/middleware/proxy/proxy_test.go @@ -11,8 +11,10 @@ import ( "time" "github.com/gofiber/fiber/v3" - "github.com/gofiber/fiber/v3/internal/tlstest" + clientpkg "github.com/gofiber/fiber/v3/client" "github.com/stretchr/testify/require" + + "github.com/gofiber/fiber/v3/internal/tlstest" "github.com/valyala/fasthttp" ) @@ -25,8 +27,6 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str ln, err := net.Listen(fiber.NetworkTCP4, "127.0.0.1:0") require.NoError(t, err) - addr := ln.Addr().String() - go func() { require.NoError(t, target.Listener(ln, fiber.ListenConfig{ DisableStartupMessage: true, @@ -34,6 +34,7 @@ func createProxyTestServer(t *testing.T, handler fiber.Handler) (*fiber.App, str }() time.Sleep(2 * time.Second) + addr := ln.Addr().String() return target, addr } @@ -104,8 +105,8 @@ func Test_Proxy(t *testing.T) { require.Equal(t, fiber.StatusTeapot, resp.StatusCode) } -// go test -run Test_Proxy_Balancer_WithTLSConfig -func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { +// go test -run Test_Proxy_Balancer_WithTlsConfig +func Test_Proxy_Balancer_WithTlsConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() @@ -118,7 +119,7 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { app := fiber.New() - app.Get("/tlsbalaner", func(c fiber.Ctx) error { + app.Get("/tlsbalancer", func(c fiber.Ctx) error { return c.SendString("tls balancer") }) @@ -137,15 +138,18 @@ func Test_Proxy_Balancer_WithTLSConfig(t *testing.T) { })) }() - code, body, errs := fiber.Get("https://" + addr + "/tlsbalaner").TLSConfig(clientTLSConf).String() + client := clientpkg.NewClient() + client.SetTLSConfig(clientTLSConf) - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "tls balancer", body) + resp, err := client.Get("https://" + addr + "/tlsbalancer") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls balancer", string(resp.Body())) + resp.Close() } -// go test -run Test_Proxy_Forward_WithTLSConfig_To_Http -func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) { +// go test -run Test_Proxy_Forward_WithTlsConfig_To_Http +func Test_Proxy_Forward_WithTlsConfig_To_Http(t *testing.T) { t.Parallel() _, targetAddr := createProxyTestServer(t, func(c fiber.Ctx) error { @@ -172,14 +176,15 @@ func Test_Proxy_Forward_WithTLSConfig_To_Http(t *testing.T) { })) }() - code, body, errs := fiber.Get("https://" + proxyAddr). - InsecureSkipVerify(). - Timeout(5 * time.Second). - String() + client := clientpkg.NewClient() + client.SetTimeout(5 * time.Second) + client.TLSConfig().InsecureSkipVerify = true - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "hello from target", body) + resp, err := client.Get("https://" + proxyAddr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "hello from target", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Forward @@ -203,8 +208,8 @@ func Test_Proxy_Forward(t *testing.T) { require.Equal(t, "forwarded", string(b)) } -// go test -run Test_Proxy_Forward_WithTLSConfig -func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { +// go test -run Test_Proxy_Forward_WithClient_TLSConfig +func Test_Proxy_Forward_WithClient_TLSConfig(t *testing.T) { t.Parallel() serverTLSConf, _, err := tlstest.GetTLSConfigs() @@ -225,7 +230,9 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { clientTLSConf := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // We're in a test func, so this is fine // disable certificate verification - WithTLSConfig(clientTLSConf) + WithClient(&fasthttp.Client{ + TLSConfig: clientTLSConf, + }) app.Use(Forward("https://" + addr + "/tlsfwd")) go func() { @@ -234,11 +241,14 @@ func Test_Proxy_Forward_WithTLSConfig(t *testing.T) { })) }() - code, body, errs := fiber.Get("https://" + addr).TLSConfig(clientTLSConf).String() + client := clientpkg.NewClient() + client.SetTLSConfig(clientTLSConf) - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "tls forward", body) + resp, err := client.Get("https://" + addr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "tls forward", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Modify_Response @@ -415,7 +425,7 @@ func Test_Proxy_Do_WithRedirect(t *testing.T) { return Do(c, "https://google.com") }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -431,7 +441,7 @@ func Test_Proxy_DoRedirects_RestoreOriginalURL(t *testing.T) { return DoRedirects(c, "http://google.com", 1) }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) _, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -447,7 +457,7 @@ func Test_Proxy_DoRedirects_TooManyRedirects(t *testing.T) { return DoRedirects(c, "http://google.com", 0) }) - resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil), 1500) + resp, err1 := app.Test(httptest.NewRequest(fiber.MethodGet, "/test", nil)) require.NoError(t, err1) body, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -586,10 +596,13 @@ func Test_Proxy_Forward_Global_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_global_client", body) + client := clientpkg.NewClient() + + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_global_client", string(resp.Body())) + resp.Close() } // go test -race -run Test_Proxy_Forward_Local_Client @@ -615,10 +628,13 @@ func Test_Proxy_Forward_Local_Client(t *testing.T) { })) }() - code, body, errs := fiber.Get("http://" + addr).String() - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_local_client", body) + client := clientpkg.NewClient() + + resp, err := client.Get("http://" + addr) + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_local_client", string(resp.Body())) + resp.Close() } // go test -run Test_ProxyBalancer_Custom_Client @@ -666,7 +682,7 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { app1 := fiber.New() app1.Get("/test", func(c fiber.Ctx) error { - return c.SendString("test_local_client:" + fiber.Query[string](c, "query_test")) + return c.SendString("test_local_client:" + c.Query("query_test")) }) proxyAddr := ln.Addr().String() @@ -679,13 +695,24 @@ func Test_Proxy_Domain_Forward_Local(t *testing.T) { Dial: fasthttp.Dial, })) - go func() { require.NoError(t, app.Listener(ln)) }() - go func() { require.NoError(t, app1.Listener(ln1)) }() + go func() { + require.NoError(t, app.Listener(ln, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + go func() { + require.NoError(t, app1.Listener(ln1, fiber.ListenConfig{ + DisableStartupMessage: true, + })) + }() + + client := clientpkg.NewClient() - code, body, errs := fiber.Get("http://" + localDomain + "/test?query_test=true").String() - require.Empty(t, errs) - require.Equal(t, fiber.StatusOK, code) - require.Equal(t, "test_local_client:true", body) + resp, err := client.Get("http://" + localDomain + "/test?query_test=true") + require.NoError(t, err) + require.Equal(t, fiber.StatusOK, resp.StatusCode()) + require.Equal(t, "test_local_client:true", string(resp.Body())) + resp.Close() } // go test -run Test_Proxy_Balancer_Forward_Local diff --git a/redirect_test.go b/redirect_test.go index d49f5267714..6dd2ae6d198 100644 --- a/redirect_test.go +++ b/redirect_test.go @@ -291,41 +291,45 @@ func Test_Redirect_Request(t *testing.T) { CookieValue string ExpectedBody string ExpectedStatusCode int - ExceptedErrsLen int + ExpectedErr error }{ { URL: "/", CookieValue: "key:value,key2:value2,co\\:m\\,ma:Fi\\:ber\\, v3", ExpectedBody: `{"inputs":{},"messages":{"co:m,ma":"Fi:ber, v3","key":"value","key2":"value2"}}`, ExpectedStatusCode: StatusOK, - ExceptedErrsLen: 0, + ExpectedErr: nil, }, { URL: "/with-inputs?name=john&surname=doe", CookieValue: "key:value,key2:value2,key:value,key2:value2,old_input_data_name:john,old_input_data_surname:doe", ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{"key":"value","key2":"value2"}}`, ExpectedStatusCode: StatusOK, - ExceptedErrsLen: 0, + ExpectedErr: nil, }, { URL: "/just-inputs?name=john&surname=doe", CookieValue: "old_input_data_name:john,old_input_data_surname:doe", ExpectedBody: `{"inputs":{"name":"john","surname":"doe"},"messages":{}}`, ExpectedStatusCode: StatusOK, - ExceptedErrsLen: 0, + ExpectedErr: nil, }, } for _, tc := range testCases { - a := Get("http://example.com" + tc.URL) - a.Cookie(FlashCookieName, tc.CookieValue) - a.MaxRedirectsCount(1) - a.HostClient.Dial = func(_ string) (net.Conn, error) { return ln.Dial() } - code, body, errs := a.String() - - require.Equal(t, tc.ExpectedStatusCode, code) - require.Equal(t, tc.ExpectedBody, body) - require.Len(t, errs, tc.ExceptedErrsLen) + client := &fasthttp.HostClient{ + Dial: func(_ string) (net.Conn, error) { + return ln.Dial() + }, + } + req, resp := fasthttp.AcquireRequest(), fasthttp.AcquireResponse() + req.SetRequestURI("http://example.com" + tc.URL) + req.Header.SetCookie(FlashCookieName, tc.CookieValue) + err := client.DoRedirects(req, resp, 1) + + require.NoError(t, err) + require.Equal(t, tc.ExpectedBody, string(resp.Body())) + require.Equal(t, tc.ExpectedStatusCode, resp.StatusCode()) } } From c5a809f5c5ec52039e064e16ec3c3684d7b70b1c Mon Sep 17 00:00:00 2001 From: Giovanni Rivera Date: Sun, 3 Mar 2024 23:49:25 -0800 Subject: [PATCH 2/2] :broom: v3 (Maintenance): Update docs to reflect fiber.Ctx struct to interface change (#2880) * :broom: [v3 Maintenance]: Update docs to reflect fiber.Ctx struct to interface change Summary: - Update `Static.Next()` in `/docs/api/app.md` to use the `Ctx` interface - Update `/docs/api/ctx.md` to use the `Ctx` interface Related Issues: #2879 * :broom: [v3 Maintenance]: Update Ctx struct description to interface Related Issues: #2879 --- docs/api/ctx.md | 168 ++++++++++++++++++++++++------------------------ 1 file changed, 84 insertions(+), 84 deletions(-) diff --git a/docs/api/ctx.md b/docs/api/ctx.md index 3b12c7b35be..ad709242417 100644 --- a/docs/api/ctx.md +++ b/docs/api/ctx.md @@ -2,7 +2,7 @@ id: ctx title: 🧠 Ctx description: >- - The Ctx struct represents the Context which hold the HTTP request and + The Ctx interface represents the Context which hold the HTTP request and response. It has methods for the request query string, parameters, body, HTTP headers, and so on. sidebar_position: 3 @@ -17,10 +17,10 @@ Based on the request’s [Accept](https://developer.mozilla.org/en-US/docs/Web/H ::: ```go title="Signature" -func (c *Ctx) Accepts(offers ...string) string -func (c *Ctx) AcceptsCharsets(offers ...string) string -func (c *Ctx) AcceptsEncodings(offers ...string) string -func (c *Ctx) AcceptsLanguages(offers ...string) string +func (c Ctx) Accepts(offers ...string) string +func (c Ctx) AcceptsCharsets(offers ...string) string +func (c Ctx) AcceptsEncodings(offers ...string) string +func (c Ctx) AcceptsLanguages(offers ...string) string ``` ```go title="Example" @@ -106,7 +106,7 @@ Params is used to get all route parameters. Using Params method to get params. ```go title="Signature" -func (c *Ctx) AllParams() map[string]string +func (c Ctx) AllParams() map[string]string ``` ```go title="Example" @@ -130,7 +130,7 @@ app.Get("/user/*", func(c fiber.Ctx) error { Returns the [\*App](ctx.md) reference so you could easily access all application settings. ```go title="Signature" -func (c *Ctx) App() *App +func (c Ctx) App() *App ``` ```go title="Example" @@ -148,7 +148,7 @@ If the header is **not** already set, it creates the header with the specified v ::: ```go title="Signature" -func (c *Ctx) Append(field string, values ...string) +func (c Ctx) Append(field string, values ...string) ``` ```go title="Example" @@ -168,7 +168,7 @@ app.Get("/", func(c fiber.Ctx) error { Sets the HTTP response [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) header field to `attachment`. ```go title="Signature" -func (c *Ctx) Attachment(filename ...string) +func (c Ctx) Attachment(filename ...string) ``` ```go title="Example" @@ -196,7 +196,7 @@ If the header is **not** specified or there is **no** proper format, **text/plai ::: ```go title="Signature" -func (c *Ctx) AutoFormat(body any) error +func (c Ctx) AutoFormat(body any) error ``` ```go title="Example" @@ -230,7 +230,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns the base URL \(**protocol** + **host**\) as a `string`. ```go title="Signature" -func (c *Ctx) BaseURL() string +func (c Ctx) BaseURL() string ``` ```go title="Example" @@ -248,7 +248,7 @@ Add vars to default view var map binding to template engine. Variables are read by the Render method and may be overwritten. ```go title="Signature" -func (c *Ctx) Bind(vars Map) error +func (c Ctx) Bind(vars Map) error ``` ```go title="Example" @@ -268,7 +268,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns the raw request **body**. ```go title="Signature" -func (c *Ctx) BodyRaw() []byte +func (c Ctx) BodyRaw() []byte ``` ```go title="Example" @@ -288,7 +288,7 @@ app.Post("/", func(c fiber.Ctx) error { As per the header `Content-Encoding`, this method will try to perform a file decompression from the **body** bytes. In case no `Content-Encoding` header is sent, it will perform as [BodyRaw](#bodyraw). ```go title="Signature" -func (c *Ctx) Body() []byte +func (c Ctx) Body() []byte ``` ```go title="Example" @@ -318,7 +318,7 @@ It is important to specify the correct struct tag based on the content type to b | `text/xml` | xml | ```go title="Signature" -func (c *Ctx) BodyParser(out any) error +func (c Ctx) BodyParser(out any) error ``` ```go title="Example" @@ -362,7 +362,7 @@ app.Post("/", func(c fiber.Ctx) error { Expire a client cookie \(_or all cookies if left empty\)_ ```go title="Signature" -func (c *Ctx) ClearCookie(key ...string) +func (c Ctx) ClearCookie(key ...string) ``` ```go title="Example" @@ -415,7 +415,7 @@ ClientHelloInfo contains information from a ClientHello message in order to guid You can refer to the [ClientHelloInfo](https://golang.org/pkg/crypto/tls/#ClientHelloInfo) struct documentation for more information on the returned struct. ```go title="Signature" -func (c *Ctx) ClientHelloInfo() *tls.ClientHelloInfo +func (c Ctx) ClientHelloInfo() *tls.ClientHelloInfo ``` ```go title="Example" @@ -431,7 +431,7 @@ app.Get("/hello", func(c fiber.Ctx) error { Returns [\*fasthttp.RequestCtx](https://godoc.org/github.com/valyala/fasthttp#RequestCtx) that is compatible with the context.Context interface that requires a deadline, a cancellation signal, and other values across API boundaries. ```go title="Signature" -func (c *Ctx) Context() *fasthttp.RequestCtx +func (c Ctx) Context() *fasthttp.RequestCtx ``` :::info @@ -443,7 +443,7 @@ Please read the [Fasthttp Documentation](https://pkg.go.dev/github.com/valyala/f Set cookie ```go title="Signature" -func (c *Ctx) Cookie(cookie *Cookie) +func (c Ctx) Cookie(cookie *Cookie) ``` ```go @@ -481,7 +481,7 @@ This method is similar to [BodyParser](ctx.md#bodyparser), but for cookie parame It is important to use the struct tag "cookie". For example, if you want to parse a cookie with a field called Age, you would use a struct field of `cookie:"age"`. ```go title="Signature" -func (c *Ctx) CookieParser(out any) error +func (c Ctx) CookieParser(out any) error ``` ```go title="Example" @@ -512,7 +512,7 @@ app.Get("/", func(c fiber.Ctx) error { Get cookie value by key, you could pass an optional default value that will be returned if the cookie key does not exist. ```go title="Signature" -func (c *Ctx) Cookies(key string, defaultValue ...string) string +func (c Ctx) Cookies(key string, defaultValue ...string) string ``` ```go title="Example" @@ -536,7 +536,7 @@ Typically, browsers will prompt the user to download. By default, the [Content-D Override this default with the **filename** parameter. ```go title="Signature" -func (c *Ctx) Download(file string, filename ...string) error +func (c Ctx) Download(file string, filename ...string) error ``` ```go title="Example" @@ -558,7 +558,7 @@ If the Accept header is **not** specified, the first handler will be used. ::: ```go title="Signature" -func (c *Ctx) Format(handlers ...ResFmt) error +func (c Ctx) Format(handlers ...ResFmt) error ``` ```go title="Example" @@ -607,7 +607,7 @@ app.Get("/default", func(c fiber.Ctx) error { MultipartForm files can be retrieved by name, the **first** file from the given key is returned. ```go title="Signature" -func (c *Ctx) FormFile(key string) (*multipart.FileHeader, error) +func (c Ctx) FormFile(key string) (*multipart.FileHeader, error) ``` ```go title="Example" @@ -625,7 +625,7 @@ app.Post("/", func(c fiber.Ctx) error { Any form values can be retrieved by name, the **first** value from the given key is returned. ```go title="Signature" -func (c *Ctx) FormValue(key string, defaultValue ...string) string +func (c Ctx) FormValue(key string, defaultValue ...string) string ``` ```go title="Example" @@ -650,7 +650,7 @@ When a client sends the Cache-Control: no-cache request header to indicate an en Read more on [https://expressjs.com/en/4x/api.html\#req.fresh](https://expressjs.com/en/4x/api.html#req.fresh) ```go title="Signature" -func (c *Ctx) Fresh() bool +func (c Ctx) Fresh() bool ``` ## Get @@ -662,7 +662,7 @@ The match is **case-insensitive**. ::: ```go title="Signature" -func (c *Ctx) Get(key string, defaultValue ...string) string +func (c Ctx) Get(key string, defaultValue ...string) string ``` ```go title="Example" @@ -682,7 +682,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns the HTTP request headers as a map. Since a header can be set multiple times in a single request, the values of the map are slices of strings containing all the different values of the header. ```go title="Signature" -func (c *Ctx) GetReqHeaders() map[string][]string +func (c Ctx) GetReqHeaders() map[string][]string ``` > _Returned value is only valid within the handler. Do not store any references. @@ -697,7 +697,7 @@ The match is **case-insensitive**. ::: ```go title="Signature" -func (c *Ctx) GetRespHeader(key string, defaultValue ...string) string +func (c Ctx) GetRespHeader(key string, defaultValue ...string) string ``` ```go title="Example" @@ -717,7 +717,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns the HTTP response headers as a map. Since a header can be set multiple times in a single request, the values of the map are slices of strings containing all the different values of the header. ```go title="Signature" -func (c *Ctx) GetRespHeaders() map[string][]string +func (c Ctx) GetRespHeaders() map[string][]string ``` > _Returned value is only valid within the handler. Do not store any references. @@ -728,7 +728,7 @@ func (c *Ctx) GetRespHeaders() map[string][]string Generates URLs to named routes, with parameters. URLs are relative, for example: "/user/1831" ```go title="Signature" -func (c *Ctx) GetRouteURL(routeName string, params Map) (string, error) +func (c Ctx) GetRouteURL(routeName string, params Map) (string, error) ``` ```go title="Example" @@ -753,7 +753,7 @@ app.Get("/test", func(c fiber.Ctx) error { Returns the hostname derived from the [Host](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Host) HTTP header. ```go title="Signature" -func (c *Ctx) Hostname() string +func (c Ctx) Hostname() string ``` ```go title="Example" @@ -774,7 +774,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns the remote IP address of the request. ```go title="Signature" -func (c *Ctx) IP() string +func (c Ctx) IP() string ``` ```go title="Example" @@ -798,7 +798,7 @@ app := fiber.New(fiber.Config{ Returns an array of IP addresses specified in the [X-Forwarded-For](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For) request header. ```go title="Signature" -func (c *Ctx) IPs() []string +func (c Ctx) IPs() []string ``` ```go title="Example" @@ -824,7 +824,7 @@ If the request has **no** body, it returns **false**. ::: ```go title="Signature" -func (c *Ctx) Is(extension string) bool +func (c Ctx) Is(extension string) bool ``` ```go title="Example" @@ -844,7 +844,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns true if request came from localhost ```go title="Signature" -func (c *Ctx) IsFromLocal() bool { +func (c Ctx) IsFromLocal() bool { ``` ```go title="Example" @@ -866,7 +866,7 @@ JSON also sets the content header to the `ctype` parameter. If no `ctype` is pas ::: ```go title="Signature" -func (c *Ctx) JSON(data any, ctype ...string) error +func (c Ctx) JSON(data any, ctype ...string) error ``` ```go title="Example" @@ -918,7 +918,7 @@ Sends a JSON response with JSONP support. This method is identical to [JSON](ctx Override this by passing a **named string** in the method. ```go title="Signature" -func (c *Ctx) JSONP(data any, callback ...string) error +func (c Ctx) JSONP(data any, callback ...string) error ``` ```go title="Example" @@ -947,7 +947,7 @@ app.Get("/", func(c fiber.Ctx) error { Joins the links followed by the property to populate the response’s [Link](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Link) HTTP header field. ```go title="Signature" -func (c *Ctx) Links(link ...string) +func (c Ctx) Links(link ...string) ``` ```go title="Example" @@ -972,7 +972,7 @@ This is useful if you want to pass some **specific** data to the next middleware ::: ```go title="Signature" -func (c *Ctx) Locals(key any, value ...any) any +func (c Ctx) Locals(key any, value ...any) any ``` ```go title="Example" @@ -1030,7 +1030,7 @@ over route-specific data within your application. Sets the response [Location](https://developer.mozilla.org/ru/docs/Web/HTTP/Headers/Location) HTTP header to the specified path parameter. ```go title="Signature" -func (c *Ctx) Location(path string) +func (c Ctx) Location(path string) ``` ```go title="Example" @@ -1049,7 +1049,7 @@ Returns a string corresponding to the HTTP method of the request: `GET`, `POST`, Optionally, you could override the method by passing a string. ```go title="Signature" -func (c *Ctx) Method(override ...string) string +func (c Ctx) Method(override ...string) string ``` ```go title="Example" @@ -1068,7 +1068,7 @@ app.Post("/", func(c fiber.Ctx) error { To access multipart form entries, you can parse the binary with `MultipartForm()`. This returns a `map[string][]string`, so given a key, the value will be a string slice. ```go title="Signature" -func (c *Ctx) MultipartForm() (*multipart.Form, error) +func (c Ctx) MultipartForm() (*multipart.Form, error) ``` ```go title="Example" @@ -1107,7 +1107,7 @@ app.Post("/", func(c fiber.Ctx) error { When **Next** is called, it executes the next method in the stack that matches the current route. You can pass an error struct within the method that will end the chaining and call the [error handler](https://docs.gofiber.io/guide/error-handling). ```go title="Signature" -func (c *Ctx) Next() error +func (c Ctx) Next() error ``` ```go title="Example" @@ -1132,7 +1132,7 @@ app.Get("/", func(c fiber.Ctx) error { Returns the original request URL. ```go title="Signature" -func (c *Ctx) OriginalURL() string +func (c Ctx) OriginalURL() string ``` ```go title="Example" @@ -1157,7 +1157,7 @@ Defaults to empty string \(`""`\), if the param **doesn't** exist. ::: ```go title="Signature" -func (c *Ctx) Params(key string, defaultValue ...string) string +func (c Ctx) Params(key string, defaultValue ...string) string ``` ```go title="Example" @@ -1209,7 +1209,7 @@ Defaults to the integer zero \(`0`\), if the param **doesn't** exist. ::: ```go title="Signature" -func (c *Ctx) ParamsInt(key string) (int, error) +func (c Ctx) ParamsInt(key string) (int, error) ``` ```go title="Example" @@ -1229,7 +1229,7 @@ This method is equivalent of using `atoi` with ctx.Params This method is similar to BodyParser, but for path parameters. It is important to use the struct tag "params". For example, if you want to parse a path parameter with a field called Pass, you would use a struct field of params:"pass" ```go title="Signature" -func (c *Ctx) ParamsParser(out any) error +func (c Ctx) ParamsParser(out any) error ``` ```go title="Example" @@ -1249,7 +1249,7 @@ app.Get("/user/:id", func(c fiber.Ctx) error { Contains the path part of the request URL. Optionally, you could override the path by passing a string. For internal redirects, you might want to call [RestartRouting](ctx.md#restartrouting) instead of [Next](ctx.md#next). ```go title="Signature" -func (c *Ctx) Path(override ...string) string +func (c Ctx) Path(override ...string) string ``` ```go title="Example" @@ -1270,7 +1270,7 @@ app.Get("/users", func(c fiber.Ctx) error { Contains the request protocol string: `http` or `https` for **TLS** requests. ```go title="Signature" -func (c *Ctx) Protocol() string +func (c Ctx) Protocol() string ``` ```go title="Example" @@ -1288,7 +1288,7 @@ app.Get("/", func(c fiber.Ctx) error { Queries is a function that returns an object containing a property for each query string parameter in the route. ```go title="Signature" -func (c *Ctx) Queries() map[string]string +func (c Ctx) Queries() map[string]string ``` ```go title="Example" @@ -1356,7 +1356,7 @@ If there is **no** query string, it returns an **empty string**. ::: ```go title="Signature" -func (c *Ctx) Query(key string, defaultValue ...string) string +func (c Ctx) Query(key string, defaultValue ...string) string ``` ```go title="Example" @@ -1417,7 +1417,7 @@ This method is similar to [BodyParser](ctx.md#bodyparser), but for query paramet It is important to use the struct tag "query". For example, if you want to parse a query parameter with a field called Pass, you would use a struct field of `query:"pass"`. ```go title="Signature" -func (c *Ctx) QueryParser(out any) error +func (c Ctx) QueryParser(out any) error ``` ```go title="Example" @@ -1459,7 +1459,7 @@ For more parser settings please look here [Config](fiber.md#config) A struct containing the type and a slice of ranges will be returned. ```go title="Signature" -func (c *Ctx) Range(size int) (Range, error) +func (c Ctx) Range(size int) (Range, error) ``` ```go title="Example" @@ -1484,7 +1484,7 @@ If **not** specified, status defaults to **302 Found**. ::: ```go title="Signature" -func (c *Ctx) Redirect(location string, status ...int) error +func (c Ctx) Redirect(location string, status ...int) error ``` ```go title="Example" @@ -1519,7 +1519,7 @@ If you want to send queries to route, you must add **"queries"** key typed as ** ::: ```go title="Signature" -func (c *Ctx) RedirectToRoute(routeName string, params fiber.Map, status ...int) error +func (c Ctx) RedirectToRoute(routeName string, params fiber.Map, status ...int) error ``` ```go title="Example" @@ -1552,7 +1552,7 @@ If **not** specified, status defaults to **302 Found**. ::: ```go title="Signature" -func (c *Ctx) RedirectBack(fallback string, status ...int) error +func (c Ctx) RedirectBack(fallback string, status ...int) error ``` ```go title="Example" @@ -1574,7 +1574,7 @@ app.Get("/back", func(c fiber.Ctx) error { Renders a view with data and sends a `text/html` response. By default `Render` uses the default [**Go Template engine**](https://pkg.go.dev/html/template/). If you want to use another View engine, please take a look at our [**Template middleware**](https://docs.gofiber.io/template). ```go title="Signature" -func (c *Ctx) Render(name string, bind any, layouts ...string) error +func (c Ctx) Render(name string, bind any, layouts ...string) error ``` ## Request @@ -1582,7 +1582,7 @@ func (c *Ctx) Render(name string, bind any, layouts ...string) error Request return the [\*fasthttp.Request](https://godoc.org/github.com/valyala/fasthttp#Request) pointer ```go title="Signature" -func (c *Ctx) Request() *fasthttp.Request +func (c Ctx) Request() *fasthttp.Request ``` ```go title="Example" @@ -1598,7 +1598,7 @@ This method is similar to [BodyParser](ctx.md#bodyparser), but for request heade It is important to use the struct tag "reqHeader". For example, if you want to parse a request header with a field called Pass, you would use a struct field of `reqHeader:"pass"`. ```go title="Signature" -func (c *Ctx) ReqHeaderParser(out any) error +func (c Ctx) ReqHeaderParser(out any) error ``` ```go title="Example" @@ -1632,7 +1632,7 @@ app.Get("/", func(c fiber.Ctx) error { Response return the [\*fasthttp.Response](https://godoc.org/github.com/valyala/fasthttp#Response) pointer ```go title="Signature" -func (c *Ctx) Response() *fasthttp.Response +func (c Ctx) Response() *fasthttp.Response ``` ```go title="Example" @@ -1648,7 +1648,7 @@ app.Get("/", func(c fiber.Ctx) error { Instead of executing the next method when calling [Next](ctx.md#next), **RestartRouting** restarts execution from the first method that matches the current route. This may be helpful after overriding the path, i. e. an internal redirect. Note that handlers might be executed again which could result in an infinite loop. ```go title="Signature" -func (c *Ctx) RestartRouting() error +func (c Ctx) RestartRouting() error ``` ```go title="Example" @@ -1667,7 +1667,7 @@ app.Get("/old", func(c fiber.Ctx) error { Returns the matched [Route](https://pkg.go.dev/github.com/gofiber/fiber?tab=doc#Route) struct. ```go title="Signature" -func (c *Ctx) Route() *Route +func (c Ctx) Route() *Route ``` ```go title="Example" @@ -1703,7 +1703,7 @@ func MyMiddleware() fiber.Handler { Method is used to save **any** multipart file to disk. ```go title="Signature" -func (c *Ctx) SaveFile(fh *multipart.FileHeader, path string) error +func (c Ctx) SaveFile(fh *multipart.FileHeader, path string) error ``` ```go title="Example" @@ -1736,7 +1736,7 @@ app.Post("/", func(c fiber.Ctx) error { Method is used to save **any** multipart file to an external storage system. ```go title="Signature" -func (c *Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error +func (c Ctx) SaveFileToStorage(fileheader *multipart.FileHeader, path string, storage Storage) error ``` ```go title="Example" @@ -1771,7 +1771,7 @@ app.Post("/", func(c fiber.Ctx) error { A boolean property that is `true` , if a **TLS** connection is established. ```go title="Signature" -func (c *Ctx) Secure() bool +func (c Ctx) Secure() bool ``` ```go title="Example" @@ -1784,7 +1784,7 @@ c.Protocol() == "https" Sets the HTTP response body. ```go title="Signature" -func (c *Ctx) Send(body []byte) error +func (c Ctx) Send(body []byte) error ``` ```go title="Example" @@ -1800,8 +1800,8 @@ Use this if you **don't need** type assertion, recommended for **faster** perfor ::: ```go title="Signature" -func (c *Ctx) SendString(body string) error -func (c *Ctx) SendStream(stream io.Reader, size ...int) error +func (c Ctx) SendString(body string) error +func (c Ctx) SendStream(stream io.Reader, size ...int) error ``` ```go title="Example" @@ -1823,7 +1823,7 @@ Method doesnΒ΄t use **gzipping** by default, set it to **true** to enable. ::: ```go title="Signature" title="Signature" -func (c *Ctx) SendFile(file string, compress ...bool) error +func (c Ctx) SendFile(file string, compress ...bool) error ``` ```go title="Example" @@ -1858,7 +1858,7 @@ You can find all used status codes and messages [here](https://github.com/gofibe ::: ```go title="Signature" -func (c *Ctx) SendStatus(status int) error +func (c Ctx) SendStatus(status int) error ``` ```go title="Example" @@ -1877,7 +1877,7 @@ app.Get("/not-found", func(c fiber.Ctx) error { Sets the response’s HTTP header field to the specified `key`, `value`. ```go title="Signature" -func (c *Ctx) Set(key string, val string) +func (c Ctx) Set(key string, val string) ``` ```go title="Example" @@ -1968,7 +1968,7 @@ app.Get("/query", func(c fiber.Ctx) error { Sets the user specified implementation for context interface. ```go title="Signature" -func (c *Ctx) SetUserContext(ctx context.Context) +func (c Ctx) SetUserContext(ctx context.Context) ``` ```go title="Example" @@ -1986,7 +1986,7 @@ app.Get("/", func(c fiber.Ctx) error { [https://expressjs.com/en/4x/api.html\#req.stale](https://expressjs.com/en/4x/api.html#req.stale) ```go title="Signature" -func (c *Ctx) Stale() bool +func (c Ctx) Stale() bool ``` ## Status @@ -1998,7 +1998,7 @@ Method is a **chainable**. ::: ```go title="Signature" -func (c *Ctx) Status(status int) *Ctx +func (c Ctx) Status(status int) Ctx ``` ```go title="Example" @@ -2023,7 +2023,7 @@ Returns a string slice of subdomains in the domain name of the request. The application property subdomain offset, which defaults to `2`, is used for determining the beginning of the subdomain segments. ```go title="Signature" -func (c *Ctx) Subdomains(offset ...int) []string +func (c Ctx) Subdomains(offset ...int) []string ``` ```go title="Example" @@ -2042,7 +2042,7 @@ app.Get("/", func(c fiber.Ctx) error { Sets the [Content-Type](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Type) HTTP header to the MIME type listed [here](https://github.com/nginx/nginx/blob/master/conf/mime.types) specified by the file **extension**. ```go title="Signature" -func (c *Ctx) Type(ext string, charset ...string) *Ctx +func (c Ctx) Type(ext string, charset ...string) Ctx ``` ```go title="Example" @@ -2063,7 +2063,7 @@ UserContext returns a context implementation that was set by user earlier or returns a non-nil, empty context, if it was not set earlier. ```go title="Signature" -func (c *Ctx) UserContext() context.Context +func (c Ctx) UserContext() context.Context ``` ```go title="Example" @@ -2084,7 +2084,7 @@ Multiple fields are **allowed**. ::: ```go title="Signature" -func (c *Ctx) Vary(fields ...string) +func (c Ctx) Vary(fields ...string) ``` ```go title="Example" @@ -2107,7 +2107,7 @@ app.Get("/", func(c fiber.Ctx) error { Write adopts the Writer interface ```go title="Signature" -func (c *Ctx) Write(p []byte) (n int, err error) +func (c Ctx) Write(p []byte) (n int, err error) ``` ```go title="Example" @@ -2123,7 +2123,7 @@ app.Get("/", func(c fiber.Ctx) error { Writef adopts the string with variables ```go title="Signature" -func (c *Ctx) Writef(f string, a ...any) (n int, err error) +func (c Ctx) Writef(f string, a ...any) (n int, err error) ``` ```go title="Example" @@ -2140,7 +2140,7 @@ app.Get("/", func(c fiber.Ctx) error { WriteString adopts the string ```go title="Signature" -func (c *Ctx) WriteString(s string) (n int, err error) +func (c Ctx) WriteString(s string) (n int, err error) ``` ```go title="Example" @@ -2156,7 +2156,7 @@ app.Get("/", func(c fiber.Ctx) error { A Boolean property, that is `true`, if the request’s [X-Requested-With](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers) header field is [XMLHttpRequest](https://developer.mozilla.org/en-US/docs/Web/API/XMLHttpRequest), indicating that the request was issued by a client library \(such as [jQuery](https://api.jquery.com/jQuery.ajax/)\). ```go title="Signature" -func (c *Ctx) XHR() bool +func (c Ctx) XHR() bool ``` ```go title="Example" @@ -2178,7 +2178,7 @@ XML also sets the content header to **application/xml**. ::: ```go title="Signature" -func (c *Ctx) XML(data any) error +func (c Ctx) XML(data any) error ``` ```go title="Example"