diff --git a/.golangci.yaml b/.golangci.yaml index c66f1369a..009dd1b3c 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -3,6 +3,8 @@ run: modules-download-mode: vendor timeout: 4m issues: + exclude-dirs: + - maintenance/errors/raven exclude-use-default: false include: - EXC0005 diff --git a/.golangci.yml b/.golangci.yml deleted file mode 100644 index 1f771c922..000000000 --- a/.golangci.yml +++ /dev/null @@ -1,11 +0,0 @@ -# for further options, see https://github.com/golangci/golangci-lint/blob/master/.golangci.example.yml -run: - modules-download-mode: vendor - timeout: 4m -issues: - include: - - EXC0005 -linters: - enable: - - gofumpt -linters-settings: diff --git a/backend/couchdb/auth.go b/backend/couchdb/auth.go index 0b44af8bd..17f2e969e 100644 --- a/backend/couchdb/auth.go +++ b/backend/couchdb/auth.go @@ -14,8 +14,8 @@ func (l *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { return l.transport.RoundTrip(req) } -func (rt *AuthTransport) Transport() http.RoundTripper { - return rt.transport +func (l *AuthTransport) Transport() http.RoundTripper { + return l.transport } func (l *AuthTransport) SetTransport(rt http.RoundTripper) { diff --git a/backend/couchdb/db.go b/backend/couchdb/db.go index 6bdd1481d..0cb7b0151 100644 --- a/backend/couchdb/db.go +++ b/backend/couchdb/db.go @@ -64,6 +64,7 @@ func clientAndDB(ctx context.Context, dbName string, cfg *Config) (*kivik.Client if db.Err() != nil { return nil, nil, db.Err() } + return client, db, err } @@ -74,6 +75,7 @@ func Client(cfg *Config) (*kivik.Client, error) { if err != nil { return nil, err } + rts := []transport.ChainableRoundTripper{ &AuthTransport{ Username: cfg.User, @@ -85,8 +87,10 @@ func Client(cfg *Config) (*kivik.Client, error) { if !cfg.DisableRequestLogging { rts = append(rts, &transport.LoggingRoundTripper{}) } + chain := transport.Chain(rts...) tr := couchdb.SetTransport(chain) + err = client.Authenticate(ctx, tr) if err != nil { return nil, err @@ -97,9 +101,11 @@ func Client(cfg *Config) (*kivik.Client, error) { func ParseConfig() (*Config, error) { var cfg Config + err := env.Parse(&cfg) if err != nil { return nil, err } + return &cfg, nil } diff --git a/backend/couchdb/health_check.go b/backend/couchdb/health_check.go index c60e3fc52..eca644330 100644 --- a/backend/couchdb/health_check.go +++ b/backend/couchdb/health_check.go @@ -7,7 +7,9 @@ import ( "time" kivik "github.com/go-kivik/kivik/v3" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" + "github.com/pace/bricks/maintenance/log" ) // HealthCheck checks the state of the object storage client. It must not be changed @@ -38,7 +40,9 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health checkTime := time.Now() var doc Doc + var err error + var row *kivik.Row check: @@ -55,10 +59,17 @@ check: if kivik.StatusCode(row.Err) == http.StatusNotFound { goto put } + h.state.SetErrorState(fmt.Errorf("failed to get: %#v", row.Err)) + return h.state.GetState() } - defer row.Body.Close() + + defer func() { + if err := row.Body.Close(); err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("Failed closing body") + } + }() // check if document exists if row.Rev != "" { @@ -77,7 +88,9 @@ check: put: // update document doc.ID = h.Config.HealthCheckKey + doc.Time = time.Now().Format(healthCheckTimeFormat) + _, err = h.DB.Put(ctx, h.Config.HealthCheckKey, doc) if err != nil { // not yet created, try to create @@ -87,13 +100,16 @@ put: h.state.SetErrorState(fmt.Errorf("failed to put object: %v", err)) return h.state.GetState() } + goto put } if kivik.StatusCode(err) == http.StatusConflict { goto check } + h.state.SetErrorState(fmt.Errorf("failed to put object: %v", err)) + return h.state.GetState() } @@ -103,6 +119,7 @@ put: healthy: // If uploading and downloading worked set the Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } @@ -124,7 +141,7 @@ func wasConcurrentHealthCheck(checkTime time.Time, observedValue string) bool { allowedEnd := checkTime.Add(healthCheckConcurrentSpan) // timestamp we got from the document is in allowed range - // concider it healthy + // consider it healthy return t.After(allowedStart) && t.Before(allowedEnd) } diff --git a/backend/k8sapi/client.go b/backend/k8sapi/client.go index 78ffda131..32ebb4250 100644 --- a/backend/k8sapi/client.go +++ b/backend/k8sapi/client.go @@ -15,6 +15,7 @@ import ( "strings" "github.com/caarlos0/env/v10" + "github.com/pace/bricks/http/transport" "github.com/pace/bricks/maintenance/log" ) @@ -40,6 +41,7 @@ func NewClient() (*Client, error) { if err != nil { return nil, err } + cl.Podname = hostname // parse environment including secrets mounted by kubernetes @@ -52,32 +54,38 @@ func NewClient() (*Client, error) { if err != nil { return nil, fmt.Errorf("failed to read %q: %v", cl.cfg.CACertFile, err) } + cl.CACert = []byte(strings.TrimSpace(string(caData))) namespaceData, err := os.ReadFile(cl.cfg.NamespaceFile) if err != nil { return nil, fmt.Errorf("failed to read %q: %v", cl.cfg.NamespaceFile, err) } + cl.Namespace = strings.TrimSpace(string(namespaceData)) tokenData, err := os.ReadFile(cl.cfg.TokenFile) if err != nil { return nil, fmt.Errorf("failed to read %q: %v", cl.cfg.CACertFile, err) } + cl.Token = strings.TrimSpace(string(tokenData)) // add kubernetes api server cert chain := transport.NewDefaultTransportChain() pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(cl.CACert) if !ok { return nil, fmt.Errorf("failed to load kubernetes ca cert") } + chain.Final(&http.Transport{ TLSClientConfig: &tls.Config{ RootCAs: pool, }, }) + cl.HttpClient.Transport = chain return &cl, nil @@ -107,8 +115,9 @@ func (c *Client) SimpleRequest(ctx context.Context, method, url string, requestO defer resp.Body.Close() if resp.StatusCode > 299 { - body, _ := io.ReadAll(resp.Body) // nolint: errcheck + body, _ := io.ReadAll(resp.Body) log.Ctx(ctx).Debug().Msgf("failed to do api request, due to: %s", string(body)) + return fmt.Errorf("k8s request failed with %s", resp.Status) } diff --git a/backend/k8sapi/pod.go b/backend/k8sapi/pod.go index 2d0a8ebac..3fb2af2da 100644 --- a/backend/k8sapi/pod.go +++ b/backend/k8sapi/pod.go @@ -30,6 +30,7 @@ func (c *Client) SetPodLabel(ctx context.Context, namespace, podname, label, val } url := fmt.Sprintf("https://%s:%d/api/v1/namespaces/%s/pods/%s", c.cfg.Host, c.cfg.Port, namespace, podname) + var resp interface{} return c.SimpleRequest(ctx, http.MethodPatch, url, &pr, &resp) diff --git a/backend/objstore/health_objstore.go b/backend/objstore/health_objstore.go index fd268d614..0c034b823 100644 --- a/backend/objstore/health_objstore.go +++ b/backend/objstore/health_objstore.go @@ -8,6 +8,7 @@ import ( "time" "github.com/minio/minio-go/v7" + "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" @@ -58,6 +59,7 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health defer func() { go func() { defer errors.HandleWithCtx(ctx, "HealthCheck remove s3 object version") + ctx := log.WithContext(context.Background()) err = h.Client.RemoveObject( @@ -91,7 +93,12 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health h.state.SetErrorState(fmt.Errorf("failed to get object: %v", err)) return h.state.GetState() } - defer obj.Close() + + defer func() { + if err := obj.Close(); err != nil { + log.Ctx(ctx).Debug().Err(err).Msg("Failed closing object") + } + }() // Assert expectations buf, err := io.ReadAll(obj) @@ -106,12 +113,14 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health } h.state.SetErrorState(fmt.Errorf("unexpected content: %q <-> %q", string(buf), string(expContent))) + return h.state.GetState() } healthy: // If uploading and downloading worked set the Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } @@ -127,7 +136,7 @@ func wasConcurrentHealthCheck(checkTime time.Time, observedValue string) bool { allowedEnd := checkTime.Add(healthCheckConcurrentSpan) // timestamp we got from the document is in allowed range - // concider it healthy + // consider it healthy return t.After(allowedStart) && t.Before(allowedEnd) } diff --git a/backend/objstore/health_objstore_test.go b/backend/objstore/health_objstore_test.go index 9c0eda18c..7035c5f25 100644 --- a/backend/objstore/health_objstore_test.go +++ b/backend/objstore/health_objstore_test.go @@ -8,19 +8,20 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + http2 "github.com/pace/bricks/http" "github.com/pace/bricks/maintenance/log" - "github.com/stretchr/testify/assert" ) func setup() *http.Response { r := http2.Router() rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/check", nil) + req := httptest.NewRequest(http.MethodGet, "/health/check", nil) r.ServeHTTP(rec, req) - resp := rec.Result() - defer resp.Body.Close() - return resp + + return rec.Result() + } // TestIntegrationHealthCheck tests if object storage health check ist working like expected @@ -28,10 +29,19 @@ func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + RegisterHealthchecks() time.Sleep(1 * time.Second) // by the magic of asynchronous code, I here-by present a magic wait + resp := setup() - if resp.StatusCode != 200 { + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Expected /health/check to respond with 200, got: %d", resp.StatusCode) } @@ -39,6 +49,7 @@ func TestIntegrationHealthCheck(t *testing.T) { if err != nil { log.Fatal(err) } + if !strings.Contains(string(data), "objstore OK") { t.Errorf("Expected /health/check to return OK, got: %s", string(data)) } @@ -46,6 +57,7 @@ func TestIntegrationHealthCheck(t *testing.T) { func TestConcurrentHealth(t *testing.T) { ct := time.Date(2020, 12, 16, 15, 30, 46, 0, time.UTC) + tests := []struct { name string checkTime time.Time diff --git a/backend/objstore/objstore.go b/backend/objstore/objstore.go index 36a1b596f..102be0c04 100644 --- a/backend/objstore/objstore.go +++ b/backend/objstore/objstore.go @@ -38,10 +38,11 @@ func Client() (*minio.Client, error) { return DefaultClientFromEnv() } -// Client with environment based configuration. Registers healthchecks automatically. If yo do not want to use healthchecks +// DefaultClientFromEnv with environment based configuration. Registers healthchecks automatically. If yo do not want to use healthchecks // consider calling CustomClient. func DefaultClientFromEnv() (*minio.Client, error) { registerHealthchecks() + return CustomClient(cfg.Endpoint, &minio.Options{ Secure: cfg.UseSSL, Region: cfg.Region, @@ -53,14 +54,17 @@ func DefaultClientFromEnv() (*minio.Client, error) { // CustomClient with customized client func CustomClient(endpoint string, opts *minio.Options) (*minio.Client, error) { opts.Transport = newCustomTransport(endpoint) + client, err := minio.New(endpoint, opts) if err != nil { return nil, err } + log.Logger().Info().Str("endpoint", endpoint). Str("region", opts.Region). Bool("ssl", opts.Secure). Msg("S3 connection created") + return client, nil } @@ -94,6 +98,7 @@ func registerHealthchecks() { if err != nil { log.Warnf("Failed to create check for bucket: %v", err) } + if !ok { err := client.MakeBucket(ctx, cfg.HealthCheckBucketName, minio.MakeBucketOptions{ Region: cfg.Region, @@ -102,6 +107,7 @@ func registerHealthchecks() { log.Warnf("Failed to create bucket: %v", err) } } + servicehealthcheck.RegisterHealthCheck("objstore", &HealthCheck{ Client: client, }) diff --git a/backend/postgres/errors.go b/backend/postgres/errors.go index b72d5ee20..a7a2f9220 100644 --- a/backend/postgres/errors.go +++ b/backend/postgres/errors.go @@ -34,5 +34,6 @@ func IsErrConnectionFailed(err error) bool { return true } } + return false } diff --git a/backend/postgres/health_postgres.go b/backend/postgres/health_postgres.go index 5b5dfe866..dec6ae4ab 100644 --- a/backend/postgres/health_postgres.go +++ b/backend/postgres/health_postgres.go @@ -7,6 +7,7 @@ import ( "time" "github.com/go-pg/pg/orm" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" ) @@ -55,6 +56,7 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health } // If no error occurred set the State of this Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } diff --git a/backend/postgres/health_postgres_test.go b/backend/postgres/health_postgres_test.go index 2bb978fd6..f540c46e0 100644 --- a/backend/postgres/health_postgres_test.go +++ b/backend/postgres/health_postgres_test.go @@ -12,20 +12,23 @@ import ( "time" "github.com/go-pg/pg/orm" + "github.com/stretchr/testify/require" + http2 "github.com/pace/bricks/http" "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" - "github.com/stretchr/testify/require" ) func setup() *http.Response { r := http2.Router() rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/check", nil) + req := httptest.NewRequest(http.MethodGet, "/health/check", nil) r.ServeHTTP(rec, req) + resp := rec.Result() defer resp.Body.Close() + return resp } @@ -33,9 +36,18 @@ func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + time.Sleep(1 * time.Second) // by the magic of asynchronous code, I here-by present a magic wait + resp := setup() - if resp.StatusCode != 200 { + + defer func() { + if err := resp.Body.Close(); err != nil { + log.Println(err) + } + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Expected /health/check to respond with 200, got: %d", resp.StatusCode) } @@ -43,6 +55,7 @@ func TestIntegrationHealthCheck(t *testing.T) { if err != nil { log.Fatal(err) } + if !strings.Contains(string(data[:]), "postgresdefault OK") { t.Errorf("Expected /health/check to return OK, got: %q", string(data[:])) } @@ -68,6 +81,7 @@ func TestHealthCheckCaching(t *testing.T) { // get the error for the first time require.Equal(t, servicehealthcheck.Err, res.State) require.Equal(t, "TestHealthCheckCaching", res.Msg) + res = h.HealthCheck(ctx) pool.err = nil // getting the cached error diff --git a/backend/postgres/metrics.go b/backend/postgres/metrics.go index 69766aaf3..555b6d93e 100644 --- a/backend/postgres/metrics.go +++ b/backend/postgres/metrics.go @@ -75,6 +75,7 @@ func NewConnectionPoolMetrics() *ConnectionPoolMetrics { []string{"database", "pool"}, ), } + return &m } @@ -121,7 +122,9 @@ func (m *ConnectionPoolMetrics) ObserveRegularly(ctx context.Context, db *pg.DB, // cleaning up the related resources. go func() { ticker := time.NewTicker(time.Minute) + defer close(trigger) + for { select { case <-ticker.C: @@ -143,7 +146,7 @@ func (m *ConnectionPoolMetrics) ObserveRegularly(ctx context.Context, db *pg.DB, } // ObserveWhenTriggered starts observing the given postgres pool. The pool name -// behaves as decribed for the ObserveRegularly method. The metrics are observed +// behaves as described for the ObserveRegularly method. The metrics are observed // for every emitted value from the trigger channel. The trigger channel allows // passing a response channel that will be closed once the metrics were // collected. It is also possible to pass nil. You should close the trigger @@ -152,13 +155,16 @@ func (m *ConnectionPoolMetrics) ObserveWhenTriggered(trigger <-chan chan<- struc // check that pool name is unique m.poolMetricsMx.Lock() defer m.poolMetricsMx.Unlock() + if _, ok := m.poolMetrics[poolName]; ok { return fmt.Errorf("invalid pool name: %q: %w", poolName, ErrNotUnique) } + m.poolMetrics[poolName] = struct{}{} // start goroutine go m.gatherConnectionPoolMetrics(trigger, db, poolName) + return nil } @@ -188,6 +194,7 @@ func (m *ConnectionPoolMetrics) gatherConnectionPoolMetrics(trigger <-chan chan< if done != nil { close(done) } + prevStats = *stats } } diff --git a/backend/postgres/metrics_test.go b/backend/postgres/metrics_test.go index 98ecd8ffa..d9cc27310 100644 --- a/backend/postgres/metrics_test.go +++ b/backend/postgres/metrics_test.go @@ -5,6 +5,7 @@ package postgres_test import ( "context" "errors" + "net/http" "net/http/httptest" "testing" "time" @@ -25,6 +26,7 @@ func ExampleConnectionPoolMetrics() { if err := metrics.ObserveRegularly(context.Background(), myDB, "my_db"); err != nil { panic(err) } + prometheus.MustRegister(metrics) } @@ -36,6 +38,7 @@ func TestIntegrationConnectionPoolMetrics(t *testing.T) { metricsRegistry := prometheus.NewRegistry() metrics := NewConnectionPoolMetrics() metricsRegistry.MustRegister(metrics) + db := ConnectionPool() trigger := make(chan chan<- struct{}) err := metrics.ObserveWhenTriggered(trigger, db, "test") @@ -44,6 +47,7 @@ func TestIntegrationConnectionPoolMetrics(t *testing.T) { if _, err := db.Exec(`SELECT 1;`); err != nil { t.Fatalf("could not query postgres database: %s", err) } + whenDone := make(chan struct{}) select { case trigger <- whenDone: @@ -58,7 +62,8 @@ func TestIntegrationConnectionPoolMetrics(t *testing.T) { // query metrics resp := httptest.NewRecorder() handler := promhttp.HandlerFor(metricsRegistry, promhttp.HandlerOpts{}) - handler.ServeHTTP(resp, httptest.NewRequest("GET", "/metrics", nil)) + handler.ServeHTTP(resp, httptest.NewRequest(http.MethodGet, "/metrics", nil)) + body := resp.Body.String() assert.Regexp(t, `pace_postgres_connection_pool_hits.*?\Wpool="test"\W`, body) assert.Regexp(t, `pace_postgres_connection_pool_misses.*?\Wpool="test"\W`, body) diff --git a/backend/postgres/options_test.go b/backend/postgres/options_test.go index 0ea808de4..08b131c8b 100644 --- a/backend/postgres/options_test.go +++ b/backend/postgres/options_test.go @@ -10,7 +10,9 @@ import ( func TestWithApplicationName(t *testing.T) { param := "ApplicationName" + var conf Config + f := WithApplicationName(param) f(&conf) require.Equal(t, conf.ApplicationName, param) @@ -18,7 +20,9 @@ func TestWithApplicationName(t *testing.T) { func TestWithDatabase(t *testing.T) { param := "Database" + var conf Config + f := WithDatabase(param) f(&conf) require.Equal(t, conf.Database, param) @@ -26,7 +30,9 @@ func TestWithDatabase(t *testing.T) { func TestWithDialTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithDialTimeout(param) f(&conf) require.Equal(t, conf.DialTimeout, param) @@ -34,7 +40,9 @@ func TestWithDialTimeout(t *testing.T) { func TestWithHealthCheckResultTTL(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithHealthCheckResultTTL(param) f(&conf) require.Equal(t, conf.HealthCheckResultTTL, param) @@ -42,7 +50,9 @@ func TestWithHealthCheckResultTTL(t *testing.T) { func TestWithHealthCheckTableName(t *testing.T) { param := "HealthCheckTableName" + var conf Config + f := WithHealthCheckTableName(param) f(&conf) require.Equal(t, conf.HealthCheckTableName, param) @@ -50,7 +60,9 @@ func TestWithHealthCheckTableName(t *testing.T) { func TestWithHost(t *testing.T) { param := "Host" + var conf Config + f := WithHost(param) f(&conf) require.Equal(t, conf.Host, param) @@ -58,7 +70,9 @@ func TestWithHost(t *testing.T) { func TestWithIdleCheckFrequency(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithIdleCheckFrequency(param) f(&conf) require.Equal(t, conf.IdleCheckFrequency, param) @@ -66,7 +80,9 @@ func TestWithIdleCheckFrequency(t *testing.T) { func TestWithIdleTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithIdleTimeout(param) f(&conf) require.Equal(t, conf.IdleTimeout, param) @@ -74,7 +90,9 @@ func TestWithIdleTimeout(t *testing.T) { func TestWithMaxConnAge(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithMaxConnAge(param) f(&conf) require.Equal(t, conf.MaxConnAge, param) @@ -82,7 +100,9 @@ func TestWithMaxConnAge(t *testing.T) { func TestWithMaxRetries(t *testing.T) { param := 42 + var conf Config + f := WithMaxRetries(param) f(&conf) require.Equal(t, conf.MaxRetries, param) @@ -90,7 +110,9 @@ func TestWithMaxRetries(t *testing.T) { func TestWithMaxRetryBackoff(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithMaxRetryBackoff(param) f(&conf) require.Equal(t, conf.MaxRetryBackoff, param) @@ -98,7 +120,9 @@ func TestWithMaxRetryBackoff(t *testing.T) { func TestWithMinIdleConns(t *testing.T) { param := 42 + var conf Config + f := WithMinIdleConns(param) f(&conf) require.Equal(t, conf.MinIdleConns, param) @@ -106,7 +130,9 @@ func TestWithMinIdleConns(t *testing.T) { func TestWithMinRetryBackoff(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithMinRetryBackoff(param) f(&conf) require.Equal(t, conf.MinRetryBackoff, param) @@ -114,7 +140,9 @@ func TestWithMinRetryBackoff(t *testing.T) { func TestWithPassword(t *testing.T) { param := "Password" + var conf Config + f := WithPassword(param) f(&conf) require.Equal(t, conf.Password, param) @@ -122,7 +150,9 @@ func TestWithPassword(t *testing.T) { func TestWithPoolSize(t *testing.T) { param := 42 + var conf Config + f := WithPoolSize(param) f(&conf) require.Equal(t, conf.PoolSize, param) @@ -130,7 +160,9 @@ func TestWithPoolSize(t *testing.T) { func TestWithPoolTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithPoolTimeout(param) f(&conf) require.Equal(t, conf.PoolTimeout, param) @@ -138,7 +170,9 @@ func TestWithPoolTimeout(t *testing.T) { func TestWithPort(t *testing.T) { param := 42 + var conf Config + f := WithPort(param) f(&conf) require.Equal(t, conf.Port, param) @@ -146,7 +180,9 @@ func TestWithPort(t *testing.T) { func TestWithReadTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithReadTimeout(param) f(&conf) require.Equal(t, conf.ReadTimeout, param) @@ -154,7 +190,9 @@ func TestWithReadTimeout(t *testing.T) { func TestWithRetryStatementTimeout(t *testing.T) { param := true + var conf Config + f := WithRetryStatementTimeout(param) f(&conf) require.Equal(t, conf.RetryStatementTimeout, param) @@ -162,7 +200,9 @@ func TestWithRetryStatementTimeout(t *testing.T) { func TestWithUser(t *testing.T) { param := "User" + var conf Config + f := WithUser(param) f(&conf) require.Equal(t, conf.User, param) @@ -170,7 +210,9 @@ func TestWithUser(t *testing.T) { func TestWithWriteTimeout(t *testing.T) { param := 5 * time.Second + var conf Config + f := WithWriteTimeout(param) f(&conf) require.Equal(t, conf.WriteTimeout, param) @@ -194,7 +236,9 @@ func TestWithLogReadWriteOnly(t *testing.T) { for _, tc := range cases { read := tc[0] write := tc[1] + var conf Config + f := WithQueryLogging(read, write) f(&conf) assert.Equal(t, conf.LogRead, read) diff --git a/backend/postgres/postgres.go b/backend/postgres/postgres.go index ec0b27af0..572bf7fbe 100644 --- a/backend/postgres/postgres.go +++ b/backend/postgres/postgres.go @@ -14,13 +14,12 @@ import ( "sync" "time" - "github.com/opentracing/opentracing-go" - olog "github.com/opentracing/opentracing-go/log" - "github.com/rs/zerolog" - "github.com/caarlos0/env/v10" "github.com/go-pg/pg" + "github.com/opentracing/opentracing-go" + olog "github.com/opentracing/opentracing-go/log" "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" @@ -34,7 +33,7 @@ type Config struct { Database string `env:"POSTGRES_DB" envDefault:"postgres"` // ApplicationName is the application name. Used in logs on Pg side. - // Only availaible from pg-9.0. + // Only available from pg-9.0. ApplicationName string `env:"POSTGRES_APPLICATION_NAME" envDefault:"-"` // Maximum number of retries before giving up. MaxRetries int `env:"POSTGRES_MAX_RETRIES" envDefault:"5"` @@ -169,18 +168,22 @@ var ( // logging and metrics. func DefaultConnectionPool() *pg.DB { var err error + defaultPoolOnce.Do(func() { if defaultPool == nil { defaultPool = ConnectionPool() // add metrics metrics := NewConnectionPoolMetrics() prometheus.MustRegister(metrics) + err = metrics.ObserveRegularly(context.Background(), defaultPool, "default") } }) + if err != nil { panic(err) } + return defaultPool } @@ -231,6 +234,7 @@ func CustomConnectionPool(opts *pg.Options) *pg.DB { Str("database", opts.Database). Str("as", opts.ApplicationName). Msg("PostgreSQL connection pool created") + db := pg.Connect(opts) if cfg.LogWrite || cfg.LogRead { db.OnQueryProcessed(queryLogger) @@ -259,6 +263,7 @@ func determineQueryMode(qry string) queryMode { if strings.HasPrefix(strings.ToLower(strings.TrimSpace(qry)), "select") { return readMode } + return writeMode } @@ -268,16 +273,18 @@ func queryLogger(event *pg.QueryProcessedEvent) { if !(cfg.LogRead || cfg.LogWrite) { return } - // we can only and should only perfom the following check if we have the information availaible + // we can only and should only perfom the following check if we have the information available mode := determineQueryMode(q) if mode == readMode && !cfg.LogRead { return } + if mode == writeMode && !cfg.LogWrite { return } } + ctx := event.DB.Context() dur := float64(time.Since(event.StartTime)) / float64(time.Millisecond) @@ -310,6 +317,7 @@ func queryLogger(event *pg.QueryProcessedEvent) { // this is only a display issue not a "real" issue le.Msgf("%v", qe) } + le.Msg(q) } @@ -326,6 +334,7 @@ func getQueryType(s string) string { if len(p) > 0 { return strings.ToUpper(s[:p[0]]) } + return strings.ToUpper(s) } @@ -378,5 +387,6 @@ func metricsAdapter(event *pg.QueryProcessedEvent, opts *pg.Options) { metricQueryRowsTotal.With(labels).Add(float64(r.RowsReturned())) metricQueryAffectedTotal.With(labels).Add(math.Max(0, float64(r.RowsAffected()))) } + metricQueryDurationSeconds.With(labels).Observe(dur) } diff --git a/backend/postgres/postgres_test.go b/backend/postgres/postgres_test.go index 5206720e9..144e293a3 100644 --- a/backend/postgres/postgres_test.go +++ b/backend/postgres/postgres_test.go @@ -12,15 +12,17 @@ func TestIntegrationConnectionPool(t *testing.T) { if testing.Short() { t.SkipNow() } + db := ConnectionPool() + var result struct { Calc int } + _, err := db.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) //nolint:errcheck if err != nil { t.Errorf("got %v", err) } - // Note: This test can't actually test the logging correctly // but the code will be accessed } @@ -29,15 +31,17 @@ func TestIntegrationConnectionPoolNoLogging(t *testing.T) { if testing.Short() { t.SkipNow() } + db := ConnectionPool(WithQueryLogging(false, false)) + var result struct { Calc int } + _, err := db.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) //nolint:errcheck if err != nil { t.Errorf("got %v", err) } - // Note: This test can't actually test the logging correctly // but the code will be accessed } diff --git a/backend/queue/metrics.go b/backend/queue/metrics.go index 03769e321..1fe8a6671 100644 --- a/backend/queue/metrics.go +++ b/backend/queue/metrics.go @@ -5,10 +5,11 @@ import ( "time" "github.com/adjust/rmq/v5" + "github.com/prometheus/client_golang/prometheus" + pberrors "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/routine" - "github.com/prometheus/client_golang/prometheus" ) type queueStatsGauges struct { @@ -31,11 +32,13 @@ func gatherMetrics(connection rmq.Connection) { log.Ctx(ctx).Debug().Err(err).Msg("rmq metrics: could not get open queues") pberrors.Handle(ctx, err) } + stats, err := connection.CollectStats(queues) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("rmq metrics: could not collect stats") pberrors.Handle(ctx, err) } + for queue, queueStats := range stats.QueueStats { labels := prometheus.Labels{ "queue": queue, @@ -50,7 +53,7 @@ func gatherMetrics(connection rmq.Connection) { }) } -func registerConnection(connection rmq.Connection) queueStatsGauges { +func registerConnection(_ rmq.Connection) queueStatsGauges { gauges := queueStatsGauges{ readyGauge: prometheus.NewGaugeVec(prometheus.GaugeOpts{ Namespace: "rmq", diff --git a/backend/queue/rmq.go b/backend/queue/rmq.go index 83e76111e..ff76922fe 100644 --- a/backend/queue/rmq.go +++ b/backend/queue/rmq.go @@ -6,13 +6,13 @@ import ( "sync" "time" + "github.com/adjust/rmq/v5" + "github.com/pace/bricks/backend/redis" pberrors "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/routine" - - "github.com/adjust/rmq/v5" ) var ( @@ -32,29 +32,34 @@ type queueHealth struct { func (h *queueHealth) isMarkedHealthy() bool { h.mu.Lock() defer h.mu.Unlock() + return h.markedUnhealthyAt.IsZero() } func (h *queueHealth) markUnhealthy() { h.mu.Lock() defer h.mu.Unlock() + h.markedUnhealthyAt = time.Now() } func (h *queueHealth) markHealthy() { h.mu.Lock() defer h.mu.Unlock() + h.markedUnhealthyAt = time.Time{} } func (h *queueHealth) getMarkedUnhealthyAt() time.Time { h.mu.Lock() defer h.mu.Unlock() + return h.markedUnhealthyAt } func initDefault() error { var err error + initMutex.Lock() defer initMutex.Unlock() @@ -79,8 +84,10 @@ func initDefault() error { rmqConnection = nil return err } + gatherMetrics(rmqConnection) servicehealthcheck.RegisterHealthCheck("rmq", &HealthCheck{}) + return nil } @@ -94,14 +101,18 @@ func NewQueue(name string, healthyLimit int) (rmq.Queue, error) { if err != nil { return nil, err } + queue, err := rmqConnection.OpenQueue(name) if err != nil { return nil, err } + if _, ok := queueHealthLimits.Load(name); ok { return queue, nil } + queueHealthLimits.Store(name, &queueHealth{limit: healthyLimit}) + return queue, nil } @@ -123,22 +134,28 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("rmq HealthCheck: could not get open queues") h.state.SetErrorState(fmt.Errorf("error while retrieving open queues: %s", err)) + return h.state.GetState() } + stats, err := rmqConnection.CollectStats(queues) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("rmq HealthCheck: could not collect stats") h.state.SetErrorState(fmt.Errorf("error while collecting stats: %s", err)) + return h.state.GetState() } + queueHealthLimits.Range(func(k, v interface{}) bool { name := k.(string) hl := v.(*queueHealth) stat := stats.QueueStats[name] + if stat.ReadyCount > int64(hl.limit) { if hl.isMarkedHealthy() { hl.markUnhealthy() h.state.SetHealthy() + return true } // queue health is still pending @@ -147,11 +164,15 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health } h.state.SetErrorState(fmt.Errorf("Queue '%s' exceeded safe health limit of '%d'", name, hl.limit)) + return false } + h.state.SetHealthy() hl.markHealthy() + return true }) + return h.state.GetState() } diff --git a/backend/queue/rmq_test.go b/backend/queue/rmq_test.go index 25f263d54..5636c8830 100644 --- a/backend/queue/rmq_test.go +++ b/backend/queue/rmq_test.go @@ -5,23 +5,28 @@ import ( "testing" "time" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/maintenance/log" ) func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + ctx := log.WithContext(context.Background()) cfg.HealthCheckPendingStateInterval = time.Second * 2 q1, err := NewQueue("integrationTestTasks", 1) assert.NoError(t, err) + err = q1.Publish("nothing here") assert.NoError(t, err) time.Sleep(time.Second) + check := &HealthCheck{IgnoreInterval: true} + res := check.HealthCheck(ctx) if res.State != "OK" { t.Errorf("Expected health check to be OK for a non-full queue: state %s, message: %s", res.State, res.Msg) @@ -37,12 +42,14 @@ func TestIntegrationHealthCheck(t *testing.T) { } // queue health pending time.Sleep(time.Second) + res = check.HealthCheck(ctx) if res.State != "OK" { t.Errorf("Expected health check to be OK") } // queue health no longer pending time.Sleep(time.Second * 2) + res = check.HealthCheck(ctx) if res.State == "OK" { t.Errorf("Expected health check to be ERR for a full queue") @@ -57,6 +64,7 @@ func TestIntegrationHealthCheck(t *testing.T) { err = q1.Publish("nothing here") assert.NoError(t, err) + err = q1.Publish("nothing here either") assert.NoError(t, err) // queue health pending again diff --git a/backend/redis/errors.go b/backend/redis/errors.go index aeac7d08e..e343b6923 100644 --- a/backend/redis/errors.go +++ b/backend/redis/errors.go @@ -16,5 +16,6 @@ func IsErrConnectionFailed(err error) bool { // go-redis has this check internally for network errors _, ok := err.(net.Error) + return ok } diff --git a/backend/redis/health_redis.go b/backend/redis/health_redis.go index fc286b42a..8b735e6e0 100644 --- a/backend/redis/health_redis.go +++ b/backend/redis/health_redis.go @@ -6,8 +6,9 @@ import ( "context" "time" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/redis/go-redis/v9" + + "github.com/pace/bricks/maintenance/health/servicehealthcheck" ) // HealthCheck checks the state of a redis connection. It must not be changed @@ -39,5 +40,6 @@ func (h *HealthCheck) HealthCheck(ctx context.Context) servicehealthcheck.Health } // If reading an writing worked set the Health Check to healthy h.state.SetHealthy() + return h.state.GetState() } diff --git a/backend/redis/health_redis_test.go b/backend/redis/health_redis_test.go index 2dd753da0..e35476f78 100644 --- a/backend/redis/health_redis_test.go +++ b/backend/redis/health_redis_test.go @@ -17,11 +17,10 @@ import ( func setup() *http.Response { r := http2.Router() rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/check", nil) + req := httptest.NewRequest(http.MethodGet, "/health/check", nil) r.ServeHTTP(rec, req) - resp := rec.Result() - defer resp.Body.Close() - return resp + + return rec.Result() } // TestIntegrationHealthCheck tests if redis health check ist working like expected @@ -29,9 +28,16 @@ func TestIntegrationHealthCheck(t *testing.T) { if testing.Short() { t.SkipNow() } + time.Sleep(time.Second) + resp := setup() - if resp.StatusCode != 200 { + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Expected /health/check to respond with 200, got: %d", resp.StatusCode) } @@ -39,6 +45,7 @@ func TestIntegrationHealthCheck(t *testing.T) { if err != nil { log.Fatal(err) } + if !strings.Contains(string(data), "redis OK") { t.Errorf("Expected /health/check to return OK, got: %q", string(data[:])) } diff --git a/backend/redis/redis.go b/backend/redis/redis.go index a1e6ad848..b9a386ef9 100755 --- a/backend/redis/redis.go +++ b/backend/redis/redis.go @@ -10,10 +10,11 @@ import ( "github.com/caarlos0/env/v10" "github.com/opentracing/opentracing-go" olog "github.com/opentracing/opentracing-go/log" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" - "github.com/pace/bricks/maintenance/log" "github.com/prometheus/client_golang/prometheus" "github.com/redis/go-redis/v9" + + "github.com/pace/bricks/maintenance/health/servicehealthcheck" + "github.com/pace/bricks/maintenance/log" ) type config struct { @@ -160,11 +161,11 @@ type logtracerValues struct { span opentracing.Span } -func (lt *logtracer) DialHook(next redis.DialHook) redis.DialHook { +func (l *logtracer) DialHook(next redis.DialHook) redis.DialHook { return next } -func (lt *logtracer) ProcessHook(next redis.ProcessHook) redis.ProcessHook { +func (l *logtracer) ProcessHook(next redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { startedAt := time.Now() @@ -186,13 +187,15 @@ func (lt *logtracer) ProcessHook(next redis.ProcessHook) redis.ProcessHook { _ = next(ctx, cmd) vals := ctx.Value(logtracerKey{}).(*logtracerValues) - le := log.Ctx(ctx).Debug().Str("cmd", cmd.Name()).Str("sentry:category", "redis") + le := log.Ctx(ctx).Debug().Str("cmd", cmd.Name()).Str("sentry:category", "redis") //nolint:zerologlint // add error cmdErr := cmd.Err() if cmdErr != nil { vals.span.LogFields(olog.Error(cmdErr)) + le = le.Err(cmdErr) + paceRedisCmdFailed.With(prometheus.Labels{ "method": cmd.Name(), }).Inc() diff --git a/cmd/pb/main.go b/cmd/pb/main.go index 69652da49..84a61fd5e 100644 --- a/cmd/pb/main.go +++ b/cmd/pb/main.go @@ -18,6 +18,7 @@ func main() { Args: cobra.MaximumNArgs(1), } addRootCommands(rootCmd) + err := rootCmd.Execute() if err != nil { log.Fatal(err) @@ -27,6 +28,7 @@ func main() { // pace ... func addRootCommands(rootCmd *cobra.Command) { var restSource string + rootCmdNew := &cobra.Command{ Use: "new NAME", Args: cobra.ExactArgs(1), @@ -67,6 +69,7 @@ func addRootCommands(rootCmd *cobra.Command) { rootCmd.AddCommand(rootCmdEdit) var runCmd string + rootCmdRun := &cobra.Command{ Use: "run NAME", Args: cobra.ExactArgs(1), @@ -81,6 +84,7 @@ func addRootCommands(rootCmd *cobra.Command) { rootCmd.AddCommand(rootCmdRun) var testGoConvey bool + rootCmdTest := &cobra.Command{ Use: "test NAME", Args: cobra.ExactArgs(1), @@ -136,6 +140,7 @@ func (e *errorDefinitionsOutputFlag) Type() string { // pace service generate ... func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { var pkgName, path, source string + cmdRest := &cobra.Command{ Use: "rest", Args: cobra.NoArgs, @@ -153,6 +158,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdRest) var commandsPath string + cmdCommands := &cobra.Command{ Use: "commands NAME", Args: cobra.ExactArgs(1), @@ -165,6 +171,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdCommands) var dockerfilePath string + cmdDockerfile := &cobra.Command{ Use: "dockerfile NAME", Args: cobra.ExactArgs(1), @@ -179,6 +186,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdDockerfile) var makefilePath string + cmdMakefile := &cobra.Command{ Use: "makefile NAME", Args: cobra.ExactArgs(1), @@ -192,6 +200,7 @@ func addServiceGenerateCommands(rootCmdGenerate *cobra.Command) { rootCmdGenerate.AddCommand(cmdMakefile) var errorsDefinitionsPkgName, errorsDefinitionsPath, errorsDefinitionsSource string + errorDefinitionsOutput := goOutputFlag cmdErrorDefinitions := &cobra.Command{ Use: "error-definitions", diff --git a/grpc/client.go b/grpc/client.go index 43824b3f1..4c37ca069 100644 --- a/grpc/client.go +++ b/grpc/client.go @@ -6,6 +6,9 @@ import ( "context" "time" + grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" + grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" + grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" @@ -14,29 +17,16 @@ import ( "github.com/pace/bricks/http/security" "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/log" - - grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" - grpc_retry "github.com/grpc-ecosystem/go-grpc-middleware/retry" - grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" ) -func DialContext(ctx context.Context, addr string) (*grpc.ClientConn, error) { - return dialCtx(ctx, addr) -} - -func Dial(addr string) (*grpc.ClientConn, error) { - return dialCtx(context.Background(), addr) -} - -func dialCtx(ctx context.Context, addr string) (*grpc.ClientConn, error) { - var conn *grpc.ClientConn - +func NewClient(addr string) (*grpc.ClientConn, error) { clientMetrics := grpc_prometheus.NewClientMetrics() opts := []grpc_retry.CallOption{ grpc_retry.WithBackoff(grpc_retry.BackoffLinear(100 * time.Millisecond)), } - conn, err := grpc.DialContext(ctx, addr, grpc.WithTransportCredentials(insecure.NewCredentials()), + + return grpc.NewClient(addr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithChainStreamInterceptor( grpc_opentracing.StreamClientInterceptor(), grpc_opentracing.StreamClientInterceptor(), @@ -49,6 +39,7 @@ func dialCtx(ctx context.Context, addr string) (*grpc.ClientConn, error) { Str("type", "stream"). Err(err). Msg("GRPC requested") + return cs, err }, ), @@ -64,23 +55,26 @@ func dialCtx(ctx context.Context, addr string) (*grpc.ClientConn, error) { Str("type", "unary"). Err(err). Msg("GRPC requested") + return err }, ), ) - return conn, err } func prepareClientContext(ctx context.Context) context.Context { if loc, ok := locale.FromCtx(ctx); ok { ctx = metadata.AppendToOutgoingContext(ctx, MetadataKeyLocale, loc.Serialize()) } + if token, ok := security.GetTokenFromContext(ctx); ok { ctx = metadata.AppendToOutgoingContext(ctx, MetadataKeyBearerToken, token.GetValue()) } + if reqID := log.RequestIDFromContext(ctx); reqID != "" { ctx = metadata.AppendToOutgoingContext(ctx, MetadataKeyRequestID, reqID) } + ctx = EncodeContextWithUTMData(ctx) if dep := middleware.ExternalDependencyContextFromContext(ctx); dep != nil { diff --git a/grpc/middleware.go b/grpc/middleware.go index 04cea44e2..1a7ac841f 100644 --- a/grpc/middleware.go +++ b/grpc/middleware.go @@ -7,8 +7,9 @@ import ( "encoding/gob" "strings" - "github.com/pace/bricks/pkg/tracking/utm" "google.golang.org/grpc/metadata" + + "github.com/pace/bricks/pkg/tracking/utm" ) const utmMetadataKey = "utm-bin" // IMPORTANT -bin post-fix allows us to send binary data via grpc metadata, otherwise it will break the protocol @@ -18,10 +19,12 @@ func ContextWithUTMFromMetadata(parentCtx context.Context, md metadata.MD) conte if len(dataSlice) == 0 { return parentCtx } + var utmData utm.UTMData if err := gob.NewDecoder(strings.NewReader(dataSlice[0])).Decode(&utmData); err != nil { return parentCtx } + return utm.ContextWithUTMData(parentCtx, utmData) } @@ -30,9 +33,11 @@ func EncodeContextWithUTMData(parentCtx context.Context) context.Context { if !exists { return parentCtx } + w := strings.Builder{} if err := gob.NewEncoder(&w).Encode(utmData); err != nil { return parentCtx } + return metadata.AppendToOutgoingContext(parentCtx, utmMetadataKey, w.String()) } diff --git a/grpc/middleware_test.go b/grpc/middleware_test.go index 474070bd3..4001bda71 100644 --- a/grpc/middleware_test.go +++ b/grpc/middleware_test.go @@ -6,9 +6,10 @@ import ( "context" "testing" - "github.com/pace/bricks/pkg/tracking/utm" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + + "github.com/pace/bricks/pkg/tracking/utm" ) func TestEncodeContextWithUTMData(t *testing.T) { @@ -25,6 +26,7 @@ func TestEncodeContextWithUTMData(t *testing.T) { ctx = EncodeContextWithUTMData(ctx) md, exists := metadata.FromOutgoingContext(ctx) require.True(t, exists) + ctx2 := context.Background() ctx2 = ContextWithUTMFromMetadata(ctx2, md) utmData, exists := utm.FromContext(ctx2) diff --git a/grpc/server.go b/grpc/server.go index 99d9296d6..5ba822759 100644 --- a/grpc/server.go +++ b/grpc/server.go @@ -9,28 +9,28 @@ import ( "strings" "time" + "github.com/caarlos0/env/v10" grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware" grpc_auth "github.com/grpc-ecosystem/go-grpc-middleware/auth" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" grpc_ctxtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" grpc_opentracing "github.com/grpc-ecosystem/go-grpc-middleware/tracing/opentracing" - "github.com/pace/bricks/http/middleware" - "github.com/pace/bricks/http/security" - "github.com/pace/bricks/locale" - "github.com/pace/bricks/maintenance/errors" - "github.com/pace/bricks/maintenance/log" - "github.com/pace/bricks/maintenance/log/hlog" "github.com/rs/xid" "github.com/rs/zerolog" zlog "github.com/rs/zerolog/log" - - "github.com/caarlos0/env/v10" "google.golang.org/grpc" "google.golang.org/grpc/metadata" "google.golang.org/grpc/peer" + + "github.com/pace/bricks/http/middleware" + "github.com/pace/bricks/http/security" + "github.com/pace/bricks/locale" + "github.com/pace/bricks/maintenance/errors" + "github.com/pace/bricks/maintenance/log" + "github.com/pace/bricks/maintenance/log/hlog" ) -var InternalServerError = errors.New("internal server error") +var ErrInternalServer = errors.New("internal server error") type Config struct { Address string `env:"GRPC_ADDR" envDefault:":3001"` @@ -46,16 +46,20 @@ func ListenAndServe(gs *grpc.Server) error { if err != nil { return err } + log.Logger().Info().Str("addr", listener.Addr().String()).Msg("Starting grpc server ...") + err = gs.Serve(listener) if err != nil { return err } + return nil } func Listener() (net.Listener, error) { var cfg Config + err := env.Parse(&cfg) if err != nil { return nil, fmt.Errorf("failed to parse grpc server environment: %w", err) @@ -65,6 +69,7 @@ func Listener() (net.Listener, error) { if err != nil { return nil, fmt.Errorf("unable to create grpc listener for %q: %w", cfg.Address, err) } + return tcpListener, nil } @@ -82,6 +87,7 @@ func Server(ab AuthBackend) *grpc.Server { wrappedStream := grpc_middleware.WrapServerStream(stream) wrappedStream.WrappedContext = ctx + var addr string if p, ok := peer.FromContext(ctx); ok { addr = p.Addr.String() @@ -98,12 +104,15 @@ func Server(ab AuthBackend) *grpc.Server { Str("user_agent", strings.Join(md.Get("user-agent"), ",")). Err(err). Msg("GRPC completed Stream") + return err }, func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) (err error) { defer errors.HandleWithCtx(stream.Context(), "GRPC "+info.FullMethod) - err = InternalServerError // default in case of a panic + + err = ErrInternalServer // default in case of a panic err = handler(srv, stream) + return err }, grpc_auth.StreamServerInterceptor(ab.AuthorizeStream), @@ -131,12 +140,15 @@ func Server(ab AuthBackend) *grpc.Server { Str("user_agent", strings.Join(md.Get("user-agent"), ",")). Err(err). Msg("GRPC completed Unary") + return }, func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) { defer errors.HandleWithCtx(ctx, "GRPC "+info.FullMethod) - err = InternalServerError // default in case of a panic + + err = ErrInternalServer // default in case of a panic resp, err = handler(ctx, req) + return }, grpc_auth.UnaryServerInterceptor(ab.AuthorizeUnary), @@ -152,11 +164,14 @@ func prepareContext(ctx context.Context) (context.Context, metadata.MD) { // add request context if req_id is given var reqID xid.ID + if ri := md.Get(MetadataKeyRequestID); len(ri) > 0 { var err error + reqID, err = xid.FromString(ri[0]) if err != nil { log.Debugf("unable to parse xid from req_id: %v", err) + reqID = xid.New() } } else { diff --git a/grpc/server_test.go b/grpc/server_test.go index a2bda5f50..ebb3d69e7 100644 --- a/grpc/server_test.go +++ b/grpc/server_test.go @@ -7,12 +7,13 @@ import ( "context" "testing" - "github.com/pace/bricks/http/middleware" - "github.com/pace/bricks/locale" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc/metadata" + + "github.com/pace/bricks/http/middleware" + "github.com/pace/bricks/locale" + "github.com/pace/bricks/maintenance/log" ) func TestPrepareContext(t *testing.T) { @@ -23,6 +24,7 @@ func TestPrepareContext(t *testing.T) { assert.NotEmpty(t, log.RequestIDFromContext(ctx0)) var buf0 bytes.Buffer + l := log.Ctx(ctx0).Output(&buf0) l.Debug().Msg("test") assert.Contains(t, buf0.String(), "{\"level\":\"debug\",\"req_id\":\""+ @@ -40,12 +42,14 @@ func TestPrepareContext(t *testing.T) { assert.Len(t, md.Get(MetadataKeyRequestID), 0) assert.Len(t, md.Get(MetadataKeyBearerToken), 0) assert.Equal(t, "c690uu0ta2rv348epm8g", log.RequestIDFromContext(ctx1)) + loc, ok := locale.FromCtx(ctx1) assert.True(t, ok) assert.Equal(t, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", loc.Language()) assert.Equal(t, "Europe/Paris", loc.Timezone()) var buf1 bytes.Buffer + l = log.Ctx(ctx1).Output(&buf1) l.Debug().Msg("test") assert.Contains(t, buf1.String(), "{\"level\":\"debug\",\"req_id\":\"c690uu0ta2rv348epm8g\",\"time\":\"") @@ -63,10 +67,12 @@ func TestPrepareContext(t *testing.T) { assert.Equal(t, "c690uu0ta2rv348epm8g", log.RequestIDFromContext(ctx1)) var buf2 bytes.Buffer + l = log.Ctx(ctx2).Output(&buf2) l.Debug().Msg("test") assert.Contains(t, buf2.String(), "{\"level\":\"debug\",\"req_id\":\"c690uu0ta2rv348epm8g\",\"time\":\"") assert.Contains(t, buf2.String(), ",\"message\":\"test\"}\n") + _, ok = locale.FromCtx(ctx2) assert.False(t, ok) diff --git a/http/jsonapi/errors_test.go b/http/jsonapi/errors_test.go index bef0155e2..7d960f454 100644 --- a/http/jsonapi/errors_test.go +++ b/http/jsonapi/errors_test.go @@ -14,6 +14,7 @@ import ( func TestErrorObjectWritesExpectedErrorMessage(t *testing.T) { err := &ErrorObject{Title: "Title test.", Detail: "Detail test."} + var input error = err output := input.Error() @@ -31,9 +32,9 @@ func TestMarshalErrorsWritesTheExpectedPayload(t *testing.T) { }{ { Title: "TestFieldsAreSerializedAsNeeded", - In: []*ErrorObject{{ID: "0", Title: "Test title.", Detail: "Test detail", Status: "400", Code: "E1100"}}, + In: []*ErrorObject{{ID: "0", Title: "Test title.", Detail: "Test detail", Status: "http.StatusBadRequest", Code: "E1100"}}, Out: map[string]interface{}{"errors": []interface{}{ - map[string]interface{}{"id": "0", "title": "Test title.", "detail": "Test detail", "status": "400", "code": "E1100"}, + map[string]interface{}{"id": "0", "title": "Test title.", "detail": "Test detail", "status": "http.StatusBadRequest", "code": "E1100"}, }}, }, { @@ -47,9 +48,11 @@ func TestMarshalErrorsWritesTheExpectedPayload(t *testing.T) { for _, testRow := range marshalErrorsTableTasts { t.Run(testRow.Title, func(t *testing.T) { buffer, output := bytes.NewBuffer(nil), map[string]interface{}{} + var writer io.Writer = buffer _ = MarshalErrors(writer, testRow.In) + err := json.Unmarshal(buffer.Bytes(), &output) if err != nil { t.Fatal(err) diff --git a/http/jsonapi/generator/generate.go b/http/jsonapi/generator/generate.go index 56fce589c..4c714af04 100644 --- a/http/jsonapi/generator/generate.go +++ b/http/jsonapi/generator/generate.go @@ -29,14 +29,14 @@ type Generator struct { generatedArrayTypes map[string]bool } -func loadSwaggerFromURI(loader *openapi3.SwaggerLoader, url *url.URL) (*openapi3.Swagger, error) { // nolint: interfacer +func loadSwaggerFromURI(loader *openapi3.SwaggerLoader, url *url.URL) (*openapi3.Swagger, error) { var schema *openapi3.Swagger resp, err := http.Get(url.String()) if err != nil { return nil, err } - defer resp.Body.Close() // nolint: errcheck + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { @@ -55,6 +55,7 @@ func loadSwaggerFromURI(loader *openapi3.SwaggerLoader, url *url.URL) (*openapi3 // based on the passed schema source (url or file path) func (g *Generator) BuildSource(source, packagePath, packageName string) (string, error) { loader := openapi3.NewSwaggerLoader() + var schema *openapi3.Swagger if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { @@ -69,7 +70,7 @@ func (g *Generator) BuildSource(source, packagePath, packageName string) (string } } else { // read spec - data, err := os.ReadFile(source) // nolint: gosec + data, err := os.ReadFile(source) //nolint:gosec if err != nil { return "", err } diff --git a/http/jsonapi/generator/generate_handler.go b/http/jsonapi/generator/generate_handler.go index 82c69a315..9ba71a6b9 100644 --- a/http/jsonapi/generator/generate_handler.go +++ b/http/jsonapi/generator/generate_handler.go @@ -27,7 +27,7 @@ const ( pkgOpentracing = "github.com/opentracing/opentracing-go" pkgOAuth2 = "github.com/pace/bricks/http/oauth2" pkgOIDC = "github.com/pace/bricks/http/oidc" - pkgApiKey = "github.com/pace/bricks/http/security/apikey" + pkgAPIKey = "github.com/pace/bricks/http/security/apikey" //nolint:gosec pkgDecimal = "github.com/shopspring/decimal" ) @@ -63,12 +63,14 @@ func (g *Generator) BuildHandler(schema *openapi3.Swagger) error { for k := range paths { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) var routes []*route for _, pattern := range keys { path := paths[pattern] + err := g.buildPath(pattern, path, &routes, schema.Components.SecuritySchemes) if err != nil { return err @@ -148,15 +150,17 @@ func (g *Generator) generateRequestResponseTypes(routes []*route, schema *openap return nil } -func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swagger) error { - var methods []jen.Code - methods = append(methods, jen.Qual("net/http", "ResponseWriter")) +func (g *Generator) generateResponseInterface(route *route, _ *openapi3.Swagger) error { + methods := []jen.Code{ + jen.Qual("net/http", "ResponseWriter"), + } // sort by key keys := make([]string, 0, len(route.operation.Responses)) for k := range route.operation.Responses { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) for _, code := range keys { @@ -203,6 +207,7 @@ func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swa if err != nil { return err } + method.Params(typeReference) defer func() { // defer to put methods after type @@ -255,12 +260,13 @@ func (g *Generator) generateResponseInterface(route *route, schema *openapi3.Swa return nil } -func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger) error { +func (g *Generator) generateRequestStruct(route *route, _ *openapi3.Swagger) error { body := route.operation.RequestBody - var fields []jen.Code // add http request - fields = append(fields, jen.Id("Request").Op("*").Qual("net/http", "Request").Tag(noValidation)) + fields := []jen.Code{ + jen.Id("Request").Op("*").Qual("net/http", "Request").Tag(noValidation), + } // add request type if body != nil { @@ -279,6 +285,7 @@ func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger for _, param := range route.operation.Parameters { paramName := generateParamName(param) paramStmt := jen.Id(paramName) + tags := make(map[string]string) if param.Value.Required { tags["valid"] = "required" @@ -309,6 +316,7 @@ func (g *Generator) generateRequestStruct(route *route, schema *openapi3.Swagger g.addGoDoc(route.requestType, "is a standard http.Request extended with the\n"+ "un-marshaled content object") } + g.goSource.Type().Id(route.requestType).Struct(fields...) return nil @@ -335,7 +343,7 @@ func (g *Generator) buildServiceInterface(routes []*route, schema *openapi3.Swag return nil } -func (g *Generator) buildSubServiceInterface(route *route, schema *openapi3.Swagger) error { +func (g *Generator) buildSubServiceInterface(route *route, _ *openapi3.Swagger) error { methods := make([]jen.Code, 0) if route.operation.Description != "" { @@ -343,6 +351,7 @@ func (g *Generator) buildSubServiceInterface(route *route, schema *openapi3.Swag } else { methods = append(methods, jen.Comment(fmt.Sprintf("%s %s", route.serviceFunc, route.operation.Summary))) } + methods = append(methods, jen.Id(route.serviceFunc).Params( jen.Qual("context", "Context"), jen.Id(route.responseType), @@ -360,7 +369,9 @@ func (g *Generator) buildRouter(routes []*route, schema *openapi3.Swagger) error if err != nil { return nil } + g.addGoDoc("Router", "implements: "+schema.Info.Title+"\n\n"+schema.Info.Description) + serviceInterfaceVariable := jen.Id("service").Interface() if hasSecuritySchema(schema) { g.goSource.Func().Id("Router").Params( @@ -369,6 +380,7 @@ func (g *Generator) buildRouter(routes []*route, schema *openapi3.Swagger) error g.goSource.Func().Id("Router").Params( serviceInterfaceVariable).Op("*").Qual(pkgGorillaMux, "Router").Block(routerBody...) } + return nil } @@ -377,7 +389,9 @@ func (g *Generator) buildRouterWithFallbackAsArg(routes []*route, schema *openap if err != nil { return nil } + g.addGoDoc("Router", "implements: "+schema.Info.Title+"\n\n"+schema.Info.Description) + serviceInterfaceVariable := jen.Id("service").Interface() if hasSecuritySchema(schema) { g.goSource.Func().Id("RouterWithFallback").Params( @@ -386,6 +400,7 @@ func (g *Generator) buildRouterWithFallbackAsArg(routes []*route, schema *openap g.goSource.Func().Id("RouterWithFallback").Params( serviceInterfaceVariable, jen.Id("fallback").Qual("net/http", "Handler")).Op("*").Qual(pkgGorillaMux, "Router").Block(routerBody...) } + return nil } @@ -401,15 +416,18 @@ func (g *Generator) buildRouterHelpers(routes []*route, schema *openapi3.Swagger // add all route handlers for i := 0; i < len(sortableRoutes); i++ { route := sortableRoutes[i] + var routeCallParams *jen.Statement if needsSecurity { routeCallParams = jen.List(jen.Id("service"), jen.Id("authBackend")) } else { routeCallParams = jen.List(jen.Id("service")) } + primaryHandler := jen.Id(route.handler).Call(routeCallParams) fallbackHandler := jen.Id(fallbackName) ifElse := make([]jen.Code, 0) + for _, handler := range []jen.Code{primaryHandler, fallbackHandler} { block := jen.Return(handler) ifElse = append(ifElse, block) @@ -431,6 +449,7 @@ func (g *Generator) buildRouterHelpers(routes []*route, schema *openapi3.Swagger } else { callParams = jen.List(jen.Id("service").Id("interface{}"), fallback) } + helper := jen.Func().Id(generateHandlerTypeAssertionHelperName(route.handler)). Params(callParams).Qual("net/http", "Handler").Block(implGuard).Line().Line() @@ -444,7 +463,9 @@ func (g *Generator) buildRouterHelpers(routes []*route, schema *openapi3.Swagger func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi3.Swagger, fallback jen.Code) ([]jen.Code, error) { needsSecurity := hasSecuritySchema(schema) startInd := 0 - var routeStmts []jen.Code + + var routeStmts []jen.Code //nolint:prealloc + if needsSecurity { startInd++ routeStmts = make([]jen.Code, 2, (len(routes)+2)*len(schema.Servers)+2) @@ -453,7 +474,9 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi for name := range schema.Components.SecuritySchemes { names = append(names, name) } + sort.Stable(sort.StringSlice(names)) + caser := cases.Title(language.Und, cases.NoLower) for _, name := range names { routeStmts = append(routeStmts, jen.Id("authBackend").Dot("Init"+caser.String(name)).Call(jen.Id("cfg"+caser.String(name)))) @@ -466,16 +489,20 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi // Note: we don't restrict host, scheme and port to ease development pathsIdx := make(map[string]struct{}) + var paths []string + for _, server := range schema.Servers { - serverUrl, err := url.Parse(server.URL) + serverURL, err := url.Parse(server.URL) if err != nil { return nil, err } - if _, ok := pathsIdx[serverUrl.Path]; !ok { - paths = append(paths, serverUrl.Path) + + if _, ok := pathsIdx[serverURL.Path]; !ok { + paths = append(paths, serverURL.Path) } - pathsIdx[serverUrl.Path] = struct{}{} + + pathsIdx[serverURL.Path] = struct{}{} } // but generate subrouters for each server @@ -494,12 +521,14 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi // add all route handlers for i := 0; i < len(sortableRoutes); i++ { route := sortableRoutes[i] + var routeCallParams *jen.Statement if needsSecurity { routeCallParams = jen.List(jen.Id("service"), fallback, jen.Id("authBackend")) } else { routeCallParams = jen.List(jen.Id("service"), fallback) } + helper := jen.Id(generateHandlerTypeAssertionHelperName(route.handler)).Call(routeCallParams) routeStmt := jen.Id(subrouterID).Dot("Methods").Call(jen.Lit(route.method)). Dot("Path").Call(jen.Lit(route.url.Path)) @@ -510,6 +539,7 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi if len(value) != 1 { panic("query paths can only handle one query parameter with the same name!") } + routeStmt.Dot("Queries").Call(jen.Lit(key), jen.Lit(value[0])) } } @@ -530,7 +560,7 @@ func (g *Generator) buildRouterBodyWithFallback(routes []*route, schema *openapi return routeStmts, nil } -func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern string, pathItem *openapi3.PathItem, secSchemes map[string]*openapi3.SecuritySchemeRef) (*route, error) { +func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern string, _ *openapi3.PathItem, secSchemes map[string]*openapi3.SecuritySchemeRef) (*route, error) { needsSecurity := len(secSchemes) > 0 route := &route{ method: strings.ToUpper(method), @@ -545,11 +575,14 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // use OperationID for go function names or generate the name caser := cases.Title(language.Und, cases.NoLower) + oid := caser.String(op.OperationID) if oid == "" { log.Warnf("Note: Avoid automatic method name generation for path (use OperationID): %s", pattern) + oid = generateName(method, op, pattern) } + handler := oid + "Handler" route.handler = handler route.serviceFunc = oid @@ -559,6 +592,7 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // check if handler has request body var requestBody bool + if body := op.RequestBody; body != nil { if mt := body.Value.Content.Get(jsonapiContent); mt != nil { requestBody = true @@ -567,25 +601,31 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // generate handler function gen := g // generator is used less frequent then the jen group, make available with longer name + var auth *jen.Group + if needsSecurity { if op.Security != nil { var err error + auth, err = generateAuthorization(op, secSchemes) if err != nil { return nil, err } } } + g.addGoDoc(handler, fmt.Sprintf(`handles request/response marshaling and validation for %s %s`, method, pattern)) + var params *jen.Statement if needsSecurity { params = jen.List(jen.Id("service").Id(generateSubServiceName(route.handler)), jen.Id("authBackend").Id(authBackendInterface)) } else { params = jen.List(jen.Id("service").Id(generateSubServiceName(route.handler))) } + g.goSource.Func().Id(handler).Params(params).Qual("net/http", "Handler").Block( jen.Return().Qual("net/http", "HandlerFunc").Call( jen.Func().Params( @@ -618,14 +658,17 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // vars in case parameters are given g.Line().Comment("Scan and validate incoming request parameters") + if len(route.operation.Parameters) > 0 { // path parameters need the vars needVars := false + for _, param := range route.operation.Parameters { if param.Value.In == "path" { needVars = true } } + if needVars { g.Id("vars").Op(":=").Qual(pkgGorillaMux, "Vars").Call(jen.Id("r")) } @@ -695,7 +738,9 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern // otherwise directly call the service if requestBody { g.Line().Comment("Unmarshal the service request body") + isArray := false + mt := op.RequestBody.Value.Content.Get(jsonapiContent) if mt != nil { data := mt.Schema.Value.Properties["data"] @@ -705,6 +750,7 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern } } } + if isArray { typeName := nameFromSchemaRef(mt.Schema.Value.Properties["data"].Value.Items) g.List(jen.Id("ok"), jen.Id("data")).Op(":="). @@ -746,32 +792,40 @@ func (g *Generator) buildHandler(method string, op *openapi3.Operation, pattern func generateAuthorization(op *openapi3.Operation, secSchemes map[string]*openapi3.SecuritySchemeRef) (*jen.Group, error) { req := *op.Security r := &jen.Group{} + if len(req[0]) == 0 { return r, nil } multipleSecSchemes := len(req[0]) > 1 + var err error + if multipleSecSchemes { r, err = generateAuthorizationForMultipleSecSchemas(op, secSchemes) } else { r, err = generateAuthorizationForSingleSecSchema(op, secSchemes) } + if err != nil { return nil, err } r.Line().Id("r").Op("=").Id("r.WithContext").Call(jen.Id("ctx")) + return r, nil } func generateAuthorizationForSingleSecSchema(op *openapi3.Operation, schemas map[string]*openapi3.SecuritySchemeRef) (*jen.Group, error) { req := *op.Security r := &jen.Group{} + if len(req[0]) == 0 { return nil, nil } + caser := cases.Title(language.Und, cases.NoLower) + for name, secConfig := range (*op.Security)[0] { securityScheme := schemas[name] switch securityScheme.Value.Type { @@ -785,22 +839,30 @@ func generateAuthorizationForSingleSecSchema(op *openapi3.Operation, schemas map if len(secConfig) > 0 { return nil, fmt.Errorf("security config for api key authorization needs %d values but had: %d", 0, len(secConfig)) } + r.Line().List(jen.Id("ctx"), jen.Id("ok")).Op(":=").Id("authBackend."+authFuncPrefix+caser.String(name)).Call(jen.Id("r"), jen.Id("w")) default: return nil, fmt.Errorf("security Scheme of type %q is not suppported", securityScheme.Value.Type) } } + r.Line().If(jen.Op("!").Id("ok")).Block(jen.Return()) + return r, nil } func generateAuthorizationForMultipleSecSchemas(op *openapi3.Operation, secSchemes map[string]*openapi3.SecuritySchemeRef) (*jen.Group, error) { - var orderedSec [][]string + orderedSec := make([][]string, len((*op.Security)[0])) + i := 0 + // Security Schemes are sorted for a reliable order of the code for name, val := range (*op.Security)[0] { vals := []string{name} - orderedSec = append(orderedSec, append(vals, val...)) + orderedSec[i] = append(vals, val...) + + i++ } + sort.Slice(orderedSec, func(i, j int) bool { return orderedSec[i][0] < orderedSec[j][0] }) @@ -814,11 +876,13 @@ func generateAuthorizationForMultipleSecSchemas(op *openapi3.Operation, secSchem r.Line().Var().Id("ctx").Id("context.Context") r.Line().Var().Id("ok").Id("bool") + for _, val := range orderedSec { name := val[0] securityScheme := secSchemes[name] innerBlock := &jen.Group{} innerBlock.Line().List(jen.Id("ctx"), jen.Id("ok")).Op("=").Id("authBackend." + authFuncPrefix + caser.String(name)) + switch securityScheme.Value.Type { case "oauth2", "openIdConnect": if len(val) >= 2 { @@ -830,25 +894,31 @@ func generateAuthorizationForMultipleSecSchemas(op *openapi3.Operation, secSchem if len(val) > 1 { return nil, fmt.Errorf("security config for api key authorization needs %d values but had: %d", 0, len(val)) } + innerBlock.Call(jen.Id("r"), jen.Id("w")) default: return nil, fmt.Errorf("security Scheme of type %q is not suppported", securityScheme.Value.Type) } + innerBlock.Line().If(jen.Op("!").Id("ok")).Block(jen.Return()) r.Line().If(jen.Id("authBackend." + authCanAuthFuncPrefix + caser.String(name))).Call(jen.Id("r")).Block(innerBlock).Else() } + r.Block(last) + return r, nil } var asciiName = regexp.MustCompile("([^a-zA-Z]+)") -func generateName(method string, op *openapi3.Operation, pattern string) string { +func generateName(method string, _ *openapi3.Operation, pattern string) string { name := method parts := strings.Split(asciiName.ReplaceAllString(pattern, "/"), "/") + for _, part := range parts { name += goNameHelper(part) } + return goNameHelper(name) } @@ -857,6 +927,7 @@ func generateMethodName(description string) string { for i := 0; i < len(parts); i++ { parts[i] = goNameHelper(parts[i]) } + return goNameHelper(strings.Join(parts, "")) } diff --git a/http/jsonapi/generator/generate_helper.go b/http/jsonapi/generator/generate_helper.go index 9cf93b1ff..00e66353e 100644 --- a/http/jsonapi/generator/generate_helper.go +++ b/http/jsonapi/generator/generate_helper.go @@ -22,7 +22,7 @@ func (g *Generator) addGoDoc(typeName, description string) { } } -func (g *Generator) goType(stmt *jen.Statement, schema *openapi3.Schema, tags map[string]string) *typeGenerator { // nolint: gocyclo +func (g *Generator) goType(stmt *jen.Statement, schema *openapi3.Schema, tags map[string]string) *typeGenerator { return &typeGenerator{ g: g, stmt: stmt, @@ -39,7 +39,7 @@ type typeGenerator struct { isParam bool } -func (g *typeGenerator) invoke() error { // nolint: gocyclo +func (g *typeGenerator) invoke() error { switch g.schema.Type { case "string": switch g.schema.Format { @@ -61,6 +61,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "date": addValidator(g.tags, "time(2006-01-02)") + if g.isParam { g.stmt.Qual("time", "Time") } else { @@ -68,6 +69,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "uuid": addValidator(g.tags, "uuid") + if g.schema.Nullable { g.stmt.Op("*").String() } else { @@ -75,6 +77,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "decimal": addValidator(g.tags, "matches(^(\\d*\\.)?\\d+$)") + if g.isParam { g.stmt.Qual(pkgDecimal, "Decimal") } else { @@ -89,6 +92,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "integer": removeOmitempty(g.tags) + switch g.schema.Format { case "int32": if g.schema.Nullable { @@ -114,6 +118,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "float": removeOmitempty(g.tags) + if g.schema.Nullable { g.stmt.Op("*").Float32() } else { @@ -123,6 +128,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo fallthrough default: removeOmitempty(g.tags) + if g.schema.Nullable { g.stmt.Op("*").Float64() } else { @@ -131,13 +137,15 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo } case "boolean": removeOmitempty(g.tags) + if g.schema.Nullable { g.stmt.Op("*").Bool() } else { g.stmt.Bool() } - case "array": // nolint: goconst + case "array": removeOmitempty(g.tags) + err := g.g.goType(g.stmt.Index(), g.schema.Items.Value, g.tags).invoke() if err != nil { return err @@ -156,7 +164,7 @@ func (g *typeGenerator) invoke() error { // nolint: gocyclo // in case the field/value is optional // an empty value needs to be added to the enum validator if hasValidator(g.tags, "optional") { - strs = append(strs, "") + strs = append(strs, "") //nolint:makezero } addValidator(g.tags, fmt.Sprintf("in(%v)", strings.Join(strs, "|"))) @@ -182,6 +190,7 @@ func addValidator(tags map[string]string, validator string) { if cur != "" { validator = tags["valid"] + "," + validator } + tags["valid"] = validator } @@ -190,6 +199,7 @@ func hasValidator(tags map[string]string, validator string) bool { if !ok { return false } + validators := strings.Split(validatorCfg, ",") for _, v := range validators { if strings.HasPrefix(v, validator) { @@ -207,6 +217,7 @@ func goNameHelper(name string) string { name = caser.String(name) name = strings.Replace(name, "Url", "URL", -1) name = idRegex.ReplaceAllString(name, "ID") + return name } @@ -215,5 +226,6 @@ func nameFromSchemaRef(ref *openapi3.SchemaRef) string { if name == "." { return "" } + return name } diff --git a/http/jsonapi/generator/generate_security.go b/http/jsonapi/generator/generate_security.go index b7f4ba045..b3ade53b3 100644 --- a/http/jsonapi/generator/generate_security.go +++ b/http/jsonapi/generator/generate_security.go @@ -27,17 +27,25 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro if !hasSecuritySchema(schema) { return nil } + securitySchemes := schema.Components.SecuritySchemes // r contains the methods for the security interface r := &jen.Group{} // Because the order of the values while iterating over a map is randomized the generated result can only be tested if the keys are sorted - var keys []string + keys := make([]string, len(securitySchemes)) + i := 0 + for k := range securitySchemes { - keys = append(keys, k) + keys[i] = k + + i++ } + sort.Stable(sort.StringSlice(keys)) + hasDuplicatedSecuritySchema := false + for _, pathItem := range schema.Paths { for _, op := range pathItem.Operations() { if op.Security != nil { @@ -47,9 +55,12 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro } caser := cases.Title(language.Und, cases.NoLower) + for _, name := range keys { value := securitySchemes[name] + r.Line().Id(authFuncPrefix + caser.String(name)) + switch value.Value.Type { case "oauth2": r.Params(jen.Id("r").Id("*http.Request"), jen.Id("w").Id("http.ResponseWriter"), jen.Id("scope").String()).Params(jen.Id("context.Context"), jen.Id("bool")) @@ -59,7 +70,7 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro r.Line().Id("Init" + caser.String(name)).Params(jen.Id("cfg"+caser.String(name)).Op("*").Qual(pkgOIDC, "Config")) case "apiKey": r.Params(jen.Id("r").Id("*http.Request"), jen.Id("w").Id("http.ResponseWriter")).Params(jen.Id("context.Context"), jen.Id("bool")) - r.Line().Id("Init" + caser.String(name)).Params(jen.Id("cfg"+caser.String(name)).Op("*").Qual(pkgApiKey, "Config")) + r.Line().Id("Init" + caser.String(name)).Params(jen.Id("cfg"+caser.String(name)).Op("*").Qual(pkgAPIKey, "Config")) default: return errors.New("security schema type not supported: " + value.Value.Type) } @@ -70,6 +81,7 @@ func (g *Generator) buildSecurityBackendInterface(schema *openapi3.Swagger) erro } g.goSource.Type().Id(authBackendInterface).Interface(r) + return nil } @@ -78,18 +90,26 @@ func (g *Generator) buildSecurityConfigs(schema *openapi3.Swagger) error { if !hasSecuritySchema(schema) { return nil } + securitySchemes := schema.Components.SecuritySchemes // Because the order of the values while iterating over a map is randomized the generated result can only be tested if the keys are sorted - var keys []string + keys := make([]string, len(securitySchemes)) + i := 0 + for k := range securitySchemes { - keys = append(keys, k) + keys[i] = k + + i++ } + sort.Stable(sort.StringSlice(keys)) for _, name := range keys { value := securitySchemes[name] instanceVal := jen.Dict{} + var pkgName string + switch value.Value.Type { case "oauth2": pkgName = pkgOAuth2 @@ -112,6 +132,7 @@ func (g *Generator) buildSecurityConfigs(schema *openapi3.Swagger) error { case "openIdConnect": pkgName = pkgOIDC instanceVal[jen.Id("Description")] = jen.Lit(value.Value.Description) + if e, ok := value.Value.Extensions["openIdConnectUrl"]; ok { var url string if data, ok := e.(json.RawMessage); ok { @@ -119,20 +140,23 @@ func (g *Generator) buildSecurityConfigs(schema *openapi3.Swagger) error { if err != nil { return err } + instanceVal[jen.Id("OpenIdConnectURL")] = jen.Lit(url) } } case "apiKey": - pkgName = pkgApiKey + pkgName = pkgAPIKey instanceVal[jen.Id("Description")] = jen.Lit(value.Value.Description) instanceVal[jen.Id("In")] = jen.Lit(value.Value.In) instanceVal[jen.Id("Name")] = jen.Lit(value.Value.Name) default: return errors.New("security schema type not supported: " + value.Value.Type) } + caser := cases.Title(language.Und, cases.NoLower) g.goSource.Var().Id("cfg"+caser.String(name)).Op("=").Op("&").Qual(pkgName, "Config").Values(instanceVal) } + return nil } @@ -142,10 +166,13 @@ func getValuesFromFlow(flow *openapi3.OAuthFlow) jen.Dict { r[jen.Id("AuthorizationURL")] = jen.Lit(flow.AuthorizationURL) r[jen.Id("TokenURL")] = jen.Lit(flow.TokenURL) r[jen.Id("RefreshURL")] = jen.Lit(flow.RefreshURL) + scopes := jen.Dict{} for scope, descr := range flow.Scopes { scopes[jen.Lit(scope)] = jen.Lit(descr) } + r[jen.Id("Scopes")] = jen.Map(jen.String()).String().Values(scopes) + return r } diff --git a/http/jsonapi/generator/generate_test.go b/http/jsonapi/generator/generate_test.go index f4569a28d..0b3da63b4 100644 --- a/http/jsonapi/generator/generate_test.go +++ b/http/jsonapi/generator/generate_test.go @@ -30,20 +30,24 @@ func TestGenerator(t *testing.T) { } g := Generator{} + result, err := g.BuildSource(testCase.source, filepath.Dir(testCase.pkg), filepath.Base(testCase.pkg)) if err != nil { t.Fatal(err) } + if os.Getenv("PACE_TEST_GENERATOR_WRITE") != "" { f, err := os.Create(fmt.Sprintf("testout/test.%s.out.go", testCase.pkg)) if err != nil { t.Fatal(err) } + _, err = f.WriteString(result) if err != nil { t.Fatal(err) } } + if string(expected[:]) != result { diff := difflib.UnifiedDiff{ A: difflib.SplitLines(string(expected[:])), @@ -53,7 +57,7 @@ func TestGenerator(t *testing.T) { Context: 3, } text, _ := difflib.GetUnifiedDiffString(diff) - t.Errorf(text) + t.Error(text) } }) } diff --git a/http/jsonapi/generator/generate_types.go b/http/jsonapi/generator/generate_types.go index 1b0b95ef9..233d3babe 100644 --- a/http/jsonapi/generator/generate_types.go +++ b/http/jsonapi/generator/generate_types.go @@ -26,6 +26,7 @@ func (g *Generator) BuildTypes(schema *openapi3.Swagger) error { for k := range schemas { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) for _, name := range keys { @@ -55,65 +56,78 @@ func (g *Generator) BuildTypes(schema *openapi3.Swagger) error { return nil } -func (g *Generator) buildType(prefix string, stmt *jen.Statement, schema *openapi3.SchemaRef, tags map[string]string, ptr bool) error { // nolint: gocyclo +func (g *Generator) buildType(prefix string, stmt *jen.Statement, schema *openapi3.SchemaRef, tags map[string]string, ptr bool) error { name := nameFromSchemaRef(schema) val := schema.Value switch val.Type { - case "array": // nolint: goconst + case "array": if schema.Ref != "" { // handle references stmt.Id(name) return nil } g.generatedArrayTypes[prefix] = true + return g.buildType(prefix, stmt.Index(), val.Items, tags, ptr) - case "object": // nolint: goconst + case "object": if schema.Ref != "" { // handle references if ptr { stmt.Op("*").Id(name) } else { stmt.Id(name) } + return nil } + if val.AdditionalPropertiesAllowed != nil && *val.AdditionalPropertiesAllowed { if len(val.Properties) > 0 { log.Warnf("%s properties are ignored. Only %s of type map[string]interface{} is generated ", prefix, prefix) } + stmt.Map(jen.String()).Interface() + return nil } + if val.AdditionalProperties != nil { if len(val.Properties) > 0 { log.Warnf("%s properties are ignored. Only %s of type map[string]type is generated ", prefix, prefix) } + stmt.Map(jen.String()) + if val.AdditionalProperties.Ref != "" { stmt.Op("*").Id(nameFromSchemaRef(val.AdditionalProperties)) return nil } + if val.AdditionalProperties.Value != nil { err := g.goType(stmt, val.AdditionalProperties.Value, make(map[string]string)).invoke() if err != nil { return err } } + return nil } if data := val.Properties["data"]; data != nil { if data.Ref != "" { return g.buildType(prefix+"Ref", stmt, data, make(map[string]string), ptr) - } else if data.Value.Type == "array" { // nolint: goconst + } else if data.Value.Type == "array" { item := prefix + "Item" if ptr { stmt.Index().Op("*").Id(item) } else { stmt.Index().Id(item) } + g.addGoDoc(item, data.Value.Description) + itemStmt := g.goSource.Type().Id(item) + return g.structJSONAPI(prefix, itemStmt, data.Value.Items.Value) } else if data.Value.Type == "object" { // This ensures that the code does only treat objects with data properties that // are objects themselves as legitimate JSONAPI struct, otherwise we want them to be treated as simple data objects. @@ -141,6 +155,7 @@ func (g *Generator) buildType(prefix string, stmt *jen.Statement, schema *openap if len(val.AllOf)+len(val.AnyOf)+len(val.OneOf) > 0 { log.Warnf("Can't generate allOf, anyOf and oneOf for type %q", prefix) stmt.Qual("encoding/json", "RawMessage") + return nil } @@ -172,6 +187,7 @@ func (g *Generator) buildTypeStruct(name string, stmt *jen.Statement, schema *op } else { stmt.Id(name) } + return nil } @@ -191,11 +207,12 @@ func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3. // in case the type referenced is defined already directly reference it sv := schema.Value - if sv.Type == "object" && sv.Properties["data"] != nil && sv.Properties["data"].Ref != "" { // nolint: goconst + if sv.Type == "object" && sv.Properties["data"] != nil && sv.Properties["data"].Ref != "" { id := nameFromSchemaRef(schema.Value.Properties["data"]) if g.generatedArrayTypes[id] { return jen.Id(id), nil } + if noPtr { return jen.Id(id), nil } @@ -207,11 +224,13 @@ func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3. t, ok := g.newType(fallbackName) if ok { g.addGoDoc(fallbackName, schema.Value.Description) + err := g.buildType(fallbackName, g.goSource.Add(t), schema, make(map[string]string), true) if err != nil { return nil, err } } + if noPtr { return jen.Id(fallbackName), nil } @@ -219,7 +238,7 @@ func (g *Generator) generateTypeReference(fallbackName string, schema *openapi3. return jen.Op("*").Id(fallbackName), nil } -func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *openapi3.Schema) error { // nolint: gocyclo +func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *openapi3.Schema) error { var fields []jen.Code propID := schema.Properties["id"] @@ -234,6 +253,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { return err } + fields = append(fields, id) // add attributes @@ -242,6 +262,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { return err } + fields = append(fields, attrFields...) } @@ -249,10 +270,12 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op links := schema.Properties["links"] if links != nil { linksAttr := jen.Id("Links") + err := g.buildTypeStruct(prefix+"Links", linksAttr, links.Value, true) if err != nil { return err } + fields = append(fields, linksAttr) } @@ -265,8 +288,10 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { log.Fatal(err) } + metaAttr.Comment("Resource meta data (json:api meta)") }() + fields = append(fields, metaAttr) } @@ -276,6 +301,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op if err != nil { return err } + fields = append(fields, relFields...) } @@ -283,10 +309,7 @@ func (g *Generator) structJSONAPI(prefix string, stmt *jen.Statement, schema *op // generate meta function if any if meta != nil { - err := g.generateJSONAPIMeta(prefix, stmt, meta.Value) - if err != nil { - return err - } + g.generateJSONAPIMeta(prefix, stmt, meta.Value) } return nil @@ -299,26 +322,32 @@ func (g *Generator) generateAttrField(prefix, name string, schema *openapi3.Sche if err != nil { return nil, err } + field.Tag(tags) + if schema.Ref == "" { g.commentOrExample(field, schema.Value) } + return field, nil } -func (g *Generator) generateStructFields(prefix string, schema *openapi3.Schema, jsonAPIObject bool) ([]jen.Code, error) { +func (g *Generator) generateStructFields(prefix string, schema *openapi3.Schema, _ bool) ([]jen.Code, error) { // sort by key keys := make([]string, 0, len(schema.Properties)) for k := range schema.Properties { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) - var fields []jen.Code + fields := make([]jen.Code, 0) + for _, attrName := range keys { attrSchema := schema.Properties[attrName] tags := make(map[string]string) addJSONAPITags(tags, "attr", attrName) + if attrSchema.Value.AdditionalPropertiesAllowed != nil && *attrSchema.Value.AdditionalPropertiesAllowed || attrSchema.Value.AdditionalProperties != nil { @@ -332,20 +361,24 @@ func (g *Generator) generateStructFields(prefix string, schema *openapi3.Schema, if err != nil { return nil, err } + fields = append(fields, field) } + return fields, nil } -func (g *Generator) generateStructRelationships(prefix string, schema *openapi3.Schema, jsonAPI bool) ([]jen.Code, error) { +func (g *Generator) generateStructRelationships(prefix string, schema *openapi3.Schema, _ bool) ([]jen.Code, error) { // sort by key keys := make([]string, 0, len(schema.Properties)) for k := range schema.Properties { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) - var relationships []jen.Code + relationships := make([]jen.Code, 0) + for _, relName := range keys { relSchema := schema.Properties[relName] tags := make(map[string]string) @@ -363,22 +396,23 @@ func (g *Generator) generateStructRelationships(prefix string, schema *openapi3. switch data.Value.Type { // case array = one-to-many - case "array": // nolint: goconst + case "array": name := data.Value.Items.Value.Properties["type"].Value.Enum[0].(string) rel.Index().Op("*").Id(goNameHelper(name)).Tag(tags) // case object = belongs-to - case "object": // nolint: goconst + case "object": name := data.Value.Properties["type"].Value.Enum[0].(string) rel.Op("*").Id(goNameHelper(name)).Tag(tags) } relationships = append(relationships, rel) } + return relationships, nil } // generateJSONAPIMeta generates a function that implements JSONAPIMeta -func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, schema *openapi3.Schema) error { +func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, schema *openapi3.Schema) { stmt.Line().Comment("JSONAPIMeta implements the meta data API for json:api").Line(). Func().Params(jen.Id("r").Op("*").Id(typeName)).Id("JSONAPIMeta").Params().Op("*").Qual(pkgJSONAPI, "Meta").BlockFunc( func(g *jen.Group) { @@ -391,6 +425,7 @@ func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, sc for k := range schema.Properties { keys = append(keys, k) } + sort.Stable(sort.StringSlice(keys)) for _, attrName := range keys { @@ -399,8 +434,6 @@ func (g *Generator) generateJSONAPIMeta(typeName string, stmt *jen.Statement, sc g.Return(jen.Op("&").Id("meta")) }) - - return nil } func (g *Generator) generateIDField(idType, objectType *openapi3.Schema) (*jen.Statement, error) { @@ -408,13 +441,16 @@ func (g *Generator) generateIDField(idType, objectType *openapi3.Schema) (*jen.S tags := map[string]string{ "jsonapi": fmt.Sprintf("primary,%s,omitempty", objectType.Enum[0]), } + err := g.goType(id, idType, tags).invoke() if err != nil { return nil, err } + addValidator(tags, "optional") id.Tag(tags) g.commentOrExample(id, idType) + return id, nil } @@ -424,13 +460,16 @@ func (g *Generator) newType(name string) (*jen.Statement, bool) { if g.generatedTypes[name] { return nil, false } + g.generatedTypes[name] = true + return jen.Type().Id(name), true } func addRequiredOptionalTag(tags map[string]string, name string, schema *openapi3.Schema) { // check if field is required isRequired := false + for _, required := range schema.Required { if required == name { isRequired = true @@ -455,6 +494,7 @@ func removeOmitempty(tags map[string]string) { if v, ok := tags["jsonapi"]; ok { tags["jsonapi"] = strings.ReplaceAll(v, ",omitempty", "") } + if v, ok := tags["json"]; ok { tags["json"] = strings.ReplaceAll(v, ",omitempty", "") } diff --git a/http/jsonapi/generator/internal/fueling/fueling_test.go b/http/jsonapi/generator/internal/fueling/fueling_test.go index d97d46708..b0d4534ac 100644 --- a/http/jsonapi/generator/internal/fueling/fueling_test.go +++ b/http/jsonapi/generator/internal/fueling/fueling_test.go @@ -3,6 +3,7 @@ package fueling import ( "context" "io" + "net/http" "net/http/httptest" "strings" "testing" @@ -36,7 +37,7 @@ func (t *testService) WaitOnPumpStatusChange(context.Context, WaitOnPumpStatusCh func TestErrorReporting(t *testing.T) { r := Router(&testService{t}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/fueling/beta/gas-stations/d7101f72-a672-453c-9d36-d5809ef0ded6/approaching", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/fueling/beta/gas-stations/d7101f72-a672-453c-9d36-d5809ef0ded6/approaching", strings.NewReader(`{ "data": { "type": "approaching", "id": "c3f037ea-492e-4033-9b4b-4efc7beca16c", @@ -53,8 +54,9 @@ func TestErrorReporting(t *testing.T) { resp := rec.Result() defer resp.Body.Close() + b, _ := io.ReadAll(resp.Body) - require.Equalf(t, 422, resp.StatusCode, "expected 422 got: %s", string(b)) + require.Equalf(t, http.StatusUnprocessableEntity, resp.StatusCode, "expected 422 got: %s", string(b)) assert.Contains(t, string(b), `can't parse content: got value \"47.8\" expected type float32: Invalid type provided`) } diff --git a/http/jsonapi/generator/internal/pay/pay_test.go b/http/jsonapi/generator/internal/pay/pay_test.go index 19aa98c9a..3fdd94cd6 100644 --- a/http/jsonapi/generator/internal/pay/pay_test.go +++ b/http/jsonapi/generator/internal/pay/pay_test.go @@ -31,6 +31,7 @@ func (s *testService) CreatePaymentMethodSEPA(ctx context.Context, w CreatePayme if str := "Jon"; r.Content.FirstName != str { s.t.Errorf("expected FirstName to be %q, got %q", str, r.Content.FirstName) } + if str := "Haid-und-Neu-Str."; r.Content.Address.Street != str { s.t.Errorf("expected Address.Street to be %q, got %q", str, r.Content.Address.Street) } @@ -76,6 +77,7 @@ func (s *testService) ProcessPayment(ctx context.Context, w ProcessPaymentRespon if r.Content.PriceIncludingVAT.String() != "69.34" { s.t.Errorf(`expected priceIncludingVAT "69.34", got %q`, r.Content.PriceIncludingVAT) } + amount := decimal.RequireFromString("11.07") rate := decimal.RequireFromString("19.0") priceWithVat := decimal.RequireFromString("69.34") @@ -139,7 +141,7 @@ func (s testAuthBackend) InitProfileKey(cfgProfileKey *apikey.Config) { func TestHandler(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/pay/beta/payment-methods/sepa-direct-debit", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/pay/beta/payment-methods/sepa-direct-debit", strings.NewReader(`{ "data": { "id": "2a1319c3-c136-495d-b59a-47b3246d08af", "type": "paymentMethod", @@ -166,12 +168,14 @@ func TestHandler(t *testing.T) { resp := rec.Result() defer resp.Body.Close() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { t.Errorf("expected OK got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } @@ -188,7 +192,7 @@ func TestHandler(t *testing.T) { func TestHandlerDecimal(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/pay/beta/transaction/1337.42?queryDecimal=123.456", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/pay/beta/transaction/1337.42?queryDecimal=123.456", strings.NewReader(`{ "data": { "id": "5d3607f4-7855-4bfc-b926-1e662c225f06", "type": "transaction", @@ -213,12 +217,14 @@ func TestHandlerDecimal(t *testing.T) { resp := rec.Result() defer resp.Body.Close() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { t.Errorf("expected OK got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } @@ -242,7 +248,7 @@ func assertDecimal(t *testing.T, got, want decimal.Decimal) { func TestHandlerPanic(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/pay/beta/payment-methods?include=paymentToken", nil) + req := httptest.NewRequest(http.MethodGet, "/pay/beta/payment-methods?include=paymentToken", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) @@ -253,10 +259,12 @@ func TestHandlerPanic(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Errorf("expected 500 got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } } @@ -264,7 +272,7 @@ func TestHandlerPanic(t *testing.T) { func TestHandlerError(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/pay/beta/payment-methods", nil) + req := httptest.NewRequest(http.MethodGet, "/pay/beta/payment-methods", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) @@ -275,10 +283,12 @@ func TestHandlerError(t *testing.T) { if resp.StatusCode != http.StatusInternalServerError { t.Errorf("expected 500 got: %d", resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } } diff --git a/http/jsonapi/generator/internal/poi/poi_test.go b/http/jsonapi/generator/internal/poi/poi_test.go index da132ec1d..b364bc3b2 100644 --- a/http/jsonapi/generator/internal/poi/poi_test.go +++ b/http/jsonapi/generator/internal/poi/poi_test.go @@ -35,9 +35,11 @@ func (s *testService) CheckForPaceApp(ctx context.Context, w CheckForPaceAppResp if r.ParamFilterLatitude != 41.859194 { s.t.Errorf("expected ParamLatitude to be %f, got: %f", 41.859194, r.ParamFilterLatitude) } + if r.ParamFilterLongitude != -87.646984 { s.t.Errorf("expected ParamLongitude to be %f, got: %f", -87.646984, r.ParamFilterLatitude) } + if r.ParamFilterAppType != "fueling" { s.t.Errorf("expected ParamAppType to be %q, got: %q", "fueling", r.ParamFilterAppType) } @@ -225,7 +227,7 @@ func (s testAuthBackend) InitOIDC(cfgOIDC *oidc.Config) {} func TestHandler(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/poi/beta/apps/query?"+ + req := httptest.NewRequest(http.MethodGet, "/poi/beta/apps/query?"+ "filter[latitude]=41.859194&filter[longitude]=-87.646984&filter[appType]=fueling", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) @@ -233,41 +235,51 @@ func TestHandler(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { t.Errorf("expected OK got: %d", resp.StatusCode) t.Error(rec.Body.String()) + return } var data struct { Data []map[string]interface{} `json:"data"` } + err := json.NewDecoder(resp.Body).Decode(&data) if err != nil { t.Fatal(err) return } + if len(data.Data) != 10 { t.Error("Expected 10 apps") return } + if data.Data[0]["type"] != "locationBasedAppWithRefs" { t.Error("Expected type locationBasedAppWithRefs") return } + attributes, ok := data.Data[0]["attributes"].(map[string]interface{}) if !ok { t.Error("Expected attributes do be present") return } + if attributes["androidInstantAppUrl"] != "https://foobar.com" { t.Error(`Expected androidInstantAppUrl to be "https://foobar.com"`) } + if attributes["title"] != "some app" { t.Error(`Expected androidInstantAppUrl to be "some app"`) } + if attributes["appType"] != "some type" { t.Error(`Expected androidInstantAppUrl to be "some type"`) } @@ -276,48 +288,58 @@ func TestHandler(t *testing.T) { func TestHandlerWithTimeInQuery(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/poi/beta/apps?filter[since]=2020-05-06T12%3A22%3A54%2E000888456", nil) + req := httptest.NewRequest(http.MethodGet, "/poi/beta/apps?filter[since]=2020-05-06T12%3A22%3A54%2E000888456", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) req.Header.Set("Content-Type", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { t.Errorf("expected OK got: %d", resp.StatusCode) t.Error(rec.Body.String()) + return } var data struct { Data []map[string]interface{} `json:"data"` } + err := json.NewDecoder(resp.Body).Decode(&data) if err != nil { t.Fatal(err) return } + if len(data.Data) != 10 { t.Error("Expected 10 apps") return } + if data.Data[0]["type"] != "locationBasedApp" { t.Error("Expected type locationBasedApp") return } + attributes, ok := data.Data[0]["attributes"].(map[string]interface{}) if !ok { t.Error("Expected attributes do be present") return } + if attributes["androidInstantAppUrl"] != "https://foobar.com" { t.Error(`Expected androidInstantAppUrl to be "https://foobar.com"`) } + if attributes["title"] != "some app" { t.Error(`Expected androidInstantAppUrl to be "some app"`) } + if attributes["appType"] != "some type" { t.Error(`Expected androidInstantAppUrl to be "some type"`) } @@ -326,7 +348,7 @@ func TestHandlerWithTimeInQuery(t *testing.T) { func TestCreatePolicyHandler(t *testing.T) { r := Router(&testService{t}, &testAuthBackend{}) rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/poi/beta/policies", strings.NewReader(`{ + req := httptest.NewRequest(http.MethodPost, "/poi/beta/policies", strings.NewReader(`{ "data": { "id": "f106ac99-213c-4cf7-8c1b-1e841516026b", "type": "policies", @@ -355,11 +377,14 @@ func TestCreatePolicyHandler(t *testing.T) { r.ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 200 { + if resp.StatusCode != http.StatusOK { t.Errorf("expected OK got: %d", resp.StatusCode) t.Error(rec.Body.String()) + return } } diff --git a/http/jsonapi/generator/internal/securitytest/security_test.go b/http/jsonapi/generator/internal/securitytest/security_test.go index 5b9d3cfa9..5bdb07d47 100644 --- a/http/jsonapi/generator/internal/securitytest/security_test.go +++ b/http/jsonapi/generator/internal/securitytest/security_test.go @@ -8,9 +8,10 @@ import ( "net/http/httptest" "testing" + "github.com/stretchr/testify/require" + "github.com/pace/bricks/http/oauth2" "github.com/pace/bricks/http/security/apikey" - "github.com/stretchr/testify/require" ) type testService struct{} @@ -57,39 +58,47 @@ func TestSecurityBothAuthenticationMethods(t *testing.T) { // oauth2 OK, profileKey OK, canAuth: both w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r := httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result := w.Result() require.Equal(t, http.StatusOK, result.StatusCode) + _ = result.Body.Close() // oauth2 ok, profileKey OK, canAuth: none authBackend.canAuthProfileKey = false authBackend.canAuthOauth = false w = httptest.NewRecorder() - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() require.Equal(t, http.StatusUnauthorized, result.StatusCode) + _ = result.Body.Close() // oauth2 400, profileKey OK, canAuth = oauth2 authBackend.canAuthProfileKey = false authBackend.canAuthOauth = true w = httptest.NewRecorder() authBackend.oauth2Code = http.StatusBadRequest - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() require.Equal(t, http.StatusBadRequest, result.StatusCode) + _ = result.Body.Close() // oauth2 400, profileKey OK, canAuth = profileKey authBackend.canAuthProfileKey = true authBackend.canAuthOauth = false w = httptest.NewRecorder() authBackend.oauth2Code = http.StatusBadRequest - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() require.Equal(t, http.StatusOK, result.StatusCode) + _ = result.Body.Close() // oauth2 400, profileKey 500, canAuth = both w = httptest.NewRecorder() @@ -97,9 +106,11 @@ func TestSecurityBothAuthenticationMethods(t *testing.T) { authBackend.oauth2Code = http.StatusBadRequest authBackend.canAuthProfileKey = true authBackend.canAuthOauth = true - r = httptest.NewRequest("GET", "http://test.de/pay/beta/test", nil) + r = httptest.NewRequest(http.MethodGet, "http://test.de/pay/beta/test", nil) router.ServeHTTP(w, r) + result = w.Result() // Alphabetic order => get the error of the alphabetic first security scheme require.Equal(t, http.StatusBadRequest, result.StatusCode) + _ = result.Body.Close() } diff --git a/http/jsonapi/generator/route.go b/http/jsonapi/generator/route.go index 0b7b32935..11b9a49a1 100644 --- a/http/jsonapi/generator/route.go +++ b/http/jsonapi/generator/route.go @@ -24,7 +24,9 @@ func (r *route) parseURL() (err error) { if err != nil { return } + r.queryValues = r.url.Query() // cache query values + return } @@ -45,9 +47,11 @@ func (l *sortableRouteList) Less(i, j int) bool { if a, b := pathLen(elemI.url.Path), pathLen(elemJ.url.Path); a != b { return a > b } + if a, b := strings.Count(elemJ.url.Path, "{"), strings.Count(elemI.url.Path, "{"); a != b { return a > b } + return len(elemI.queryValues) > len(elemJ.queryValues) } diff --git a/http/jsonapi/generator/route_test.go b/http/jsonapi/generator/route_test.go index cec8bc409..2eb7e53e6 100644 --- a/http/jsonapi/generator/route_test.go +++ b/http/jsonapi/generator/route_test.go @@ -35,16 +35,21 @@ func TestSortableRouteList(t *testing.T) { "/beta/receipts/{transactionID}.{fileFormat}", } list := make(sortableRouteList, len(paths)) + for i, path := range paths { route := &route{pattern: path} require.NoError(t, route.parseURL()) + list[i] = route } + sort.Stable(&list) + actual := make([]string, len(paths)) for i, route := range list { actual[i] = route.pattern } + assert.Equal(t, []string{ "/beta/payment-method-kinds/applepay/authorize", "/beta/payment-methods/{paymentMethodId}/notification", diff --git a/http/jsonapi/middleware/error_middleware.go b/http/jsonapi/middleware/error_middleware.go index 2822b40a5..63f2daadc 100644 --- a/http/jsonapi/middleware/error_middleware.go +++ b/http/jsonapi/middleware/error_middleware.go @@ -22,15 +22,19 @@ func (e *errorMiddleware) Write(b []byte) (int, error) { log.Req(e.req).Warn().Msgf("Error already sent, ignoring: %q", string(b)) return 0, nil } - repliesJsonApi := e.Header().Get("Content-Type") == runtime.JSONAPIContentType - requestsJsonApi := e.req.Header.Get("Accept") == runtime.JSONAPIContentType - if e.statusCode >= 400 && requestsJsonApi && !repliesJsonApi { + + repliesJSONAPI := e.Header().Get("Content-Type") == runtime.JSONAPIContentType + requestsJSONAPI := e.req.Header.Get("Accept") == runtime.JSONAPIContentType + + if e.statusCode >= 400 && requestsJSONAPI && !repliesJSONAPI { if e.hasBytes { log.Req(e.req).Warn().Msgf("Body already contains data from previous writes: ignoring: %q", string(b)) return 0, nil } + e.hasErr = true runtime.WriteError(e.ResponseWriter, e.statusCode, errors.New(strings.Trim(string(b), "\n"))) + return 0, nil } @@ -38,6 +42,7 @@ func (e *errorMiddleware) Write(b []byte) (int, error) { if err == nil && n > 0 { e.hasBytes = true } + return n, err } @@ -46,7 +51,7 @@ func (e *errorMiddleware) WriteHeader(code int) { e.ResponseWriter.WriteHeader(code) } -// ErrorMiddleware is a middleware that wraps http.ResponseWriter +// Error is a middleware that wraps http.ResponseWriter // such that it forces responses with status codes 4xx/5xx to have // Content-Type: application/vnd.api+json func Error(next http.Handler) http.Handler { diff --git a/http/jsonapi/middleware/error_middleware_test.go b/http/jsonapi/middleware/error_middleware_test.go index fb0c80a26..d92bbdedb 100644 --- a/http/jsonapi/middleware/error_middleware_test.go +++ b/http/jsonapi/middleware/error_middleware_test.go @@ -8,31 +8,34 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/pace/bricks/http/jsonapi/runtime" ) const payload = "dummy response data" func TestErrorMiddleware(t *testing.T) { - for _, statusCode := range []int{200, 201, 400, 402, 500, 503} { + for _, statusCode := range []int{http.StatusOK, http.StatusCreated, http.StatusBadRequest, 402, 500, 503} { for _, responseContentType := range []string{"text/plain", "text/html", runtime.JSONAPIContentType} { r := mux.NewRouter() r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", responseContentType) w.WriteHeader(statusCode) _, _ = io.WriteString(w, payload) - }).Methods("GET") + }).Methods(http.MethodGet) r.Use(Error) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) resp := rec.Result() b, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { t.Fatal(err) } @@ -40,6 +43,7 @@ func TestErrorMiddleware(t *testing.T) { if statusCode != resp.StatusCode { t.Fatalf("status codes differ: expected %v, got %v", statusCode, resp.StatusCode) } + if resp.StatusCode < 400 || responseContentType == runtime.JSONAPIContentType { if payload != string(b) { t.Fatalf("payloads differ: expected %v, got %v", payload, string(b)) @@ -53,9 +57,11 @@ func TestErrorMiddleware(t *testing.T) { if err != nil { t.Fatal(err) } + if len(e.List) != 1 { t.Fatalf("expected only one record, got %v", len(e.List)) } + if payload != e.List[0].Title { t.Fatalf("error titles differ: expected %v, got %v", payload, e.List[0].Title) } @@ -67,7 +73,7 @@ func TestErrorMiddleware(t *testing.T) { func TestJsonApiErrorMiddlewareMultipleErrorWrite(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(400) + w.WriteHeader(http.StatusBadRequest) w.Header().Set("Content-Type", "text/html") if _, err := io.WriteString(w, payload); err != nil { t.Fatal(err) @@ -81,28 +87,35 @@ func TestJsonApiErrorMiddlewareMultipleErrorWrite(t *testing.T) { if _, err := io.WriteString(w, payload); err != nil { t.Fatal(err) } - }).Methods("GET") + }).Methods(http.MethodGet) r.Use(Error) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) + resp := rec.Result() b, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { t.Fatal(err) } + var e struct { List runtime.Errors `json:"errors"` } + if err := json.Unmarshal(b, &e); err != nil { t.Fatal(err) } + if len(e.List) != 1 { t.Fatalf("expected only one record, got %v", len(e.List)) } + if payload != e.List[0].Title { t.Fatalf("error titles differ: expected %v, got %v", payload, e.List[0].Title) } @@ -111,7 +124,7 @@ func TestJsonApiErrorMiddlewareMultipleErrorWrite(t *testing.T) { func TestJsonApiErrorMiddlewareInvalidWriteOrder(t *testing.T) { r := mux.NewRouter() r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) if _, err := io.WriteString(w, payload); err != nil { t.Fatal(err) } @@ -119,22 +132,26 @@ func TestJsonApiErrorMiddlewareInvalidWriteOrder(t *testing.T) { if ok && !jsonWriter.hasBytes { t.Fatal("expected hasBytes flag to be set") } - w.WriteHeader(400) + w.WriteHeader(http.StatusBadRequest) w.Header().Set("Content-Type", "text/plain") _, _ = io.WriteString(w, payload) // will get discarded - }).Methods("GET") + }).Methods(http.MethodGet) r.Use(Error) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Accept", runtime.JSONAPIContentType) r.ServeHTTP(rec, req) + resp := rec.Result() b, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { t.Fatal(err) } + if payload != string(b) { t.Fatalf("bad response body, expected %q, got %q", payload, string(b)) } diff --git a/http/jsonapi/models_test.go b/http/jsonapi/models_test.go index 3ebd8e9e9..0f4dce073 100644 --- a/http/jsonapi/models_test.go +++ b/http/jsonapi/models_test.go @@ -113,6 +113,7 @@ func (b *Blog) JSONAPIRelationshipLinks(relation string) *Links { }, } } + if relation == "current_post" { return &Links{ "self": fmt.Sprintf("https://example.com/api/posts/%s", "3"), @@ -121,6 +122,7 @@ func (b *Blog) JSONAPIRelationshipLinks(relation string) *Links { }, } } + return nil } @@ -146,11 +148,13 @@ func (b *Blog) JSONAPIRelationshipMeta(relation string) *Meta { }, } } + if relation == "current_post" { return &Meta{ "detail": "extra current_post detail", } } + return nil } diff --git a/http/jsonapi/node.go b/http/jsonapi/node.go index 46b6f3cfd..214977bd2 100644 --- a/http/jsonapi/node.go +++ b/http/jsonapi/node.go @@ -88,6 +88,7 @@ func (l *Links) validate() (err error) { ) } } + return } diff --git a/http/jsonapi/request.go b/http/jsonapi/request.go index ab331cb43..be0598936 100644 --- a/http/jsonapi/request.go +++ b/http/jsonapi/request.go @@ -18,22 +18,22 @@ import ( ) const ( - unsupportedStructTagMsg = "Unsupported jsonapi tag annotation, %s" + unsupportedStructTagMsg = "unsupported jsonapi tag annotation, %s" ) var ( // ErrInvalidTime is returned when a struct has a time.Time type field, but // the JSON value was not a unix timestamp integer. - ErrInvalidTime = errors.New("Only numbers can be parsed as dates, unix timestamps") + ErrInvalidTime = errors.New("only numbers can be parsed as dates, unix timestamps") // ErrInvalidISO8601 is returned when a struct has a time.Time type field and includes // "iso8601" in the tag spec, but the JSON value was not an ISO8601 timestamp string. - ErrInvalidISO8601 = errors.New("Only strings can be parsed as dates, ISO8601 timestamps") + ErrInvalidISO8601 = errors.New("only strings can be parsed as dates, ISO8601 timestamps") // ErrUnknownFieldNumberType is returned when the JSON value was a float // (numeric) but the Struct field was a non numeric type (i.e. not int, uint, // float, etc) - ErrUnknownFieldNumberType = errors.New("The struct field was not of a known number type") + ErrUnknownFieldNumberType = errors.New("the struct field was not of a known number type") // ErrInvalidType is returned when the given type is incompatible with the expected type. - ErrInvalidType = errors.New("Invalid type provided") // I wish we used punctuation. + ErrInvalidType = errors.New("invalid type provided") // I wish we used punctuation. ) @@ -48,9 +48,11 @@ type ErrUnsupportedPtrType struct { func (eupt ErrUnsupportedPtrType) Error() string { typeName := eupt.t.Elem().Name() kind := eupt.t.Elem().Kind() + if kind.String() != "" && kind.String() != typeName { typeName = fmt.Sprintf("%s (%s)", typeName, kind.String()) } + return fmt.Sprintf( "jsonapi: Can't unmarshal %+v (%s) to struct field `%s`, which is a pointer to `%s`", eupt.rf, eupt.rf.Type().Kind(), eupt.structField.Name, typeName, @@ -83,7 +85,7 @@ func newErrUnsupportedPtrType(rf reflect.Value, t reflect.Type, structField refl // // ...do stuff with your blog... // // w.Header().Set("Content-Type", jsonapi.MediaType) -// w.WriteHeader(201) +// w.WriteHeader(http.StatusCreated) // // if err := jsonapi.MarshalPayload(w, blog); err != nil { // http.Error(w, err.Error(), 500) @@ -102,6 +104,7 @@ func UnmarshalPayload(in io.Reader, model interface{}) error { if payload.Included != nil { includedMap := make(map[string]*Node) + for _, included := range payload.Included { key := fmt.Sprintf("%s,%s", included.Type, included.ID) includedMap[key] = included @@ -109,6 +112,7 @@ func UnmarshalPayload(in io.Reader, model interface{}) error { return unmarshalNode(payload.Data, reflect.ValueOf(model), &includedMap) } + return unmarshalNode(payload.Data, reflect.ValueOf(model), nil) } @@ -133,10 +137,12 @@ func UnmarshalManyPayload(in io.Reader, t reflect.Type) ([]interface{}, error) { for _, data := range payload.Data { model := reflect.New(t.Elem()) + err := unmarshalNode(data, model, &includedMap) if err != nil { return nil, err } + models = append(models, model.Interface()) } @@ -157,6 +163,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) for i := 0; i < modelValue.NumField(); i++ { fieldType := modelType.Field(i) + tag := fieldType.Tag.Get("jsonapi") if tag == "" { continue @@ -186,10 +193,11 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) // Check the JSON API Type if data.Type != args[1] { er = fmt.Errorf( - "Trying to Unmarshal an object of type %#v, but %#v does not match", + "trying to Unmarshal an object of type %#v, but %#v does not match", data.Type, args[1], ) + break } @@ -251,6 +259,7 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) } structField := fieldType + value, err := unmarshalAttribute(attribute, args, structField, fieldValue) if err != nil { er = err @@ -273,8 +282,13 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) buf := bytes.NewBuffer(nil) - json.NewEncoder(buf).Encode(data.Relationships[args[1]]) // nolint: errcheck - json.NewDecoder(buf).Decode(relationship) // nolint: errcheck + if err := json.NewEncoder(buf).Encode(data.Relationships[args[1]]); err != nil { + return err + } + + if err := json.NewDecoder(buf).Decode(relationship); err != nil { + return err + } data := relationship.Data models := reflect.New(fieldValue.Type()).Elem() @@ -301,10 +315,13 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) buf := bytes.NewBuffer(nil) - json.NewEncoder(buf).Encode( // nolint: errcheck - data.Relationships[args[1]], - ) - json.NewDecoder(buf).Decode(relationship) // nolint: errcheck + if err := json.NewEncoder(buf).Encode(data.Relationships[args[1]]); err != nil { + return err + } + + if err := json.NewDecoder(buf).Decode(relationship); err != nil { + return err + } /* http://jsonapi.org/format/#document-resource-object-relationships @@ -357,6 +374,7 @@ func assign(field, value reflect.Value) { // initialize pointer so it's value // can be set by assignValue field.Set(reflect.New(field.Type().Elem())) + field = field.Elem() } @@ -392,6 +410,7 @@ func unmarshalAttribute( fieldValue reflect.Value, ) (value reflect.Value, err error) { var attribute interface{} + err = json.Unmarshal(rawAttribute, &attribute) if err != nil { return reflect.Value{}, err @@ -430,6 +449,7 @@ func unmarshalAttribute( if fieldValue.Type().Kind() == reflect.Slice { value = reflect.New(fieldValue.Type()) err = json.Unmarshal(rawAttribute, value.Interface()) + return } @@ -456,6 +476,7 @@ func unmarshalAttribute( func handleDecimal(attribute json.RawMessage) (reflect.Value, error) { var dec decimal.Decimal + err := json.Unmarshal(attribute, &dec) if err != nil { return reflect.Value{}, fmt.Errorf("can't decode decimal from value %q: %v", string(attribute), err) @@ -466,6 +487,7 @@ func handleDecimal(attribute json.RawMessage) (reflect.Value, error) { func handleMapStringSlice(attribute json.RawMessage) (reflect.Value, error) { var m map[string][]string + err := json.Unmarshal(attribute, &m) if err != nil { return reflect.Value{}, fmt.Errorf("can't decode map string slice from value %q: %v", string(attribute), err) @@ -476,6 +498,7 @@ func handleMapStringSlice(attribute json.RawMessage) (reflect.Value, error) { func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) (reflect.Value, error) { var isIso8601 bool + v := reflect.ValueOf(attribute) if len(args) > 2 { @@ -584,12 +607,13 @@ func handleNumeric( func handlePointer( attribute interface{}, - args []string, + _ []string, fieldType reflect.Type, fieldValue reflect.Value, structField reflect.StructField, ) (reflect.Value, error) { t := fieldValue.Type() + var concreteVal reflect.Value if attribute == nil { @@ -605,11 +629,13 @@ func handlePointer( concreteVal = reflect.ValueOf(&cVal) case map[string]interface{}: var err error + concreteVal, err = handleStruct(attribute, fieldValue) if err != nil { return reflect.Value{}, newErrUnsupportedPtrType( reflect.ValueOf(attribute), fieldType, structField) } + return concreteVal, err default: return reflect.Value{}, newErrUnsupportedPtrType( diff --git a/http/jsonapi/request_test.go b/http/jsonapi/request_test.go index 9ace13911..4edde086f 100644 --- a/http/jsonapi/request_test.go +++ b/http/jsonapi/request_test.go @@ -17,6 +17,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshall_attrStringSlice(t *testing.T) { @@ -34,6 +35,7 @@ func TestUnmarshall_attrStringSlice(t *testing.T) { }, }, } + b, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -53,9 +55,11 @@ func TestUnmarshall_attrStringSlice(t *testing.T) { if out.Decimal1.String() != "9.9999999999999999999" { t.Fatalf("Expected json dec1 data to be %#v got: %#v", "9.9999999999999999999", out.Decimal1.String()) } + if out.Decimal2.String() != "9.9999999999999999999" { t.Fatalf("Expected json dec2 data to be %#v got: %#v", "9.9999999999999999999", out.Decimal2.String()) } + if out.Decimal3.String() != "10" { t.Fatalf("Expected json dec2 data to be %#v got: %#v", 10, out.Decimal3.String()) } @@ -104,6 +108,7 @@ func TestUnmarshall_MapStringSlice(t *testing.T) { for _, tc := range tcs { t.Run(tc.name, func(t *testing.T) { out := &Book{} + b, err := json.Marshal(tc.input) if err != nil { t.Fatal(err) @@ -123,18 +128,26 @@ func TestUnmarshalToStructWithPointerAttr(t *testing.T) { "int-val": json.RawMessage(`8`), "float-val": json.RawMessage(`1.1`), } - if err := UnmarshalPayload(sampleWithPointerPayload(in), out); err != nil { + + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + + if err := UnmarshalPayload(payload, out); err != nil { t.Fatal(err) } + if *out.Name != "The name" { t.Fatalf("Error unmarshalling to string ptr") } + if !*out.IsActive { t.Fatalf("Error unmarshalling to bool ptr") } + if *out.IntVal != 8 { t.Fatalf("Error unmarshalling to int ptr") } + if *out.FloatVal != 1.1 { t.Fatalf("Error unmarshalling to float ptr") } @@ -156,7 +169,10 @@ func TestUnmarshalPayloadWithPointerID(t *testing.T) { out := new(WithPointer) attrs := map[string]json.RawMessage{} - if err := UnmarshalPayload(sampleWithPointerPayload(attrs), out); err != nil { + payload, err := sampleWithPointerPayload(attrs) + require.NoError(t, err) + + if err := UnmarshalPayload(payload, out); err != nil { t.Fatalf("Error unmarshalling to Foo") } @@ -164,6 +180,7 @@ func TestUnmarshalPayloadWithPointerID(t *testing.T) { if out.ID == nil { t.Fatalf("Error unmarshalling; expected ID ptr to be not nil") } + if e, a := uint64(2), *out.ID; e != a { t.Fatalf("Was expecting the ID to have a value of %d, got %d", e, a) } @@ -176,7 +193,10 @@ func TestUnmarshalPayloadWithPointerAttr_AbsentVal(t *testing.T) { "is-active": json.RawMessage(`true`), } - if err := UnmarshalPayload(sampleWithPointerPayload(in), out); err != nil { + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + + if err := UnmarshalPayload(payload, out); err != nil { t.Fatalf("Error unmarshalling to Foo") } @@ -198,14 +218,18 @@ func TestUnmarshalToStructWithPointerAttr_BadType_bool(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal true (bool) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } + if _, ok := err.(ErrUnsupportedPtrType); !ok { t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } @@ -218,14 +242,18 @@ func TestUnmarshalToStructWithPointerAttr_BadType_MapPtr(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal map[a:5] (map) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } + if _, ok := err.(ErrUnsupportedPtrType); !ok { t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } @@ -238,14 +266,18 @@ func TestUnmarshalToStructWithPointerAttr_BadType_Struct(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal map[A:5] (map) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } + if _, ok := err.(ErrUnsupportedPtrType); !ok { t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } @@ -258,14 +290,18 @@ func TestUnmarshalToStructWithPointerAttr_BadType_IntSlice(t *testing.T) { } expectedErrorMessage := "jsonapi: Can't unmarshal [4 5] (slice) to struct field `Name`, which is a pointer to `string`" - err := UnmarshalPayload(sampleWithPointerPayload(in), out) + payload, err := sampleWithPointerPayload(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("Expected error due to invalid type.") } + if err.Error() != expectedErrorMessage { t.Fatalf("Unexpected error message: %s", err.Error()) } + if _, ok := err.(ErrUnsupportedPtrType); !ok { t.Fatalf("Unexpected error type: %s", reflect.TypeOf(err)) } @@ -285,6 +321,7 @@ func TestStringPointerField(t *testing.T) { }, }, } + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -299,6 +336,7 @@ func TestStringPointerField(t *testing.T) { if book.Description == nil { t.Fatal("Was not expecting a nil pointer for book.Description") } + if expected, actual := description, *book.Description; expected != actual { t.Fatalf("Was expecting descript to be `%s`, got `%s`", expected, actual) } @@ -306,7 +344,11 @@ func TestStringPointerField(t *testing.T) { func TestMalformedTag(t *testing.T) { out := new(BadModel) - err := UnmarshalPayload(samplePayload(), out) + + payload, err := samplePayload() + require.NoError(t, err) + + err = UnmarshalPayload(payload, out) if err == nil || err != ErrBadJSONAPIStructTag { t.Fatalf("Did not error out with wrong number of arguments in tag") } @@ -317,7 +359,6 @@ func TestUnmarshalInvalidJSON(t *testing.T) { out := new(Blog) err := UnmarshalPayload(in, out) - if err == nil { t.Fatalf("Did not error out the invalid JSON.") } @@ -341,11 +382,14 @@ func TestUnmarshalInvalidJSON_BadType(t *testing.T) { in[test.Field] = test.BadValue expectedErrorMessage := test.Error.Error() - err := UnmarshalPayload(samplePayloadWithBadTypes(in), out) + payload, err := samplePayloadWithBadTypes(in) + require.NoError(t, err) + err = UnmarshalPayload(payload, out) if err == nil { t.Fatalf("(Test %d) Expected error due to invalid type.", i+1) } + if err.Error() != expectedErrorMessage { t.Fatalf("(Test %d) Unexpected error message: %q \nexpected: %q", i+1, expectedErrorMessage, err.Error()) } @@ -354,7 +398,9 @@ func TestUnmarshalInvalidJSON_BadType(t *testing.T) { } func TestUnmarshalSetsID(t *testing.T) { - in := samplePayloadWithID() + in, err := samplePayloadWithID() + require.NoError(t, err) + out := new(Blog) if err := UnmarshalPayload(in, out); err != nil { @@ -369,10 +415,12 @@ func TestUnmarshalSetsID(t *testing.T) { func TestUnmarshal_nonNumericID(t *testing.T) { data := samplePayloadWithoutIncluded() data["data"].(map[string]interface{})["id"] = "non-numeric-id" + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) } + in := bytes.NewReader(payload) out := new(Post) @@ -411,6 +459,7 @@ func TestUnmarshalParsesISO8601(t *testing.T) { } in := bytes.NewBuffer(nil) + err := json.NewEncoder(in).Encode(payload) if err != nil { log.Fatal(err) @@ -440,6 +489,7 @@ func TestUnmarshalParsesISO8601TimePointer(t *testing.T) { } in := bytes.NewBuffer(nil) + err := json.NewEncoder(in).Encode(payload) if err != nil { t.Fatal(err) @@ -469,6 +519,7 @@ func TestUnmarshalInvalidISO8601(t *testing.T) { } in := bytes.NewBuffer(nil) + err := json.NewEncoder(in).Encode(payload) if err != nil { t.Fatal(err) @@ -486,6 +537,7 @@ func TestUnmarshalRelationshipsWithoutIncluded(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Post) @@ -536,6 +588,7 @@ func TestUnmarshalNullRelationship(t *testing.T) { }, }, } + data, err := json.Marshal(sample) if err != nil { t.Fatal(err) @@ -569,6 +622,7 @@ func TestUnmarshalNullRelationshipInSlice(t *testing.T) { }, }, } + data, err := json.Marshal(sample) if err != nil { t.Fatal(err) @@ -703,7 +757,10 @@ func TestUnmarshalNestedRelationshipsSideloaded(t *testing.T) { func TestUnmarshalNestedRelationshipsEmbedded_withClientIDs(t *testing.T) { model := new(Blog) - if err := UnmarshalPayload(samplePayload(), model); err != nil { + payload, err := samplePayload() + require.NoError(t, err) + + if err := UnmarshalPayload(payload, model); err != nil { t.Fatal(err) } @@ -713,7 +770,11 @@ func TestUnmarshalNestedRelationshipsEmbedded_withClientIDs(t *testing.T) { } func unmarshalSamplePayload() (*Blog, error) { - in := samplePayload() + in, err := samplePayload() + if err != nil { + return nil, err + } + out := new(Blog) if err := UnmarshalPayload(in, out); err != nil { @@ -749,6 +810,7 @@ func TestUnmarshalManyPayload(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) posts, err := UnmarshalManyPayload(in, reflect.TypeOf(new(Post))) @@ -805,6 +867,7 @@ func TestManyPayload_withLinks(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) payload := new(ManyPayload) @@ -822,6 +885,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := firstPageURL, first; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyFirstPage, e, a) } @@ -830,6 +894,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := prevPageURL, prev; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyPreviousPage, e, a) } @@ -838,6 +903,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := nextPageURL, next; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyNextPage, e, a) } @@ -846,6 +912,7 @@ func TestManyPayload_withLinks(t *testing.T) { if !ok { t.Fatal("Was expecting a non nil ptr Link field") } + if e, a := lastPageURL, last; e != a { t.Fatalf("Was expecting links.%s to have a value of %s, got %s", KeyLastPage, e, a) } @@ -869,6 +936,7 @@ func TestUnmarshalCustomTypeAttributes(t *testing.T) { }, }, } + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -883,9 +951,11 @@ func TestUnmarshalCustomTypeAttributes(t *testing.T) { if expected, actual := customInt, customAttributeTypes.Int; expected != actual { t.Fatalf("Was expecting custom int to be `%d`, got `%d`", expected, actual) } + if expected, actual := customInt, *customAttributeTypes.IntPtr; expected != actual { t.Fatalf("Was expecting custom int pointer to be `%d`, got `%d`", expected, actual) } + if customAttributeTypes.IntPtrNull != nil { t.Fatalf("Was expecting custom int pointer to be , got `%d`", customAttributeTypes.IntPtrNull) } @@ -893,6 +963,7 @@ func TestUnmarshalCustomTypeAttributes(t *testing.T) { if expected, actual := customFloat, customAttributeTypes.Float; expected != actual { t.Fatalf("Was expecting custom float to be `%f`, got `%f`", expected, actual) } + if expected, actual := customString, customAttributeTypes.String; expected != actual { t.Fatalf("Was expecting custom string to be `%s`, got `%s`", expected, actual) } @@ -912,6 +983,7 @@ func TestUnmarshalCustomTypeAttributes_ErrInvalidType(t *testing.T) { }, }, } + payload, err := json.Marshal(data) if err != nil { t.Fatal(err) @@ -919,6 +991,7 @@ func TestUnmarshalCustomTypeAttributes_ErrInvalidType(t *testing.T) { // Parse JSON API payload customAttributeTypes := new(CustomAttributeTypes) + err = UnmarshalPayload(bytes.NewReader(payload), customAttributeTypes) if err == nil { t.Fatal("Expected an error unmarshalling the payload due to type mismatch, got none") @@ -963,7 +1036,7 @@ func samplePayloadWithoutIncluded() map[string]interface{} { } } -func samplePayload() io.Reader { +func samplePayload() (io.Reader, error) { payload := &OnePayload{ Data: &Node{ Type: "blogs", @@ -1028,12 +1101,15 @@ func samplePayload() io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + if err := json.NewEncoder(out).Encode(payload); err != nil { + return nil, err + } + + return out, nil } -func samplePayloadWithID() io.Reader { +func samplePayloadWithID() (io.Reader, error) { payload := &OnePayload{ Data: &Node{ ID: "2", @@ -1046,12 +1122,16 @@ func samplePayloadWithID() io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + err := json.NewEncoder(out).Encode(payload) + if err != nil { + return nil, err + } + + return out, nil } -func samplePayloadWithBadTypes(m map[string]json.RawMessage) io.Reader { +func samplePayloadWithBadTypes(m map[string]json.RawMessage) (io.Reader, error) { payload := &OnePayload{ Data: &Node{ ID: "2", @@ -1061,12 +1141,16 @@ func samplePayloadWithBadTypes(m map[string]json.RawMessage) io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + err := json.NewEncoder(out).Encode(payload) + if err != nil { + return nil, err + } + + return out, nil } -func sampleWithPointerPayload(m map[string]json.RawMessage) io.Reader { +func sampleWithPointerPayload(m map[string]json.RawMessage) (io.Reader, error) { payload := &OnePayload{ Data: &Node{ ID: "2", @@ -1076,9 +1160,12 @@ func sampleWithPointerPayload(m map[string]json.RawMessage) io.Reader { } out := bytes.NewBuffer(nil) - json.NewEncoder(out).Encode(payload) // nolint: errcheck - return out + if err := json.NewEncoder(out).Encode(payload); err != nil { + return nil, err + } + + return out, nil } func testModel() *Blog { @@ -1153,6 +1240,7 @@ func samplePayloadWithSideloaded() io.Reader { testModel := testModel() out := bytes.NewBuffer(nil) + err := MarshalPayload(out, testModel) if err != nil { panic(err) @@ -1163,12 +1251,14 @@ func samplePayloadWithSideloaded() io.Reader { func sampleSerializedEmbeddedTestModel() *Blog { out := bytes.NewBuffer(nil) + err := MarshalOnePayloadEmbedded(out, testModel()) if err != nil { panic(err) } blog := new(Blog) + err = UnmarshalPayload(out, blog) if err != nil { panic(err) @@ -1182,11 +1272,13 @@ func TestUnmarshalNestedStructPtr(t *testing.T) { Firstname string `jsonapi:"attr,firstname"` Surname string `jsonapi:"attr,surname"` } + type Movie struct { ID string `jsonapi:"primary,movies"` Name string `jsonapi:"attr,name"` Director *Director `jsonapi:"attr,director"` } + sample := map[string]interface{}{ "data": map[string]interface{}{ "type": "movies", @@ -1205,6 +1297,7 @@ func TestUnmarshalNestedStructPtr(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Movie) @@ -1215,9 +1308,11 @@ func TestUnmarshalNestedStructPtr(t *testing.T) { if out.Name != "The Shawshank Redemption" { t.Fatalf("expected out.Name to be `The Shawshank Redemption`, but got `%s`", out.Name) } + if out.Director.Firstname != "Frank" { t.Fatalf("expected out.Director.Firstname to be `Frank`, but got `%s`", out.Director.Firstname) } + if out.Director.Surname != "Darabont" { t.Fatalf("expected out.Director.Surname to be `Darabont`, but got `%s`", out.Director.Surname) } @@ -1265,6 +1360,7 @@ func TestUnmarshalNestedStruct(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Company) @@ -1369,6 +1465,7 @@ func TestUnmarshalNestedStructSlice(t *testing.T) { if err != nil { t.Fatal(err) } + in := bytes.NewReader(data) out := new(Company) diff --git a/http/jsonapi/response.go b/http/jsonapi/response.go index b4c967f23..856001fe7 100644 --- a/http/jsonapi/response.go +++ b/http/jsonapi/response.go @@ -96,6 +96,7 @@ func Marshal(models interface{}) (Payloader, error) { if er := jl.validate(); er != nil { return nil, er } + payload.Links = linkableModels.JSONAPILinks() } @@ -109,6 +110,7 @@ func Marshal(models interface{}) (Payloader, error) { if reflect.Indirect(vals).Kind() != reflect.Struct { return nil, ErrUnexpectedType } + return marshalOne(models) default: return nil, ErrUnexpectedType @@ -127,6 +129,7 @@ func MarshalPayloadWithoutIncluded(w io.Writer, model interface{}) error { if err != nil { return err } + payload.clearIncluded() return json.NewEncoder(w).Encode(payload) @@ -142,6 +145,7 @@ func marshalOne(model interface{}) (*OnePayload, error) { if err != nil { return nil, err } + payload := &OnePayload{Data: rootNode} payload.Included = nodeMapValues(&included) @@ -163,8 +167,10 @@ func marshalMany(models []interface{}) (*ManyPayload, error) { if err != nil { return nil, err } + payload.Data = append(payload.Data, node) } + payload.Included = nodeMapValues(&included) return payload, nil @@ -202,6 +208,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, node := new(Node) var er error + value := reflect.ValueOf(model) if value.IsNil() { return nil, nil @@ -212,6 +219,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, for i := 0; i < modelValue.NumField(); i++ { structField := modelValue.Type().Field(i) + tag := structField.Tag.Get(annotationJSONAPI) if tag == "" { continue @@ -304,6 +312,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if node.Attributes == nil { node.Attributes = make(map[string]json.RawMessage) } + var err error if fieldValue.Type() == reflect.TypeOf(decimal.Decimal{}) { @@ -312,6 +321,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if !decimal.MarshalJSONWithoutQuotes { return nil, fmt.Errorf("decimal.MarshalJSONWithoutQuotes needs to be turned on to export decimals as numbers") } + node.Attributes[args[1]] = json.RawMessage(d.String()) } else if fieldValue.Type() == reflect.TypeOf(new(decimal.Decimal)) { // A decimal pointer may be nil @@ -327,6 +337,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if !decimal.MarshalJSONWithoutQuotes { return nil, fmt.Errorf("decimal.MarshalJSONWithoutQuotes needs to be turned on to export decimals as numbers") } + node.Attributes[args[1]] = json.RawMessage(d.String()) } } else if fieldValue.Type() == reflect.TypeOf(time.Time{}) { @@ -341,6 +352,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, } else { node.Attributes[args[1]], err = json.Marshal(t.Unix()) } + if err != nil { return nil, err } @@ -364,6 +376,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, } else { node.Attributes[args[1]], err = json.Marshal(tm.Unix()) } + if err != nil { return nil, err } @@ -383,20 +396,24 @@ func visitModelNode(model interface{}, included *map[string]*Node, // We need to pass a pointer value ptr := reflect.New(fieldValue.Type()) ptr.Elem().Set(fieldValue) + n, err1 := visitModelNode(ptr.Interface(), nil, false) if err1 != nil { return nil, err1 } + node.Attributes[args[1]], err = json.Marshal(n.Attributes) } else if fieldValue.Type().Kind() == reflect.Ptr && fieldValue.Elem().Kind() == reflect.Struct { n, err1 := visitModelNode(fieldValue.Interface(), nil, false) if err1 != nil { return nil, err1 } + node.Attributes[args[1]], err = json.Marshal(n.Attributes) } else { node.Attributes[args[1]], err = json.Marshal(fieldValue.Interface()) } + if err != nil { return nil, err } @@ -441,11 +458,13 @@ func visitModelNode(model interface{}, included *map[string]*Node, er = err break } + relationship.Links = relLinks relationship.Meta = relMeta if sideload { shallowNodes := []*Node{} + for _, n := range relationship.Data { appendIncluded(included, n) shallowNodes = append(shallowNodes, toShallowNode(n)) @@ -461,7 +480,6 @@ func visitModelNode(model interface{}, included *map[string]*Node, } } else { // to-one relationships - // Handle null relationship case if fieldValue.IsNil() { node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} @@ -480,6 +498,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if sideload { appendIncluded(included, relationship) + node.Relationships[args[1]] = &RelationshipOneNode{ Data: toShallowNode(relationship), Links: relLinks, @@ -509,6 +528,7 @@ func visitModelNode(model interface{}, included *map[string]*Node, if er := jl.validate(); er != nil { return nil, er } + node.Links = linkableModel.JSONAPILinks() } @@ -564,6 +584,7 @@ func nodeMapValues(m *map[string]*Node) []*Node { nodes := make([]*Node, len(mp)) i := 0 + for _, n := range mp { nodes[i] = n i++ @@ -577,9 +598,11 @@ func convertToSliceInterface(i *interface{}) ([]interface{}, error) { if vals.Kind() != reflect.Slice { return nil, ErrExpectedSlice } + var response []interface{} for x := 0; x < vals.Len(); x++ { response = append(response, vals.Index(x).Interface()) } + return response, nil } diff --git a/http/jsonapi/response_test.go b/http/jsonapi/response_test.go index 405932f46..5bd62e73b 100644 --- a/http/jsonapi/response_test.go +++ b/http/jsonapi/response_test.go @@ -13,11 +13,10 @@ import ( "testing" "time" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/pace/bricks/pkg/isotime" - - "github.com/shopspring/decimal" ) func TestMarshalPayload(t *testing.T) { @@ -25,12 +24,15 @@ func TestMarshalPayload(t *testing.T) { if e != nil { panic(e) } + book := &Book{ID: 1, Decimal1: d} books := []*Book{book, {ID: 2}} + var jsonData map[string]interface{} // One out1 := bytes.NewBuffer(nil) + err := MarshalPayload(out1, book) if err != nil { t.Fatal(err) @@ -43,13 +45,16 @@ func TestMarshalPayload(t *testing.T) { if err := json.Unmarshal(out1.Bytes(), &jsonData); err != nil { t.Fatal(err) } + if _, ok := jsonData["data"].(map[string]interface{}); !ok { t.Fatalf("data key did not contain an Hash/Dict/Map") } + fmt.Println(out1.String()) // Many out2 := bytes.NewBuffer(nil) + err = MarshalPayload(out2, books) if err != nil { t.Fatal(err) @@ -58,6 +63,7 @@ func TestMarshalPayload(t *testing.T) { if err := json.Unmarshal(out2.Bytes(), &jsonData); err != nil { t.Fatal(err) } + if _, ok := jsonData["data"].([]interface{}); !ok { t.Fatalf("data key did not contain an Array") } @@ -65,6 +71,7 @@ func TestMarshalPayload(t *testing.T) { func TestMarshalPayloadWithNulls(t *testing.T) { books := []*Book{nil, {ID: 101}, nil} + var jsonData map[string]interface{} out := bytes.NewBuffer(nil) @@ -75,14 +82,17 @@ func TestMarshalPayloadWithNulls(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + raw, ok := jsonData["data"] if !ok { t.Fatalf("data key does not exist") } + arr, ok := raw.([]interface{}) if !ok { t.Fatalf("data is not an Array") } + for i := 0; i < len(arr); i++ { if books[i] == nil && arr[i] != nil || books[i] != nil && arr[i] == nil { @@ -139,6 +149,7 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + relationships := jsonData["data"].(map[string]interface{})["relationships"].(map[string]interface{}) // Verifiy the "posts" relation was an empty array @@ -146,18 +157,22 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { if !ok { t.Fatal("Was expecting the data.relationships.posts key/value to have been present") } + postsMap, ok := posts.(map[string]interface{}) if !ok { t.Fatal("data.relationships.posts was not a map") } + postsData, ok := postsMap["data"] if !ok { t.Fatal("Was expecting the data.relationships.posts.data key/value to have been present") } + postsDataSlice, ok := postsData.([]interface{}) if !ok { t.Fatal("data.relationships.posts.data was not a slice []") } + if len(postsDataSlice) != 0 { t.Fatal("Was expecting the data.relationships.posts.data value to have been an empty array []") } @@ -167,14 +182,17 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { if !postExists { t.Fatal("Was expecting the data.relationships.current_post key/value to have NOT been omitted") } + currentPostMap, ok := currentPost.(map[string]interface{}) if !ok { t.Fatal("data.relationships.current_post was not a map") } + currentPostData, ok := currentPostMap["data"] if !ok { t.Fatal("Was expecting the data.relationships.current_post.data key/value to have been present") } + if currentPostData != nil { t.Fatal("Was expecting the data.relationships.current_post.data value to have been nil/null") } @@ -199,6 +217,7 @@ func TestWithOmitsEmptyAnnotationOnRelation(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + payload := jsonData["data"].(map[string]interface{}) // Verify relationship was NOT set @@ -231,6 +250,7 @@ func TestWithOmitsEmptyAnnotationOnRelation_MixedData(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + payload := jsonData["data"].(map[string]interface{}) // Verify relationship was set @@ -290,19 +310,24 @@ func TestWithOmitsEmptyAnnotationOnAttribute(t *testing.T) { // Verify that there is no field "phones" in attributes payload := jsonData["data"].(map[string]interface{}) + attributes := payload["attributes"].(map[string]interface{}) if _, ok := attributes["title"]; !ok { t.Fatal("Was expecting the data.attributes.title to have NOT been omitted") } + if _, ok := attributes["phones"]; ok { t.Fatal("Was expecting the data.attributes.phones to have been omitted") } + if _, ok := attributes["address"]; ok { t.Fatal("Was expecting the data.attributes.phones to have been omitted") } + if _, ok := attributes["tags"]; !ok { t.Fatal("Was expecting the data.attributes.tags to have NOT been omitted") } + if _, ok := attributes["account"]; !ok { t.Fatal("Was expecting the data.attributes.account to have NOT been omitted") } @@ -325,6 +350,7 @@ func TestMarshalIDPtr(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + data := jsonData["data"].(map[string]interface{}) // attributes := data["attributes"].(map[string]interface{}) @@ -333,6 +359,7 @@ func TestMarshalIDPtr(t *testing.T) { if !exists { t.Fatal("Was expecting the data.id member to exist") } + if val != id { t.Fatalf("Was expecting the data.id member to be `%s`, got `%s`", id, val) } @@ -346,6 +373,7 @@ func TestMarshalOnePayload_omitIDString(t *testing.T) { foo := &Foo{Title: "Foo"} out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, foo); err != nil { t.Fatal(err) } @@ -354,6 +382,7 @@ func TestMarshalOnePayload_omitIDString(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + payload := jsonData["data"].(map[string]interface{}) // Verify that empty ID of type string gets omitted. See: @@ -368,6 +397,7 @@ func TestMarshall_invalidIDType(t *testing.T) { type badIDStruct struct { ID *bool `jsonapi:"primary,cars"` } + id := true o := &badIDStruct{ID: &id} @@ -394,12 +424,14 @@ func TestOmitsEmptyAnnotation(t *testing.T) { if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { t.Fatal(err) } + attributes := jsonData["data"].(map[string]interface{})["attributes"].(map[string]interface{}) // Verify that the specifically omitted field were omitted if val, exists := attributes["title"]; exists { t.Fatalf("Was expecting the data.attributes.title key/value to have been omitted - it was not and had a value of %v", val) } + if val, exists := attributes["pages"]; exists { t.Fatalf("Was expecting the data.attributes.pages key/value to have been omitted - it was not and had a value of %v", val) } @@ -664,12 +696,14 @@ func TestSupportsLinkable(t *testing.T) { if data.Links == nil { t.Fatal("Expected data.links") } + links := *data.Links self, hasSelf := links["self"] if !hasSelf { t.Fatal("Expected 'self' link to be present") } + if _, isString := self.(string); !isString { t.Fatal("Expected 'self' to contain a string") } @@ -678,6 +712,7 @@ func TestSupportsLinkable(t *testing.T) { if !hasComments { t.Fatal("expect 'comments' to be present") } + commentsMap, isMap := comments.(map[string]interface{}) if !isMap { t.Fatal("Expected 'comments' to contain a map") @@ -687,6 +722,7 @@ func TestSupportsLinkable(t *testing.T) { if !hasHref { t.Fatal("Expect 'comments' to contain an 'href' key/value") } + if _, isString := commentsHref.(string); !isString { t.Fatal("Expected 'href' to contain a string") } @@ -695,16 +731,19 @@ func TestSupportsLinkable(t *testing.T) { if !hasMeta { t.Fatal("Expect 'comments' to contain a 'meta' key/value") } + commentsMetaMap, isMap := commentsMeta.(map[string]interface{}) if !isMap { t.Fatal("Expected 'comments' to contain a map") } commentsMetaObject := Meta(commentsMetaMap) + countsMap, isMap := commentsMetaObject["counts"].(map[string]interface{}) if !isMap { t.Fatal("Expected 'counts' to contain a map") } + for k, v := range countsMap { if _, isNum := v.(float64); !isNum { t.Fatalf("Exepected value at '%s' to be a numeric (float64)", k) @@ -777,6 +816,7 @@ func TestRelations(t *testing.T) { if relations["posts"].(map[string]interface{})["links"] == nil { t.Fatalf("Posts relationship links were not materialized") } + if relations["posts"].(map[string]interface{})["meta"] == nil { t.Fatalf("Posts relationship meta were not materialized") } @@ -788,6 +828,7 @@ func TestRelations(t *testing.T) { if relations["current_post"].(map[string]interface{})["links"] == nil { t.Fatalf("Current post relationship links were not materialized") } + if relations["current_post"].(map[string]interface{})["meta"] == nil { t.Fatalf("Current post relationship meta were not materialized") } @@ -975,6 +1016,7 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { {ID: 2, Author: "shwoodard", ISBN: "xyz"}, } interfaces := []interface{}{} + for _, s := range structs { interfaces = append(interfaces, s) } @@ -984,6 +1026,7 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { if err := MarshalPayload(structsOut, structs); err != nil { t.Fatal(err) } + interfacesOut := new(bytes.Buffer) if err := MarshalPayload(interfacesOut, interfaces); err != nil { t.Fatal(err) @@ -994,6 +1037,7 @@ func TestMarshalMany_SliceOfInterfaceAndSliceOfStructsSameJSON(t *testing.T) { if err := json.Unmarshal(structsOut.Bytes(), &structsData); err != nil { t.Fatal(err) } + if err := json.Unmarshal(interfacesOut.Bytes(), &interfacesData); err != nil { t.Fatal(err) } @@ -1009,9 +1053,11 @@ func TestMarshal_InvalidIntefaceArgument(t *testing.T) { if err := MarshalPayload(out, true); err != ErrUnexpectedType { t.Fatal("Was expecting an error") } + if err := MarshalPayload(out, 25); err != ErrUnexpectedType { t.Fatal("Was expecting an error") } + if err := MarshalPayload(out, Book{}); err != ErrUnexpectedType { t.Fatal("Was expecting an error") } diff --git a/http/jsonapi/runtime.go b/http/jsonapi/runtime.go index 7fd67db16..d7f382f33 100644 --- a/http/jsonapi/runtime.go +++ b/http/jsonapi/runtime.go @@ -106,6 +106,7 @@ func (r *Runtime) instrumentCall(start Event, stop Event, c func() error) error } begin := time.Now() + Instrumentation(r, start, instrumentationGUID, time.Duration(0)) if err := c(); err != nil { @@ -128,5 +129,6 @@ func newUUID() (string, error) { uuid[8] = uuid[8]&^0xc0 | 0x80 // version 4 (pseudo-random); see section 4.1.3 uuid[6] = uuid[6]&^0xf0 | 0x40 + return fmt.Sprintf("%x-%x-%x-%x-%x", uuid[0:4], uuid[4:6], uuid[6:8], uuid[8:10], uuid[10:]), nil } diff --git a/http/jsonapi/runtime/error.go b/http/jsonapi/runtime/error.go index ecacebf7c..df3388b9c 100644 --- a/http/jsonapi/runtime/error.go +++ b/http/jsonapi/runtime/error.go @@ -58,6 +58,7 @@ func (e Errors) Error() string { for i, err := range e { messages[i] = err.Error() } + return strings.Join(messages, "\n") } @@ -110,6 +111,7 @@ func WriteError(w http.ResponseWriter, code int, err error) { // render the error to the client enc := json.NewEncoder(w) enc.SetIndent("", " ") + err = enc.Encode(errList) if err != nil { log.Logger().Info().Str("req_id", reqID). diff --git a/http/jsonapi/runtime/error_test.go b/http/jsonapi/runtime/error_test.go index 77c124067..24dc656fc 100644 --- a/http/jsonapi/runtime/error_test.go +++ b/http/jsonapi/runtime/error_test.go @@ -51,12 +51,15 @@ func TestErrorMarshaling(t *testing.T) { if resp.StatusCode != testCase.httpStatus { t.Errorf("expected the response code %d got: %d", testCase.httpStatus, resp.StatusCode) } + if ct := resp.Header.Get("Content-Type"); ct != JSONAPIContentType { t.Errorf("expected the response code %q got: %q", JSONAPIContentType, ct) } var errList errorObjects + dec := json.NewDecoder(resp.Body) + err := dec.Decode(&errList) if err != nil { t.Fatal(err) @@ -82,6 +85,7 @@ func TestErrors(t *testing.T) { &Error{Title: "foo2", Detail: "bar2"}, } result := "foo\nfoo2" + if errs.Error() != result { t.Errorf("expected %q got: %q", result, errs.Error()) } @@ -89,9 +93,9 @@ func TestErrors(t *testing.T) { func TestError(t *testing.T) { err := Error{} - err.setHTTPStatus(200) + err.setHTTPStatus(http.StatusOK) - result := "200" + result := "http.StatusOK" if err.Status != result { t.Errorf("expected %q got: %q", result, err.Status) } diff --git a/http/jsonapi/runtime/marshalling.go b/http/jsonapi/runtime/marshalling.go index c28a2ed80..3fbf9d369 100644 --- a/http/jsonapi/runtime/marshalling.go +++ b/http/jsonapi/runtime/marshalling.go @@ -18,7 +18,9 @@ import ( // In case of an error, an jsonapi error message will be directly send to the client func Unmarshal(w http.ResponseWriter, r *http.Request, data interface{}) bool { // don't leak , but error can't be handled - defer r.Body.Close() // nolint: errcheck + defer func() { + _ = r.Body.Close() + }() // verify that the client accepts our response // Note: logically this would be done before marshalling, @@ -57,7 +59,9 @@ func Unmarshal(w http.ResponseWriter, r *http.Request, data interface{}) bool { // In case of an error, an jsonapi error message will be directly send to the client func UnmarshalMany(w http.ResponseWriter, r *http.Request, t reflect.Type) (bool, []interface{}) { // don't leak , but error can't be handled - defer r.Body.Close() // nolint: errcheck + defer func() { + _ = r.Body.Close() + }() // verify that the client accepts our response // Note: logically this would be done before marshalling, @@ -91,6 +95,7 @@ func UnmarshalMany(w http.ResponseWriter, r *http.Request, t reflect.Type) (bool return false, nil } } + return true, data } diff --git a/http/jsonapi/runtime/marshalling_test.go b/http/jsonapi/runtime/marshalling_test.go index 36ae03aaf..8a3653d12 100644 --- a/http/jsonapi/runtime/marshalling_test.go +++ b/http/jsonapi/runtime/marshalling_test.go @@ -15,7 +15,7 @@ import ( func TestUnmarshalAccept(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) ok := Unmarshal(rec, req, nil) if ok { @@ -24,6 +24,7 @@ func TestUnmarshalAccept(t *testing.T) { resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != http.StatusNotAcceptable { t.Errorf("Expected status code %d got: %d", http.StatusNotAcceptable, resp.StatusCode) } @@ -31,7 +32,7 @@ func TestUnmarshalAccept(t *testing.T) { func TestUnmarshalContentType(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) req.Header.Set("Accept", JSONAPIContentType) ok := Unmarshal(rec, req, nil) @@ -41,6 +42,7 @@ func TestUnmarshalContentType(t *testing.T) { resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != http.StatusUnsupportedMediaType { t.Errorf("Expected status code %d got: %d", http.StatusUnsupportedMediaType, resp.StatusCode) } @@ -48,7 +50,7 @@ func TestUnmarshalContentType(t *testing.T) { func TestUnmarshalContent(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data": 1}`)) + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"data": 1}`)) req.Header.Set("Accept", JSONAPIContentType) req.Header.Set("Content-Type", JSONAPIContentType) @@ -58,6 +60,7 @@ func TestUnmarshalContent(t *testing.T) { } var article Article + ok := Unmarshal(rec, req, &article) if ok { t.Error("Un-marshalling should fail") @@ -65,6 +68,7 @@ func TestUnmarshalContent(t *testing.T) { resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != http.StatusUnprocessableEntity { t.Errorf("Expected status code %d got: %d", http.StatusUnprocessableEntity, resp.StatusCode) } @@ -72,7 +76,7 @@ func TestUnmarshalContent(t *testing.T) { func TestUnmarshalArticle(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":{ + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"data":{ "type": "articles", "id": "cb855aff-f03c-4307-9a22-ab5fcc6b6d7c", "attributes": { @@ -88,20 +92,23 @@ func TestUnmarshalArticle(t *testing.T) { } var article Article - ok := Unmarshal(rec, req, &article) + ok := Unmarshal(rec, req, &article) if !ok { t.Error("Un-marshalling should have been ok") } resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d got: %d", http.StatusOK, resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } @@ -109,6 +116,7 @@ func TestUnmarshalArticle(t *testing.T) { if article.ID != uuid { t.Errorf("article.ID expected %q got: %q", uuid, article.ID) } + if article.Title != "This is my first blog" { t.Errorf("article.ID expected \"This is my first blog\" got: %q", article.Title) } @@ -116,7 +124,7 @@ func TestUnmarshalArticle(t *testing.T) { func TestUnmarshalArticles(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", strings.NewReader(`{"data":[ + req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{"data":[ { "type":"article", "id": "82180c8d-0ab6-4946-9298-61d3c8d13da4", @@ -134,31 +142,35 @@ func TestUnmarshalArticles(t *testing.T) { ]}`)) req.Header.Set("Accept", JSONAPIContentType) req.Header.Set("Content-Type", JSONAPIContentType) + type Article struct { ID string `jsonapi:"primary,article" valid:"optional,uuid"` Title string `jsonapi:"attr,title" valid:"required"` } ok, articles := UnmarshalMany(rec, req, reflect.TypeOf(new(Article))) - if !ok { t.Error("Un-marshalling many should have been ok") } resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d got: %d", http.StatusOK, resp.StatusCode) + b, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } + t.Error(string(b[:])) } if len(articles) != 2 { t.Errorf("Expected 2 articles, got %d", len(articles)) } + expected := []*Article{ { ID: "82180c8d-0ab6-4946-9298-61d3c8d13da4", @@ -169,11 +181,13 @@ func TestUnmarshalArticles(t *testing.T) { Title: "This is the second article", }, } + for i := range articles { got := articles[i].(*Article) if expected[i].ID != got.ID { t.Errorf("article.ID expected %q got: %q", expected[i].ID, got.ID) } + if expected[i].Title != got.Title { t.Errorf("article.ID expected \"%s\" got: %q", expected[i].ID, got.Title) } @@ -196,6 +210,7 @@ func TestMarshalArticle(t *testing.T) { resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { t.Errorf("Expected status code %d got: %d", http.StatusOK, resp.StatusCode) } @@ -245,6 +260,7 @@ func TestMarshalConnectionError(t *testing.T) { t.Fatal("Was not expecting a panic") } }() + rec := writer{} Marshal(rec, &struct{}{}, http.StatusOK) } diff --git a/http/jsonapi/runtime/parameters.go b/http/jsonapi/runtime/parameters.go index 54f661b89..3b42b2dc7 100644 --- a/http/jsonapi/runtime/parameters.go +++ b/http/jsonapi/runtime/parameters.go @@ -72,6 +72,7 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP size := len(input) array := reflect.MakeSlice(reValue.Type(), size, size) invalid := 0 + for i := 0; i < size; i++ { if input[i] == "" { invalid++ @@ -79,7 +80,8 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP } arrElem := array.Index(i - invalid) - n, _ := Scan(input[i], arrElem.Addr().Interface()) // nolint: gosec + n, _ := Scan(input[i], arrElem.Addr().Interface()) + if n != 1 { WriteError(w, http.StatusBadRequest, param.BuildInvalidValueError(arrElem.Type(), input[i])) return false @@ -89,6 +91,7 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP if invalid > 0 { array = array.Slice(0, size-invalid) } + reValue.Set(array) // skip parsing at the bottom of the loop @@ -110,6 +113,7 @@ func ScanParameters(w http.ResponseWriter, r *http.Request, parameters ...*ScanP return false } } + return true } @@ -121,7 +125,9 @@ func Scan(str string, data interface{}) (int, error) { if err != nil { return 0, err } + *d = nd + return 1, nil } @@ -133,6 +139,7 @@ func Scan(str string, data interface{}) (int, error) { } *t = nt + return 1, nil } @@ -143,5 +150,5 @@ func Scan(str string, data interface{}) (int, error) { return 1, nil } - return fmt.Sscan(str, data) // nolint: gosec + return fmt.Sscan(str, data) } diff --git a/http/jsonapi/runtime/parameters_test.go b/http/jsonapi/runtime/parameters_test.go index d5dcfe994..3c1d5e304 100644 --- a/http/jsonapi/runtime/parameters_test.go +++ b/http/jsonapi/runtime/parameters_test.go @@ -4,6 +4,7 @@ package runtime import ( "encoding/json" + "net/http" "net/http/httptest" "testing" "time" @@ -21,11 +22,12 @@ func TestScanStringParametersInQuery(t *testing.T) { } for _, tc := range tests { - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() + var param0 string + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "q"}) - // Parsing if !ok { t.Errorf("expected the scanning of %q to be successful", tc.path) } @@ -50,11 +52,12 @@ func TestScanTimeParametersInQuery(t *testing.T) { } for _, tc := range tests { - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() + var param0 time.Time + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "q"}) - // Parsing if !ok { t.Errorf("expected the scanning of %q to be successful", tc.path) } @@ -80,11 +83,12 @@ func TestScanBoolParametersInQuery(t *testing.T) { } for _, tc := range tests { - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) rec := httptest.NewRecorder() + var param0 bool + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "b"}) - // Parsing if !ok { t.Errorf("expected the scanning of %q to be successful", tc.path) } @@ -96,20 +100,33 @@ func TestScanBoolParametersInQuery(t *testing.T) { } func TestScanNumericParametersInPath(t *testing.T) { - req := httptest.NewRequest("GET", "/foo/", nil) + req := httptest.NewRequest(http.MethodGet, "/foo/", nil) rec := httptest.NewRecorder() + var param0 uint + var param1 uint8 + var param2 uint16 + var param3 uint32 + var param4 uint64 + var param10 int + var param11 int8 + var param12 int16 + var param13 int32 + var param14 int64 + var param20 float32 + var param21 float64 + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInPath, "12", "num"}, &ScanParameter{¶m1, ScanInPath, "12", "num"}, @@ -134,15 +151,19 @@ func TestScanNumericParametersInPath(t *testing.T) { if param0 != uint(12) { t.Errorf("expected parsing result %#v got: %#v", uint(12), param0) } + if param1 != uint8(12) { t.Errorf("expected parsing result %#v got: %#v", uint8(12), param1) } + if param2 != uint16(12) { t.Errorf("expected parsing result %#v got: %#v", uint16(12), param2) } + if param3 != uint32(12) { t.Errorf("expected parsing result %#v got: %#v", uint32(12), param3) } + if param4 != uint64(12) { t.Errorf("expected parsing result %#v got: %#v", uint64(12), param4) } @@ -151,15 +172,19 @@ func TestScanNumericParametersInPath(t *testing.T) { if param10 != int(-12) { t.Errorf("expected parsing result %#v got: %#v", int(-12), param10) } + if param11 != int8(-12) { t.Errorf("expected parsing result %#v got: %#v", int8(-12), param11) } + if param12 != int16(-12) { t.Errorf("expected parsing result %#v got: %#v", int16(-12), param12) } + if param13 != int32(-12) { t.Errorf("expected parsing result %#v got: %#v", int32(-12), param13) } + if param14 != int64(-12) { t.Errorf("expected parsing result %#v got: %#v", int64(-12), param14) } @@ -168,19 +193,26 @@ func TestScanNumericParametersInPath(t *testing.T) { if param20 != float32(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-12.123123123123123123123123), param20) } + if param21 != float64(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float64(-12.123123123123123123123123), param21) } } func TestScanNumericParametersInQueryUint(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=12", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=12", nil) rec := httptest.NewRecorder() + var param0 uint + var param1 uint8 + var param2 uint16 + var param3 uint32 + var param4 uint64 + ok := ScanParameters(rec, req, &ScanParameter{¶m0, ScanInQuery, "", "num"}, &ScanParameter{¶m1, ScanInQuery, "", "num"}, @@ -198,28 +230,38 @@ func TestScanNumericParametersInQueryUint(t *testing.T) { if param0 != uint(12) { t.Errorf("expected parsing result %#v got: %#v", uint(12), param0) } + if param1 != uint8(12) { t.Errorf("expected parsing result %#v got: %#v", uint8(12), param1) } + if param2 != uint16(12) { t.Errorf("expected parsing result %#v got: %#v", uint16(12), param2) } + if param3 != uint32(12) { t.Errorf("expected parsing result %#v got: %#v", uint32(12), param3) } + if param4 != uint64(12) { t.Errorf("expected parsing result %#v got: %#v", uint64(12), param4) } } func TestScanNumericParametersInQueryInt(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12", nil) rec := httptest.NewRecorder() + var param10 int + var param11 int8 + var param12 int16 + var param13 int32 + var param14 int64 + ok := ScanParameters(rec, req, &ScanParameter{¶m10, ScanInQuery, "", "num"}, &ScanParameter{¶m11, ScanInQuery, "", "num"}, @@ -237,25 +279,32 @@ func TestScanNumericParametersInQueryInt(t *testing.T) { if param10 != int(-12) { t.Errorf("expected parsing result %#v got: %#v", int(-12), param10) } + if param11 != int8(-12) { t.Errorf("expected parsing result %#v got: %#v", int8(-12), param11) } + if param12 != int16(-12) { t.Errorf("expected parsing result %#v got: %#v", int16(-12), param12) } + if param13 != int32(-12) { t.Errorf("expected parsing result %#v got: %#v", int32(-12), param13) } + if param14 != int64(-12) { t.Errorf("expected parsing result %#v got: %#v", int64(-12), param14) } } func TestScanNumericParametersInQueryFloat(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12.123123123123123123123123", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12.123123123123123123123123", nil) rec := httptest.NewRecorder() + var param20 float32 + var param21 float64 + ok := ScanParameters(rec, req, &ScanParameter{¶m20, ScanInQuery, "", "num"}, &ScanParameter{¶m21, ScanInQuery, "", "num"}, @@ -270,15 +319,18 @@ func TestScanNumericParametersInQueryFloat(t *testing.T) { if param20 != float32(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-12.123123123123123123123123), param20) } + if param21 != float64(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float64(-12.123123123123123123123123), param21) } } func TestScanNumericParametersInQueryFloatArray(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12.123123123123123123123123&num=-987.123123123123123123123123&num=", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12.123123123123123123123123&num=-987.123123123123123123123123&num=", nil) rec := httptest.NewRecorder() + var param []float32 + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInQuery, "", "num"}, ) @@ -296,15 +348,18 @@ func TestScanNumericParametersInQueryFloatArray(t *testing.T) { if param[0] != float32(-12.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-12.123123123123123123123123), param[0]) } + if param[1] != float32(-987.123123123123123123123123) { t.Errorf("expected parsing result %#v got: %#v", float32(-987.123123123123123123123123), param[1]) } } func TestScanNumericParametersInQueryFloatArrayFail(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12.123123123123123123123123&num=stuff", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12.123123123123123123123123&num=stuff", nil) rec := httptest.NewRecorder() + var param []float32 + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInQuery, "", "num"}, ) @@ -318,7 +373,9 @@ func TestScanNumericParametersInQueryFloatArrayFail(t *testing.T) { defer resp.Body.Close() var errList errorObjects + dec := json.NewDecoder(resp.Body) + err := dec.Decode(&errList) if err != nil { t.Fatal(err) @@ -332,19 +389,24 @@ func TestScanNumericParametersInQueryFloatArrayFail(t *testing.T) { if r := "invalid value for num"; errObj.Title != r { t.Errorf("expected title %q got: %q", r, errObj.Title) } + if r := "400"; errObj.Status != r { t.Errorf("expected status %q got: %q", r, errObj.Status) } + if r := "num"; (*errObj.Source)["parameter"] != r { t.Errorf("expected source parameter %q got: %q", r, (*errObj.Source)["parameter"]) } } func TestScanParametersHeader(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("num", "123") + rec := httptest.NewRecorder() + var param int + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInHeader, "", "num"}, ) @@ -366,9 +428,11 @@ func TestScanParametersHeader(t *testing.T) { } func TestScanParametersError(t *testing.T) { - req := httptest.NewRequest("GET", "/foo?num=-12", nil) + req := httptest.NewRequest(http.MethodGet, "/foo?num=-12", nil) rec := httptest.NewRecorder() + var param uint + ok := ScanParameters(rec, req, &ScanParameter{¶m, ScanInQuery, "", "num"}, ) @@ -382,7 +446,9 @@ func TestScanParametersError(t *testing.T) { defer resp.Body.Close() var errList errorObjects + dec := json.NewDecoder(resp.Body) + err := dec.Decode(&errList) if err != nil { t.Fatal(err) @@ -396,9 +462,11 @@ func TestScanParametersError(t *testing.T) { if r := "invalid value for num"; errObj.Title != r { t.Errorf("expected title %q got: %q", r, errObj.Title) } + if r := "400"; errObj.Status != r { t.Errorf("expected status %q got: %q", r, errObj.Status) } + if r := "num"; (*errObj.Source)["parameter"] != r { t.Errorf("expected source parameter %q got: %q", r, (*errObj.Source)["parameter"]) } diff --git a/http/jsonapi/runtime/standard_params.go b/http/jsonapi/runtime/standard_params.go index c45859d55..90d52da0d 100644 --- a/http/jsonapi/runtime/standard_params.go +++ b/http/jsonapi/runtime/standard_params.go @@ -75,26 +75,34 @@ type UrlQueryParameters struct { // even if any errors occur. The returned error combines all errors of pagination, filter and sorting. func ReadURLQueryParameters(r *http.Request, mapper ColumnMapper, sanitizer ValueSanitizer) (*UrlQueryParameters, error) { result := &UrlQueryParameters{} + var errs []error + if err := result.readPagination(r); err != nil { errs = append(errs, err) } + if err := result.readSorting(r, mapper); err != nil { errs = append(errs, err) } + if err := result.readFilter(r, mapper, sanitizer); err != nil { errs = append(errs, err) } + if len(errs) == 0 { return result, nil } + if len(errs) == 1 { return result, errs[0] } + var errAggregate []string for _, err := range errs { errAggregate = append(errAggregate, err.Error()) } + return result, fmt.Errorf("reading URL Query Parameters cased multiple errors: %v", strings.Join(errAggregate, ",")) } @@ -103,6 +111,7 @@ func (u *UrlQueryParameters) AddToQuery(query *orm.Query) *orm.Query { if u.HasPagination { query.Offset(u.PageSize * u.PageNr).Limit(u.PageSize) } + for name, filterValues := range u.Filter { if len(filterValues) == 0 { continue @@ -112,26 +121,33 @@ func (u *UrlQueryParameters) AddToQuery(query *orm.Query) *orm.Query { query.Where(name+" = ?", filterValues[0]) continue } + query.Where(name+" IN (?)", pg.In(filterValues)) } + for _, val := range u.Order { query.Order(val) } + return query } func (u *UrlQueryParameters) readPagination(r *http.Request) error { pageStr := r.URL.Query().Get("page[number]") sizeStr := r.URL.Query().Get("page[size]") + if pageStr == "" { u.HasPagination = false return nil } + u.HasPagination = true + pageNr, err := strconv.Atoi(pageStr) if err != nil { return err } + var pageSize int if sizeStr != "" { pageSize, err = strconv.Atoi(sizeStr) @@ -141,11 +157,14 @@ func (u *UrlQueryParameters) readPagination(r *http.Request) error { } else { pageSize = cfg.DefaultPageSize } + if (pageSize < cfg.MinPageSize) || (pageSize > cfg.MaxPageSize) { return fmt.Errorf("invalid pagesize not between min. and max. value, min: %d, max: %d", cfg.MinPageSize, cfg.MaxPageSize) } + u.PageNr = pageNr u.PageSize = pageSize + return nil } @@ -154,19 +173,25 @@ func (u *UrlQueryParameters) readSorting(r *http.Request, mapper ColumnMapper) e if sort == "" { return nil } + sorting := strings.Split(sort, ",") var order string + var resultedOrders []string + var errSortingWithReason []string + for _, val := range sorting { if val == "" { continue } + order = " ASC" if strings.HasPrefix(val, "-") { order = " DESC" } + val = strings.TrimPrefix(val, "-") key, isValid := mapper.Map(val) @@ -174,38 +199,50 @@ func (u *UrlQueryParameters) readSorting(r *http.Request, mapper ColumnMapper) e errSortingWithReason = append(errSortingWithReason, val) continue } + resultedOrders = append(resultedOrders, key+order) } + u.Order = resultedOrders + if len(errSortingWithReason) > 0 { return fmt.Errorf("at least one sorting parameter is not valid: %q", strings.Join(errSortingWithReason, ",")) } + return nil } func (u *UrlQueryParameters) readFilter(r *http.Request, mapper ColumnMapper, sanitizer ValueSanitizer) error { filter := make(map[string][]interface{}) + var invalidFilter []string + for queryName, queryValues := range r.URL.Query() { if !(strings.HasPrefix(queryName, "filter[") && strings.HasSuffix(queryName, "]")) { continue } + key, isValid := getFilterKey(queryName, mapper) if !isValid { invalidFilter = append(invalidFilter, key) continue } + filterValues, isValid := getFilterValues(key, queryValues, sanitizer) if !isValid { invalidFilter = append(invalidFilter, key) continue } + filter[key] = filterValues } + u.Filter = filter + if len(invalidFilter) != 0 { return fmt.Errorf("at least one filter parameter is not valid: %q", strings.Join(invalidFilter, ",")) } + return nil } @@ -213,14 +250,17 @@ func getFilterKey(queryName string, modelMapping ColumnMapper) (string, bool) { field := strings.TrimPrefix(queryName, "filter[") field = strings.TrimSuffix(field, "]") mapped, isValid := modelMapping.Map(field) + if !isValid { return field, false } + return mapped, true } func getFilterValues(fieldName string, queryValues []string, sanitizer ValueSanitizer) ([]interface{}, bool) { var filterValues []interface{} + for _, value := range queryValues { separatedValues := strings.Split(value, ",") for _, separatedValue := range separatedValues { @@ -228,8 +268,10 @@ func getFilterValues(fieldName string, queryValues []string, sanitizer ValueSani if err != nil { return nil, false } + filterValues = append(filterValues, sanitized) } } + return filterValues, true } diff --git a/http/jsonapi/runtime/standard_params_test.go b/http/jsonapi/runtime/standard_params_test.go index 72be89027..22a23fa0f 100644 --- a/http/jsonapi/runtime/standard_params_test.go +++ b/http/jsonapi/runtime/standard_params_test.go @@ -4,6 +4,7 @@ package runtime_test import ( "context" + "net/http" "net/http/httptest" "sort" "testing" @@ -34,6 +35,7 @@ func TestIntegrationFilterParameter(t *testing.T) { // Setup a := assert.New(t) db := setupDatabase(a) + defer func() { // Tear Down err := db.DropTable(&TestModel{}, &orm.DropTableOptions{}) @@ -45,20 +47,24 @@ func TestIntegrationFilterParameter(t *testing.T) { } mapper := runtime.NewMapMapper(mappingNames) // filter - r := httptest.NewRequest("GET", "http://abc.de/whatEver?filter[test]=b", nil) + r := httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?filter[test]=b", nil) urlParams, err := runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) a.NoError(err) + var modelsFilter []TestModel + q := db.Model(&modelsFilter) q = urlParams.AddToQuery(q) count, _ := q.SelectAndCount() a.Equal(1, count) a.Equal("b", modelsFilter[0].FilterName) - r = httptest.NewRequest("GET", "http://abc.de/whatEver?filter[test]=a,b", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?filter[test]=a,b", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) a.NoError(err) + var modelsFilter2 []TestModel + q = db.Model(&modelsFilter2) q = urlParams.AddToQuery(q) count, _ = q.SelectAndCount() @@ -70,10 +76,12 @@ func TestIntegrationFilterParameter(t *testing.T) { a.Equal("b", modelsFilter2[1].FilterName) // Paging - r = httptest.NewRequest("GET", "http://abc.de/whatEver?page[number]=1&page[size]=2", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?page[number]=1&page[size]=2", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) assert.NoError(t, err) + var modelsPaging []TestModel + q = db.Model(&modelsPaging) q = urlParams.AddToQuery(q) err = q.Select() @@ -85,10 +93,12 @@ func TestIntegrationFilterParameter(t *testing.T) { a.Equal("d", modelsPaging[1].FilterName) // Sorting - r = httptest.NewRequest("GET", "http://abc.de/whatEver?sort=-test", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?sort=-test", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) assert.NoError(t, err) + var modelsSort []TestModel + q = db.Model(&modelsSort) q = urlParams.AddToQuery(q) err = q.Select() @@ -102,10 +112,12 @@ func TestIntegrationFilterParameter(t *testing.T) { a.Equal("a", modelsSort[5].FilterName) // Combine all - r = httptest.NewRequest("GET", "http://abc.de/whatEver?sort=-test&filter[test]=a,b,e,f&page[number]=1&page[size]=2", nil) + r = httptest.NewRequest(http.MethodGet, "http://abc.de/whatEver?sort=-test&filter[test]=a,b,e,f&page[number]=1&page[size]=2", nil) urlParams, err = runtime.ReadURLQueryParameters(r, mapper, &testValueSanitizer{}) assert.NoError(t, err) + var modelsCombined []TestModel + q = db.Model(&modelsCombined) q = urlParams.AddToQuery(q) err = q.Select() @@ -121,6 +133,7 @@ func setupDatabase(a *assert.Assertions) *pg.DB { err := db.CreateTable(&TestModel{}, &orm.CreateTableOptions{}) a.NoError(err) + _, err = db.Model(&TestModel{ FilterName: "a", }).Insert() diff --git a/http/jsonapi/runtime/validation.go b/http/jsonapi/runtime/validation.go index 86e4da00e..ff329681e 100644 --- a/http/jsonapi/runtime/validation.go +++ b/http/jsonapi/runtime/validation.go @@ -9,6 +9,7 @@ import ( "time" valid "github.com/asaskevich/govalidator" + "github.com/pace/bricks/pkg/isotime" ) @@ -46,11 +47,11 @@ func ValidateRequest(w http.ResponseWriter, r *http.Request, data interface{}) b // The passed source is the source for validation errors (e.g. pointer for data or parameter) func ValidateStruct(w http.ResponseWriter, r *http.Request, data interface{}, source string) bool { ok, err := valid.ValidateStruct(data) - if !ok { switch errs := err.(type) { case valid.Errors: var e Errors + generateValidationErrors(errs, &e, source) WriteError(w, http.StatusUnprocessableEntity, e) case error: diff --git a/http/jsonapi/runtime/validation_test.go b/http/jsonapi/runtime/validation_test.go index cb7c4e8eb..15974b50b 100644 --- a/http/jsonapi/runtime/validation_test.go +++ b/http/jsonapi/runtime/validation_test.go @@ -17,10 +17,12 @@ func TestValidateParametersWithError(t *testing.T) { type access struct { Token string `valid:"uuid"` } + type input struct { UUID string `valid:"uuid"` Access access } + expected := map[string]interface{}{ "errors": []interface{}{ map[string]interface{}{ @@ -49,10 +51,9 @@ func TestValidateParametersWithError(t *testing.T) { } rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) ok := ValidateParameters(rec, req, &val) - if ok { t.Error("expected to fail the validation") } @@ -60,11 +61,12 @@ func TestValidateParametersWithError(t *testing.T) { resp := rec.Result() defer resp.Body.Close() - if resp.StatusCode != 422 { + if resp.StatusCode != http.StatusUnprocessableEntity { t.Error("expected UnprocessableEntity") } var data map[string]interface{} + err := json.NewDecoder(resp.Body).Decode(&data) if err != nil { t.Fatal(err) @@ -77,13 +79,14 @@ func TestValidateParametersWithError(t *testing.T) { func TestValidateRequest(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/", nil) + req := httptest.NewRequest(http.MethodPost, "/", nil) type args struct { w http.ResponseWriter r *http.Request data interface{} } + tests := []struct { name string args args diff --git a/http/jsonapi/runtime/value_sanitizers.go b/http/jsonapi/runtime/value_sanitizers.go index 38b6c0ccc..9924ed531 100644 --- a/http/jsonapi/runtime/value_sanitizers.go +++ b/http/jsonapi/runtime/value_sanitizers.go @@ -31,6 +31,7 @@ func (d datetimeSanitizer) SanitizeValue(fieldName string, value string) (interf if err != nil { return nil, err } + return t, nil } @@ -58,6 +59,7 @@ func (u uuidSanitizer) SanitizeValue(fieldName string, value string) (interface{ if _, err := uuid.Parse(value); err != nil { return nil, err } + return value, nil } @@ -68,6 +70,7 @@ func (c composableAndFieldRestrictedSanitizer) SanitizeValue(fieldName string, v if !found { return nil, fmt.Errorf("%w: %v", ErrInvalidFieldname, fieldName) } + return san.SanitizeValue(fieldName, value) } diff --git a/http/longpoll/longpoll.go b/http/longpoll/longpoll.go index 782b9721c..55e834fde 100644 --- a/http/longpoll/longpoll.go +++ b/http/longpoll/longpoll.go @@ -31,7 +31,7 @@ var Default = Config{ } // Until executes the given function fn until duration d is passed or context is canceled. -// The constaints of the Default configuration apply. +// The constraints of the Default configuration apply. func Until(ctx context.Context, d time.Duration, fn LongPollFunc) (ok bool, err error) { return Default.LongPollUntil(ctx, d, fn) } diff --git a/http/longpoll/longpoll_test.go b/http/longpoll/longpoll_test.go index 289860ab9..da07b5fde 100644 --- a/http/longpoll/longpoll_test.go +++ b/http/longpoll/longpoll_test.go @@ -14,8 +14,10 @@ func TestLongPollUntilBounds(t *testing.T) { ok, err := Until(context.Background(), -1, func(ctx context.Context) (bool, error) { budget, ok := ctx.Deadline() assert.True(t, ok) - assert.Equal(t, time.Millisecond*999, budget.Sub(time.Now()).Truncate(time.Millisecond)) // nolint: gosimple + assert.Equal(t, time.Millisecond*999, time.Until(budget).Truncate(time.Millisecond)) + called++ + return true, nil }) assert.True(t, ok) @@ -26,8 +28,10 @@ func TestLongPollUntilBounds(t *testing.T) { ok, err = Until(context.Background(), time.Hour, func(ctx context.Context) (bool, error) { budget, ok := ctx.Deadline() assert.True(t, ok) - assert.Equal(t, time.Second*59, budget.Sub(time.Now()).Truncate(time.Second)) // nolint: gosimple + assert.Equal(t, time.Second*59, time.Until(budget).Truncate(time.Second)) + called++ + return true, nil }) assert.True(t, ok) @@ -79,8 +83,10 @@ func TestLongPollUntilTimeout(t *testing.T) { func TestLongPollUntilTimeoutWithContext(t *testing.T) { called := 0 + ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() + ok, err := Until(ctx, time.Second*2, func(context.Context) (bool, error) { called++ return false, nil diff --git a/http/middleware/context.go b/http/middleware/context.go index 35bfc66f9..78c132788 100644 --- a/http/middleware/context.go +++ b/http/middleware/context.go @@ -29,6 +29,7 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context { if r := requestFromContext(ctx); r != nil { return contextWithRequest(targetCtx, r) } + return targetCtx } @@ -46,6 +47,7 @@ func requestFromContext(ctx context.Context) *ctxRequest { if v := ctx.Value((*ctxRequest)(nil)); v != nil { return v.(*ctxRequest) } + return nil } @@ -67,21 +69,26 @@ func GetXForwardedForHeaderFromContext(ctx context.Context) (string, error) { if ctxReq == nil { return "", fmt.Errorf("getting request from context: %w", ErrNotFound) } + xForwardedFor := ctxReq.XForwardedFor + ip, _, err := net.SplitHostPort(ctxReq.RemoteAddr) if err != nil { return "", fmt.Errorf( "%w (from context): could not get ip from remote address: %s", ErrInvalidRequest, err) } + if ip == "" { return "", fmt.Errorf( "%w (from context): could not get ip from remote address: %q", ErrInvalidRequest, ctxReq.RemoteAddr) } + if xForwardedFor != "" { xForwardedFor += ", " } + return xForwardedFor + ip, nil } @@ -93,5 +100,6 @@ func GetUserAgentFromContext(ctx context.Context) (string, error) { if ctxReq == nil { return "", fmt.Errorf("getting request from context: %w", ErrNotFound) } + return ctxReq.UserAgent, nil } diff --git a/http/middleware/context_test.go b/http/middleware/context_test.go index 2f12aee09..026697166 100644 --- a/http/middleware/context_test.go +++ b/http/middleware/context_test.go @@ -8,13 +8,14 @@ import ( "net/http" "testing" - . "github.com/pace/bricks/http/middleware" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + . "github.com/pace/bricks/http/middleware" ) func TestContextTransfer(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) require.NoError(t, err) r.Header.Set("User-Agent", "Foobar") RequestInContext(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { @@ -72,12 +73,14 @@ func TestGetXForwardedForHeaderFromContext(t *testing.T) { } for name, c := range cases { t.Run(name, func(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) require.NoError(t, err) + r.RemoteAddr = c.RemoteAddr if c.XForwardedFor != "" { r.Header.Set("X-Forwarded-For", c.XForwardedFor) } + RequestInContext(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { ctx := r.Context() xForwardedFor, err := GetXForwardedForHeaderFromContext(ctx) @@ -98,7 +101,7 @@ func TestGetXForwardedForHeaderFromContext(t *testing.T) { } func TestGetUserAgentFromContext(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/", nil) require.NoError(t, err) r.Header.Set("User-Agent", "Foobar") RequestInContext(http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { diff --git a/http/middleware/external_dependency.go b/http/middleware/external_dependency.go index 382a54b53..aa6f996bb 100644 --- a/http/middleware/external_dependency.go +++ b/http/middleware/external_dependency.go @@ -24,6 +24,7 @@ const ExternalDependencyHeaderName = "External-Dependencies" func ExternalDependency(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var edc ExternalDependencyContext + edw := externalDependencyWriter{ ResponseWriter: w, edc: &edc, @@ -39,6 +40,7 @@ func AddExternalDependency(ctx context.Context, name string, dur time.Duration) log.Ctx(ctx).Warn().Msgf("can't add external dependency %q with %s, because context is missing", name, dur) return } + ec.AddDependency(name, dur) } @@ -54,6 +56,7 @@ func (w *externalDependencyWriter) addHeader() { if len(w.edc.dependencies) > 0 { w.ResponseWriter.Header().Add(ExternalDependencyHeaderName, w.edc.String()) } + w.header = true } } @@ -78,6 +81,7 @@ func ExternalDependencyContextFromContext(ctx context.Context) *ExternalDependen if v := ctx.Value((*ExternalDependencyContext)(nil)); v != nil { return v.(*ExternalDependencyContext) } + return nil } @@ -100,14 +104,19 @@ func (c *ExternalDependencyContext) AddDependency(name string, duration time.Dur // String formats all external dependencies func (c *ExternalDependencyContext) String() string { var b strings.Builder + sep := len(c.dependencies) - 1 + for _, dep := range c.dependencies { b.WriteString(dep.String()) + if sep > 0 { b.WriteByte(',') + sep-- } } + return b.String() } @@ -119,6 +128,7 @@ func (c *ExternalDependencyContext) Parse(s string) { if index == -1 { continue // ignore the invalid values } + dur, err := strconv.ParseInt(value[index+1:], 10, 64) if err != nil { continue // ignore the invalid values diff --git a/http/middleware/external_dependency_test.go b/http/middleware/external_dependency_test.go index 77b668f51..5d0ccb806 100644 --- a/http/middleware/external_dependency_test.go +++ b/http/middleware/external_dependency_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_ExternalDependency_Middleare(t *testing.T) { @@ -30,7 +31,9 @@ func Test_ExternalDependency_Middleare(t *testing.T) { h := ExternalDependency(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { AddExternalDependency(r.Context(), "test", time.Second) - w.Write(nil) // nolint: errcheck + + _, err := w.Write(nil) + require.NoError(t, err) })) h.ServeHTTP(rec, req) assert.Equal(t, rec.Result().Header[ExternalDependencyHeaderName][0], "test:1000") diff --git a/http/middleware/metrics.go b/http/middleware/metrics.go index 77b21166d..88d71f6f7 100644 --- a/http/middleware/metrics.go +++ b/http/middleware/metrics.go @@ -63,9 +63,11 @@ func Metrics(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { paceHTTPInFlightGauge.Inc() defer paceHTTPInFlightGauge.Dec() + startTime := time.Now() srw := statusWriter{ResponseWriter: w} next.ServeHTTP(&srw, r) + dur := float64(time.Since(startTime)) / float64(time.Millisecond) labels := prometheus.Labels{ "code": strconv.Itoa(srw.status), @@ -91,10 +93,12 @@ func (w *statusWriter) WriteHeader(status int) { func (w *statusWriter) Write(b []byte) (int, error) { if w.status == 0 { - w.status = 200 + w.status = http.StatusOK } + n, err := w.ResponseWriter.Write(b) w.length += n + return n, err } @@ -103,5 +107,6 @@ func filterRequestSource(source string) string { case "uptime", "kubernetes", "nginx", "livetest": return source } + return "" } diff --git a/http/middleware/response_header.go b/http/middleware/response_header.go index 59a6ad158..2a6714b68 100644 --- a/http/middleware/response_header.go +++ b/http/middleware/response_header.go @@ -28,6 +28,7 @@ func ClientID(next http.Handler) http.Handler { w.Header().Add(ClientIDHeaderName, claim.AuthorizedParty) } } + next.ServeHTTP(w, r) }) } @@ -41,5 +42,6 @@ func (c clientIDClaim) Valid() error { if c.AuthorizedParty == "" { return ErrEmptyAuthorizedParty } + return nil } diff --git a/http/oauth2/authorizer.go b/http/oauth2/authorizer.go index d7058eeba..3cba3f5b2 100644 --- a/http/oauth2/authorizer.go +++ b/http/oauth2/authorizer.go @@ -51,7 +51,6 @@ func (a *Authorizer) WithScope(tok string) *Authorizer { // Error: writes all errors directly to response, returns unchanged context and false func (a *Authorizer) Authorize(r *http.Request, w http.ResponseWriter) (context.Context, bool) { ctx, ok := introspectRequest(r, w, a.introspection) - // Check if introspection was successful if !ok { return ctx, ok } @@ -60,6 +59,7 @@ func (a *Authorizer) Authorize(r *http.Request, w http.ResponseWriter) (context. // Check if the scope is valid for this user ok = validateScope(ctx, w, a.scope) } + return ctx, ok } @@ -68,6 +68,7 @@ func validateScope(ctx context.Context, w http.ResponseWriter, req Scope) bool { http.Error(w, fmt.Sprintf("Forbidden - requires scope %q", req), http.StatusForbidden) return false } + return true } diff --git a/http/oauth2/example_multi_backend_test.go b/http/oauth2/example_multi_backend_test.go index 4637f7491..91f3c447e 100644 --- a/http/oauth2/example_multi_backend_test.go +++ b/http/oauth2/example_multi_backend_test.go @@ -5,6 +5,7 @@ package oauth2_test import ( "context" "fmt" + "net/http" "net/http/httptest" "github.com/pace/bricks/http/oauth2" @@ -21,6 +22,7 @@ func (b multiAuthBackends) IntrospectToken(ctx context.Context, token string) (r return } } + return nil, oauth2.ErrInvalidToken } @@ -33,6 +35,7 @@ func (b *authBackend) IntrospectToken(ctx context.Context, token string) (*oauth Backend: b, }, nil } + return nil, oauth2.ErrInvalidToken } @@ -42,14 +45,13 @@ func Example_multipleBackends() { // authorized the request. The actual value used for the backend depends on // your implementation: you can use constants or pointers, like in this // example. - authorizer := oauth2.NewAuthorizer(multiAuthBackends{ &authBackend{"A", "token-a"}, &authBackend{"B", "token-b"}, &authBackend{"C", "token-c"}, }, nil) - r := httptest.NewRequest("GET", "/some/endpoint", nil) + r := httptest.NewRequest(http.MethodGet, "/some/endpoint", nil) r.Header.Set("Authorization", "Bearer token-b") if authorizer.CanAuthorizeRequest(r) { diff --git a/http/oauth2/middleware/scopes_middleware.go b/http/oauth2/middleware/scopes_middleware.go index e612776fc..68dd9b3ac 100644 --- a/http/oauth2/middleware/scopes_middleware.go +++ b/http/oauth2/middleware/scopes_middleware.go @@ -7,6 +7,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/pace/bricks/http/oauth2" ) @@ -34,6 +35,7 @@ func (m *ScopesMiddleware) Handler(next http.Handler) http.Handler { next.ServeHTTP(w, r) return } + http.Error(w, fmt.Sprintf("Forbidden - requires scope %q", m.RequiredScopes[routeName]), http.StatusForbidden) }) } diff --git a/http/oauth2/middleware/scopes_middleware_test.go b/http/oauth2/middleware/scopes_middleware_test.go index 7089f53a4..20f4322b7 100644 --- a/http/oauth2/middleware/scopes_middleware_test.go +++ b/http/oauth2/middleware/scopes_middleware_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/pace/bricks/http/oauth2" ) @@ -23,12 +24,14 @@ func TestScopesMiddleware(t *testing.T) { resp := w.Result() body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { t.Fatal(err) } - if got, ex := resp.StatusCode, 200; got != ex { + if got, ex := resp.StatusCode, http.StatusOK; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } @@ -45,7 +48,9 @@ func TestScopesMiddleware(t *testing.T) { resp := w.Result() body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { t.Fatal(err) } @@ -64,21 +69,23 @@ func setupRouter(requiredScope string, tokenScope string) *mux.Router { rs := RequiredScopes{ "GetFoo": oauth2.Scope(requiredScope), } - m := NewScopesMiddleware(rs) // nolint: staticcheck - om := oauth2.NewMiddleware(&tokenIntrospecter{returnedScope: tokenScope}) // nolint: staticcheck + m := NewScopesMiddleware(rs) + om := oauth2.NewMiddleware(&tokenIntrospecter{returnedScope: tokenScope}) //nolint:staticcheck r := mux.NewRouter() r.Use(om.Handler) r.Use(m.Handler) r.HandleFunc("/foo", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "Hello") + if _, err := fmt.Fprint(w, "Hello"); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } }).Name("GetFoo") return r } func setupRequest() *http.Request { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Authorization", "Bearer some-token") return req diff --git a/http/oauth2/oauth2.go b/http/oauth2/oauth2.go index e3bb6927d..2bed8b0a5 100644 --- a/http/oauth2/oauth2.go +++ b/http/oauth2/oauth2.go @@ -37,6 +37,7 @@ func (m *Middleware) Handler(next http.Handler) http.Handler { if !isOk { return } + next.ServeHTTP(w, r.WithContext(ctx)) }) } @@ -72,6 +73,7 @@ func introspectRequest(r *http.Request, w http.ResponseWriter, tokenIntro TokenI http.Error(w, "Unauthorized", http.StatusUnauthorized) return nil, false } + s, err := tokenIntro.IntrospectToken(ctx, tok) if err != nil { switch { @@ -85,16 +87,21 @@ func introspectRequest(r *http.Request, w http.ResponseWriter, tokenIntro TokenI http.Error(w, err.Error(), http.StatusInternalServerError) } + log.Req(r).Info().Msg(err.Error()) + return nil, false } + t := fromIntrospectResponse(s, tok) ctx = security.ContextWithToken(ctx, &t) + log.Req(r).Info(). Str("client_id", t.clientID). Str("user_id", t.userID). Msg("Oauth2") span.LogFields(olog.String("client_id", t.clientID), olog.String("user_id", t.userID)) + return ctx, true } @@ -108,6 +115,7 @@ func fromIntrospectResponse(s *IntrospectResponse, tokenValue string) token { } t.scope = Scope(s.Scope) + return t } @@ -125,60 +133,73 @@ func Request(r *http.Request) *http.Request { // the permissions represented by the provided scope are included in the valid scope. func HasScope(ctx context.Context, scope Scope) bool { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return false } + return scope.IsIncludedIn(oauth2token.scope) } // UserID returns the userID stored in ctx func UserID(ctx context.Context) (string, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return "", false } + return oauth2token.userID, true } // AuthTime returns the auth time stored in ctx as unix timestamp func AuthTime(ctx context.Context) (int64, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return 0, false } + return oauth2token.authTime, true } // Scopes returns the scopes stored in ctx func Scopes(ctx context.Context) []string { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return []string{} } + return oauth2token.scope.toSlice() } func AddScope(ctx context.Context, scope string) context.Context { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return ctx } + oauth2token.scope = oauth2token.scope.Add(scope) + return security.ContextWithToken(ctx, oauth2token) } // ClientID returns the clientID stored in ctx func ClientID(ctx context.Context) (string, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return "", false } + return oauth2token.clientID, true } @@ -186,10 +207,12 @@ func ClientID(ctx context.Context) (string, bool) { // authorization backend for the token. func Backend(ctx context.Context) (interface{}, bool) { tok, _ := security.GetTokenFromContext(ctx) + oauth2token, ok := tok.(*token) if !ok { return nil, false } + return oauth2token.backend, true } @@ -207,6 +230,7 @@ func BearerToken(ctx context.Context) (string, bool) { if tok, ok := security.GetTokenFromContext(ctx); ok { return tok.GetValue(), true } + return "", false } diff --git a/http/oauth2/oauth2_test.go b/http/oauth2/oauth2_test.go index c20e09e19..ccb353402 100644 --- a/http/oauth2/oauth2_test.go +++ b/http/oauth2/oauth2_test.go @@ -58,7 +58,7 @@ func TestHandlerIntrospectErrorAsMiddleware(t *testing.T) { r.Use(m.Handler) r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) req.Header.Set("Authorization", "Bearer some-token") w := httptest.NewRecorder() @@ -66,7 +66,9 @@ func TestHandlerIntrospectErrorAsMiddleware(t *testing.T) { resp := w.Result() body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { t.Fatal(err) } @@ -126,7 +128,7 @@ func TestAuthenticatorWithSuccess(t *testing.T) { for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer bearer") auth := NewAuthorizer(&tokenIntrospectedSuccessful{&IntrospectResponse{ @@ -138,20 +140,25 @@ func TestAuthenticatorWithSuccess(t *testing.T) { if tC.expectedScopes != "" { auth = auth.WithScope(tC.expectedScopes) } + authorize, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { t.Fatal(err) } + if !b || authorize == nil { t.Errorf("Expected succesfull Authentication, but was not succesfull with code %d and body %q", resp.StatusCode, string(body)) return } + to, _ := security.GetTokenFromContext(authorize) - tok, ok := to.(*token) + tok, ok := to.(*token) if !ok || tok.value != "bearer" || tok.scope != Scope(tC.userScopes) || tok.clientID != tC.clientId || tok.userID != tC.userId { t.Errorf("Expected %v but got %v", auth.introspection.(*tokenIntrospectedSuccessful).response, tok) } @@ -168,23 +175,28 @@ func TestAuthenticationSuccessScopeError(t *testing.T) { }}, &Config{}).WithScope("DE") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer bearer") _, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { t.Fatal(err) } + if b { t.Errorf("Expected error in Authentication, but was succesfull with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusForbidden; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), "Forbidden - requires scope \"DE\"\n"; got != ex { t.Errorf("Expected status code %q, got %q", ex, got) } @@ -226,18 +238,22 @@ func TestAuthenticationWithErrors(t *testing.T) { t.Run(tC.desc, func(t *testing.T) { auth := NewAuthorizer(&tokenInspectorWithError{returnedErr: tC.returnedErr}, &Config{}) w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer bearer") + _, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) + defer resp.Body.Close() + if err != nil { t.Fatal(err) } + if b { - t.Errorf("Expected error in authentication, but was succesful with code %d and body %v", resp.StatusCode, string(body)) + t.Errorf("Expected error in authentication, but was successful with code %d and body %v", resp.StatusCode, string(body)) } if got, ex := w.Code, tC.expectedCode; got != ex { @@ -266,8 +282,10 @@ func Example() { if err != nil { panic(err) } + return } + _, err := fmt.Fprintf(w, "Your client may not have the right scopes to see the secret code") if err != nil { panic(err) @@ -290,7 +308,7 @@ func TestRequest(t *testing.T) { scope: Scope("scope1 scope2"), } - r := httptest.NewRequest("GET", "http://example.com", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com", nil) ctx := security.ContextWithToken(r.Context(), &to) r = r.WithContext(ctx) @@ -303,7 +321,7 @@ func TestRequest(t *testing.T) { } func TestRequestWithNoToken(t *testing.T) { - r := httptest.NewRequest("GET", "http://example.com", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com", nil) r2 := Request(r) header := r2.Header.Get("Authorization") @@ -402,6 +420,7 @@ func TestUnsuccessfulAccessors(t *testing.T) { func TestWithBearerToken(t *testing.T) { ctx := context.Background() ctx = WithBearerToken(ctx, "some access token") + token, ok := security.GetTokenFromContext(ctx) if !ok || token.GetValue() != "some access token" { t.Error("could not store bearer token in context") @@ -419,14 +438,17 @@ func TestAddScope(t *testing.T) { wantCtx := context.Background() wantCtx = WithBearerToken(wantCtx, "some access token") + tok, ok := security.GetTokenFromContext(wantCtx) if !ok { t.Error("could not get token from context") } + ouathToken, ok := tok.(*token) if !ok { t.Error("could not convert token to oauth token") } + ouathToken.scope = "scope1" tests := []struct { @@ -446,14 +468,17 @@ func TestAddScope(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := AddScope(tt.args.ctx, tt.args.scope) + gotTok, ok := security.GetTokenFromContext(got) if !ok { t.Error("could not get token from context") } + gotOauthToken, ok := gotTok.(*token) if !ok { t.Error("could not convert token to oauth token") } + if gotOauthToken.scope != tt.want.scope { t.Errorf("AddScope() = %v, want %v", gotOauthToken.scope, tt.want.scope) } diff --git a/http/oauth2/scope_test.go b/http/oauth2/scope_test.go index 01d95936c..bc0181657 100644 --- a/http/oauth2/scope_test.go +++ b/http/oauth2/scope_test.go @@ -35,6 +35,7 @@ func TestScope_Add(t *testing.T) { type args struct { scope string } + tests := []struct { name string s Scope diff --git a/http/router.go b/http/router.go index 63fd344b7..15d2648b9 100755 --- a/http/router.go +++ b/http/router.go @@ -6,9 +6,8 @@ import ( "net/http" "net/http/pprof" - "github.com/pace/bricks/maintenance/tracing" - "github.com/gorilla/mux" + "github.com/pace/bricks/http/middleware" "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/errors" @@ -16,6 +15,7 @@ import ( "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/maintenance/metric" + "github.com/pace/bricks/maintenance/tracing" redactMdw "github.com/pace/bricks/pkg/redact/middleware" ) @@ -61,10 +61,10 @@ func Router() *mux.Router { // report Client ID back to caller r.Use(middleware.ClientID) - // support redacting of data accross the full request scope + // support redacting of data across the full request scope r.Use(redactMdw.Redact) - // makes some infos about the request accessable from the context + // makes some infos about the request accessible from the context r.Use(middleware.RequestInContext) // for prometheus diff --git a/http/router_test.go b/http/router_test.go index bf43f761b..ef2688c6f 100644 --- a/http/router_test.go +++ b/http/router_test.go @@ -11,19 +11,20 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/stretchr/testify/require" + "github.com/pace/bricks/http/jsonapi/runtime" "github.com/pace/bricks/maintenance/health" - "github.com/stretchr/testify/require" ) func TestHealthHandler(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/liveness", nil) + req := httptest.NewRequest(http.MethodGet, "/health/liveness", nil) Router().ServeHTTP(rec, req) resp := rec.Result() - require.Equal(t, 200, resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode) data, err := io.ReadAll(resp.Body) require.NoError(t, err) @@ -52,14 +53,16 @@ func TestHealthRoutes(t *testing.T) { expectedResult: "OK\n", title: "route liveness", }} + health.SetCustomReadinessCheck(func(w http.ResponseWriter, r *http.Request) { _, err := fmt.Fprint(w, "Ready") require.NoError(t, err) }) + for _, tC := range tCs { t.Run(tC.title, func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", tC.route, nil) + req := httptest.NewRequest(http.MethodGet, tC.route, nil) Router().ServeHTTP(rec, req) @@ -73,13 +76,13 @@ func TestHealthRoutes(t *testing.T) { func TestCustomRoutes(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo/bar", nil) + req := httptest.NewRequest(http.MethodGet, "/foo/bar", nil) // example of a service foo exposing api bar fooRouter := mux.NewRouter() fooRouter.HandleFunc("/foo/bar", func(w http.ResponseWriter, r *http.Request) { runtime.WriteError(w, http.StatusNotImplemented, fmt.Errorf("Some error")) - }).Methods("GET") + }).Methods(http.MethodGet) r := Router() // service routers will be mounted like this diff --git a/http/security/apikey/authorizer.go b/http/security/apikey/authorizer.go index 72e9619f1..6a4742df2 100644 --- a/http/security/apikey/authorizer.go +++ b/http/security/apikey/authorizer.go @@ -48,12 +48,16 @@ func (a *Authorizer) Authorize(r *http.Request, w http.ResponseWriter) (context. if key == "" { log.Req(r).Info().Msg("No Api Key present in field " + a.authConfig.Name) http.Error(w, "Unauthorized", http.StatusUnauthorized) + return r.Context(), false } + if key == a.apiKey { return security.ContextWithToken(r.Context(), &token{key}), true } + http.Error(w, "ApiKey not valid", http.StatusUnauthorized) + return r.Context(), false } diff --git a/http/security/apikey/authorizer_test.go b/http/security/apikey/authorizer_test.go index 781141aff..9939c1beb 100644 --- a/http/security/apikey/authorizer_test.go +++ b/http/security/apikey/authorizer_test.go @@ -13,23 +13,30 @@ func TestApiKeyAuthenticationSuccessful(t *testing.T) { auth := NewAuthorizer(&Config{Name: "Authorization"}, "testkey") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer testkey") _, b := auth.Authorize(r, w) resp := w.Result() + defer func() { + _ = resp.Body.Close() + }() + body, err := io.ReadAll(resp.Body) - resp.Body.Close() + if err != nil { t.Fatal(err) } + if !b { t.Errorf("Expected no error in authentication, but failed with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusOK; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), ""; got != ex { t.Errorf("Expected status code %q, got %q", ex, got) } @@ -39,23 +46,30 @@ func TestApiKeyAuthenticationError(t *testing.T) { auth := NewAuthorizer(&Config{Name: "Authorization"}, "testkey") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Add("Authorization", "Bearer wrongKey") _, b := auth.Authorize(r, w) resp := w.Result() + defer func() { + _ = resp.Body.Close() + }() + body, err := io.ReadAll(resp.Body) - resp.Body.Close() + if err != nil { t.Fatal(err) } + if b { t.Errorf("Expected error in Authentication, but was succesfull with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusUnauthorized; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), "ApiKey not valid\n"; got != ex { t.Errorf("Expected error massage %q, got %q", ex, got) } @@ -65,22 +79,27 @@ func TestApiKeyAuthenticationNoKey(t *testing.T) { auth := NewAuthorizer(&Config{Name: "Authorization"}, "testkey") w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/", nil) + r := httptest.NewRequest(http.MethodGet, "/", nil) _, b := auth.Authorize(r, w) resp := w.Result() body, err := io.ReadAll(resp.Body) + resp.Body.Close() + if err != nil { t.Fatal(err) } + if b { t.Errorf("Expected error in Authentication, but was succesfull with code %d and body %v", resp.StatusCode, string(body)) } + if got, ex := w.Code, http.StatusUnauthorized; got != ex { t.Errorf("Expected status code %d, got %d", ex, got) } + if got, ex := string(body), "Unauthorized\n"; got != ex { t.Errorf("Expected status code %q, got %q", ex, got) } diff --git a/http/security/helper.go b/http/security/helper.go index 1528aeab3..c9c4a74bb 100644 --- a/http/security/helper.go +++ b/http/security/helper.go @@ -34,6 +34,7 @@ func GetBearerTokenFromHeader(authHeader string) string { if !hasPrefix { return "" } + return strings.TrimPrefix(authHeader, headerPrefix) } @@ -48,7 +49,9 @@ func GetTokenFromContext(ctx context.Context) (Token, bool) { if val == nil { return nil, false } + tok, ok := val.(Token) + return tok, ok } diff --git a/http/server.go b/http/server.go index e73f673bc..ba4b9ee59 100644 --- a/http/server.go +++ b/http/server.go @@ -9,6 +9,7 @@ import ( "time" "github.com/caarlos0/env/v10" + "github.com/pace/bricks/maintenance/log" ) @@ -31,6 +32,7 @@ func (cfg config) addrOrPort() string { if cfg.Addr != "" { return cfg.Addr } + return ":" + strconv.Itoa(cfg.Port) } diff --git a/http/server_test.go b/http/server_test.go index 9d682df7f..4301bb90c 100644 --- a/http/server_test.go +++ b/http/server_test.go @@ -6,18 +6,34 @@ import ( "os" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestServer(t *testing.T) { // Defaults - os.Setenv("ADDR", "") - os.Setenv("PORT", "") - os.Setenv("MAX_HEADER_BYTES", "") - os.Setenv("IDLE_TIMEOUT", "") - os.Setenv("READ_TIMEOUT", "") - os.Setenv("WRITE_TIMEOUT", "") + err := os.Setenv("ADDR", "") + require.NoError(t, err) + + err = os.Setenv("PORT", "") + require.NoError(t, err) + + err = os.Setenv("MAX_HEADER_BYTES", "") + require.NoError(t, err) + + err = os.Setenv("IDLE_TIMEOUT", "") + require.NoError(t, err) + + err = os.Setenv("READ_TIMEOUT", "") + require.NoError(t, err) + + err = os.Setenv("WRITE_TIMEOUT", "") + require.NoError(t, err) + parseConfig() + s := Server(nil) + cases := []struct { env string expected, actual interface{} @@ -35,14 +51,28 @@ func TestServer(t *testing.T) { } // custom - os.Setenv("ADDR", ":5432") - os.Setenv("PORT", "1234") - os.Setenv("MAX_HEADER_BYTES", "100") - os.Setenv("IDLE_TIMEOUT", "1s") - os.Setenv("READ_TIMEOUT", "2s") - os.Setenv("WRITE_TIMEOUT", "3s") + err = os.Setenv("ADDR", ":5432") + require.NoError(t, err) + + err = os.Setenv("PORT", "1234") + require.NoError(t, err) + + err = os.Setenv("MAX_HEADER_BYTES", "100") + require.NoError(t, err) + + err = os.Setenv("IDLE_TIMEOUT", "1s") + require.NoError(t, err) + + err = os.Setenv("READ_TIMEOUT", "2s") + require.NoError(t, err) + + err = os.Setenv("WRITE_TIMEOUT", "3s") + require.NoError(t, err) + parseConfig() + s = Server(nil) + cases = []struct { env string expected, actual interface{} @@ -64,6 +94,7 @@ func TestEnvironment(t *testing.T) { // Defaults os.Setenv("ENVIRONMENT", "") parseConfig() + if Environment() != "edge" { t.Errorf("Expected edge, got: %q", Environment()) } @@ -71,6 +102,7 @@ func TestEnvironment(t *testing.T) { // custom os.Setenv("ENVIRONMENT", "production") parseConfig() + if Environment() != "production" { t.Errorf("Expected production, got: %q", Environment()) } diff --git a/http/transport/attempt_round_tripper.go b/http/transport/attempt_round_tripper.go index 1212a8799..da2e5f269 100644 --- a/http/transport/attempt_round_tripper.go +++ b/http/transport/attempt_round_tripper.go @@ -37,11 +37,13 @@ func attemptFromCtx(ctx context.Context) int32 { if !ok { return 0 } + return a } func transportWithAttempt(rt http.RoundTripper) http.RoundTripper { ar := &attemptRoundTripper{attempt: 0} ar.SetTransport(rt) + return ar } diff --git a/http/transport/chainable.go b/http/transport/chainable.go index 970745b95..ed77e2fd2 100644 --- a/http/transport/chainable.go +++ b/http/transport/chainable.go @@ -41,7 +41,7 @@ type RoundTripperChain struct { } // Chain returns a round tripper chain with the specified chainable round trippers and http.DefaultTransport as transport. -// The transport can be overriden by using the Final method. +// The transport can be overridden by using the Final method. func Chain(rt ...ChainableRoundTripper) *RoundTripperChain { final := &finalRoundTripper{transport: http.DefaultTransport} c := &RoundTripperChain{first: final, current: final, final: final} @@ -67,6 +67,7 @@ func (c *RoundTripperChain) Use(rt ChainableRoundTripper) *RoundTripperChain { c.current.SetTransport(rt) rt.SetTransport(c.final) + c.current = rt return c diff --git a/http/transport/chainable_test.go b/http/transport/chainable_test.go index 924c2eff1..c1a7dbc0c 100644 --- a/http/transport/chainable_test.go +++ b/http/transport/chainable_test.go @@ -33,12 +33,18 @@ func TestRoundTripperRace(t *testing.T) { go func() { for i := 0; i < 10; i++ { - client.Get(server.URL + "/test001") // nolint: errcheck + resp, err := client.Get(server.URL + "/test001") + if err == nil { + _ = resp.Body.Close() + } } }() for i := 0; i < 10; i++ { - client.Get(server.URL + "/test002") // nolint: errcheck + resp, err := client.Get(server.URL + "/test002") + if err == nil { + _ = resp.Body.Close() + } } } @@ -48,16 +54,17 @@ func TestRoundTripperChaining(t *testing.T) { c := Chain().Final(transport) url := "/foo" - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest(http.MethodGet, url, nil) _, err := c.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } - if v := transport.req.Method; v != "GET" { - t.Errorf("Expected method %q, got %q", "GET", v) + if v := transport.req.Method; v != http.MethodGet { + t.Errorf("Expected method %q, got %q", http.MethodGet, v) } + if v := transport.req.URL.String(); v != url { t.Errorf("Expected URL %q, got %q", url, v) } @@ -68,19 +75,21 @@ func TestRoundTripperChaining(t *testing.T) { c.Use(&addHeaderRoundTripper{key: "foo", value: "bar"}).Final(transport) url := "/foo" - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest(http.MethodGet, url, nil) _, err := c.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } - if v := transport.req.Method; v != "GET" { - t.Errorf("Expected method %v, got %v", "GET", v) + if v := transport.req.Method; v != http.MethodGet { + t.Errorf("Expected method %v, got %v", http.MethodGet, v) } + if v := transport.req.URL.String(); v != url { t.Errorf("Expected URL %v, got %v", url, v) } + if v, ex := transport.req.Header.Get("foo"), "bar"; v != ex { t.Errorf("Expected header foo to eq %v, got %v", ex, v) } @@ -93,22 +102,25 @@ func TestRoundTripperChaining(t *testing.T) { c := Chain(rt1, rt2, rt3).Final(transport) url := "/foo" - req := httptest.NewRequest("GET", url, nil) + req := httptest.NewRequest(http.MethodGet, url, nil) _, err := c.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) } - if v := transport.req.Method; v != "GET" { - t.Errorf("Expected method %v, got %v", "GET", v) + if v := transport.req.Method; v != http.MethodGet { + t.Errorf("Expected method %v, got %v", http.MethodGet, v) } + if v := transport.req.URL.String(); v != url { t.Errorf("Expected URL %v, got %v", url, v) } + if v, ex := transport.req.Header.Get("foo"), "baroverride"; v != ex { t.Errorf("Expected header foo to eq %v, got %v", ex, v) } + if v, ex := transport.req.Header.Get("Authorization"), "Bearer 123"; v != ex { t.Errorf("Expected header Authorization to eq %v, got %v", ex, v) } diff --git a/http/transport/circuit_breaker_tripper.go b/http/transport/circuit_breaker_tripper.go index 25b8a50e4..938ed904d 100644 --- a/http/transport/circuit_breaker_tripper.go +++ b/http/transport/circuit_breaker_tripper.go @@ -48,6 +48,7 @@ func NewCircuitBreakerTripper(settings gobreaker.Settings) *circuitBreakerTrippe }, []string{"from", "to"}) var ok bool + var are prometheus.AlreadyRegisteredError if err := prometheus.Register(stateSwitchCounterVec); errors.As(err, &are) { stateSwitchCounterVec, ok = are.ExistingCollector.(*prometheus.CounterVec) @@ -84,7 +85,7 @@ func (c *circuitBreakerTripper) SetTransport(rt http.RoundTripper) { // RoundTrip executes a single HTTP transaction via Transport() func (c *circuitBreakerTripper) RoundTrip(req *http.Request) (*http.Response, error) { resp, err := c.breaker.Execute(func() (interface{}, error) { - return c.transport.RoundTrip(req) + return c.transport.RoundTrip(req) //nolint:bodyclose }) if err != nil { switch { diff --git a/http/transport/circuit_breaker_tripper_test.go b/http/transport/circuit_breaker_tripper_test.go index 4272f3f7a..f9dfcd9bf 100644 --- a/http/transport/circuit_breaker_tripper_test.go +++ b/http/transport/circuit_breaker_tripper_test.go @@ -13,7 +13,7 @@ import ( ) func TestCircuitBreakerTripper(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) t.Run("with_default_settings", func(t *testing.T) { breaker := NewDefaultCircuitBreakerTripper("testcircuitbreaker") diff --git a/http/transport/default_transport_test.go b/http/transport/default_transport_test.go index 78e656c0b..3a0f7da55 100644 --- a/http/transport/default_transport_test.go +++ b/http/transport/default_transport_test.go @@ -12,14 +12,22 @@ import ( "os" "testing" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pace/bricks/maintenance/log" ) func TestNewDefaultTransportChain(t *testing.T) { old := os.Getenv("HTTP_TRANSPORT_DUMP") - defer os.Setenv("HTTP_TRANSPORT_DUMP", old) - os.Setenv("HTTP_TRANSPORT_DUMP", "request,response,body") + + defer func() { + err := os.Setenv("HTTP_TRANSPORT_DUMP", old) + require.NoError(t, err) + }() + + err := os.Setenv("HTTP_TRANSPORT_DUMP", "request,response,body") + require.NoError(t, err) t.Run("Finalizer not set explicitly", func(t *testing.T) { b := "Hello World" @@ -29,19 +37,26 @@ func TestNewDefaultTransportChain(t *testing.T) { retry++ if retry == 5 { w.WriteHeader(http.StatusOK) - fmt.Fprint(w, b) + _, err := fmt.Fprint(w, b) + require.NoError(t, err) + return } + w.WriteHeader(http.StatusBadGateway) - fmt.Fprint(w, b) + + _, err := fmt.Fprint(w, b) + require.NoError(t, err) })) - req := httptest.NewRequest("GET", ts.URL, nil) + req := httptest.NewRequest(http.MethodGet, ts.URL, nil) req = req.WithContext(log.WithContext(context.Background())) + resp, err := tr.RoundTrip(req) if err != nil { t.Fatal(err) } + ts.Close() assert.Equal(t, retry, 5) @@ -60,8 +75,9 @@ func TestNewDefaultTransportChain(t *testing.T) { tr := &transportWithBody{body: "abc"} dt := NewDefaultTransportChain().Final(tr) - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req = req.WithContext(log.WithContext(context.Background())) + resp, err := dt.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) @@ -71,6 +87,7 @@ func TestNewDefaultTransportChain(t *testing.T) { if err != nil { t.Fatalf("Expected readable body, got error: %q", err.Error()) } + if ex, got := tr.body, string(body); ex != got { t.Errorf("Expected body %q, got %q", ex, got) } @@ -84,7 +101,7 @@ type transportWithBody struct { func (t *transportWithBody) RoundTrip(req *http.Request) (*http.Response, error) { body := io.NopCloser(bytes.NewReader([]byte(t.body))) - resp := &http.Response{Body: body, StatusCode: 200} + resp := &http.Response{Body: body, StatusCode: http.StatusOK} return resp, nil } diff --git a/http/transport/dump_options.go b/http/transport/dump_options.go index 2e37d1f97..2f0f58d2b 100644 --- a/http/transport/dump_options.go +++ b/http/transport/dump_options.go @@ -13,6 +13,7 @@ func NewDumpOptions(opts ...DumpOption) (DumpOptions, error) { return nil, err } } + return dumpOptions, nil } @@ -29,6 +30,7 @@ func (o DumpOptions) AnyEnabled(options ...string) bool { return true } } + return false } @@ -39,7 +41,9 @@ func WithDumpOption(option string, enabled bool) DumpOption { if !isDumpOptionValid(option) { return fmt.Errorf("invalid dump option %q", option) } + o[option] = enabled + return nil } } @@ -80,8 +84,10 @@ func mergeDumpOptions(globalOptions, reqOptions DumpOptions) DumpOptions { // req option already exists, ignore the global one continue } + reqOptions[globalKey] = globalVal } + return reqOptions } @@ -91,11 +97,13 @@ func CtxWithDumpRoundTripperOptions(ctx context.Context, opts DumpOptions) conte if opts == nil { return ctx } + return context.WithValue(ctx, dumpRoundTripperCtxKey{}, opts) } func DumpRoundTripperOptionsFromCtx(ctx context.Context) DumpOptions { do := ctx.Value(dumpRoundTripperCtxKey{}) dumpOptions, _ := do.(DumpOptions) + return dumpOptions } diff --git a/http/transport/dump_round_tripper.go b/http/transport/dump_round_tripper.go index 709408b96..3405aee59 100644 --- a/http/transport/dump_round_tripper.go +++ b/http/transport/dump_round_tripper.go @@ -38,18 +38,23 @@ type dumpRoundTripperConfig struct { func roundTripConfigViaEnv() DumpRoundTripperOption { return func(rt *DumpRoundTripper) (*DumpRoundTripper, error) { var cfg dumpRoundTripperConfig + err := env.Parse(&cfg) if err != nil { return rt, fmt.Errorf("failed to parse dump round tripper environment: %w", err) } + for _, option := range cfg.Options { if !isDumpOptionValid(option) { return nil, fmt.Errorf("invalid dump option %q", option) } + rt.options[option] = true } + rt.blacklistAnyDumpPrefixes = cfg.BlacklistAnyDumpPrefixes rt.blacklistBodyDumpPrefixes = cfg.BlacklistBodyDumpPrefixes + return rt, nil } } @@ -60,8 +65,10 @@ func RoundTripConfig(dumpOptions ...string) DumpRoundTripperOption { if !isDumpOptionValid(option) { return nil, fmt.Errorf("invalid dump option %q", option) } + rt.options[option] = true } + return rt, nil } } @@ -73,19 +80,23 @@ func NewDumpRoundTripperEnv() *DumpRoundTripper { if err != nil { log.Fatalf("failed to setup NewDumpRoundTripperEnv: %v", err) } + return rt } // NewDumpRoundTripper return the roundtripper with configured options func NewDumpRoundTripper(options ...DumpRoundTripperOption) (*DumpRoundTripper, error) { rt := &DumpRoundTripper{options: DumpOptions{}} + var err error + for _, option := range options { rt, err = option(rt) if err != nil { return rt, err } } + return rt, nil } @@ -108,12 +119,14 @@ func (l *DumpRoundTripper) ContainsBlacklistedPrefix(url *url.URL, blacklist []s if len(blacklist) == 0 { return false } + for _, prefix := range blacklist { // TODO (juf): Do benchmark and compare against using pre-constructed prefix-tree if strings.HasPrefix(url.String(), prefix) { return true } } + return false } @@ -156,6 +169,7 @@ func (l *DumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) if options.IsEnabled(DumpRoundTripperOptionRequest) { dl = dl.Bytes(DumpRoundTripperOptionRequest, reqDump) } + if options.IsEnabled(DumpRoundTripperOptionRequestHEX) { dl = dl.Str(DumpRoundTripperOptionRequestHEX, hex.EncodeToString(reqDump)) } @@ -177,9 +191,11 @@ func (l *DumpRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) if redactor != nil { respDump = []byte(redactor.Mask(string(respDump))) } + if options.IsEnabled(DumpRoundTripperOptionResponse) { dl = dl.Bytes(DumpRoundTripperOptionResponse, respDump) } + if options.IsEnabled(DumpRoundTripperOptionResponseHEX) { dl = dl.Str(DumpRoundTripperOptionResponseHEX, hex.EncodeToString(respDump)) } diff --git a/http/transport/dump_round_tripper_test.go b/http/transport/dump_round_tripper_test.go index e4b2bef1c..c3da4eb86 100644 --- a/http/transport/dump_round_tripper_test.go +++ b/http/transport/dump_round_tripper_test.go @@ -5,6 +5,7 @@ package transport import ( "bytes" "context" + "net/http" "net/http/httptest" "os" "testing" @@ -24,8 +25,9 @@ func TestNewDumpRoundTripperEnv(t *testing.T) { rt := NewDumpRoundTripperEnv() assert.NotNil(t, rt) - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err := rt.RoundTrip(req) @@ -40,8 +42,14 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedComplete(t *testing.T) { ctx := log.Output(out).WithContext(context.Background()) require.NotPanics(t, func() { - defer os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX")) - os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", "https://please-ignore-me") + defer func() { + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX")) + require.NoError(t, err) + + err = os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_ALL_URL_PREFIX", "https://please-ignore-me") + require.NoError(t, err) + }() + rt, err := NewDumpRoundTripper( roundTripConfigViaEnv(), RoundTripConfig( @@ -54,8 +62,9 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedComplete(t *testing.T) { require.NoError(t, err) assert.NotNil(t, rt) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err = rt.RoundTrip(req) @@ -72,7 +81,7 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedComplete(t *testing.T) { assert.Equal(t, "", out.String()) - reqWithPrefix := httptest.NewRequest("GET", "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) + reqWithPrefix := httptest.NewRequest(http.MethodGet, "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) reqWithPrefix = reqWithPrefix.WithContext(ctx) _, err = rt.RoundTrip(reqWithPrefix) @@ -86,8 +95,14 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedBody(t *testing.T) { ctx := log.Output(out).WithContext(context.Background()) require.NotPanics(t, func() { - defer os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX")) - os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", "https://please-ignore-me") + defer func() { + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", os.Getenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX")) + require.NoError(t, err) + }() + + err := os.Setenv("HTTP_TRANSPORT_DUMP_DISABLE_DUMP_BODY_URL_PREFIX", "https://please-ignore-me") + require.NoError(t, err) + rt, err := NewDumpRoundTripper( roundTripConfigViaEnv(), RoundTripConfig( @@ -100,8 +115,9 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedBody(t *testing.T) { require.NoError(t, err) assert.NotNil(t, rt) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err = rt.RoundTrip(req) @@ -118,7 +134,7 @@ func TestNewDumpRoundTripperEnvDisablePrefixBasedBody(t *testing.T) { assert.Equal(t, "", out.String()) - reqWithPrefix := httptest.NewRequest("GET", "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) + reqWithPrefix := httptest.NewRequest(http.MethodGet, "https://please-ignore-me.org/foo/", bytes.NewBufferString("Foo")) reqWithPrefix = reqWithPrefix.WithContext(ctx) _, err = rt.RoundTrip(reqWithPrefix) @@ -148,8 +164,9 @@ func TestNewDumpRoundTripper(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err = rt.RoundTrip(req) @@ -176,9 +193,10 @@ func TestNewDumpRoundTripperRedacted(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo DE12345678909876543210 bar")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo DE12345678909876543210 bar")) ctx = redact.Default.WithContext(ctx) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err = rt.RoundTrip(req) @@ -203,9 +221,10 @@ func TestNewDumpRoundTripperRedactedBasicAuth(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Authorization: Basic ZGVtbzpwQDU1dzByZA==")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Authorization: Basic ZGVtbzpwQDU1dzByZA==")) ctx = redact.Default.WithContext(ctx) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err = rt.RoundTrip(req) @@ -229,8 +248,9 @@ func TestNewDumpRoundTripperSimple(t *testing.T) { ) require.NoError(t, err) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) + rt.SetTransport(&transportWithResponse{}) _, err = rt.RoundTrip(req) @@ -255,7 +275,7 @@ func TestNewDumpRoundTripperContextOptionsOverwrite(t *testing.T) { out := &bytes.Buffer{} ctx := log.Output(out).WithContext(context.Background()) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) _, err = rt.RoundTrip(req) @@ -276,9 +296,10 @@ func TestNewDumpRoundTripperContextOptionsOverwrite(t *testing.T) { WithDumpOption(DumpRoundTripperOptionResponse, false), ) require.NoError(t, err) + ctx = CtxWithDumpRoundTripperOptions(ctx, ctxDumpOptions) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) _, err = rt.RoundTrip(req) @@ -301,7 +322,7 @@ func TestNewDumpRoundTripperContextOptionsOverwriteBody(t *testing.T) { out := &bytes.Buffer{} ctx := log.Output(out).WithContext(context.Background()) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) _, err = rt.RoundTrip(req) @@ -321,9 +342,10 @@ func TestNewDumpRoundTripperContextOptionsOverwriteBody(t *testing.T) { WithDumpOption(DumpRoundTripperOptionBody, false), ) require.NoError(t, err) + ctx = CtxWithDumpRoundTripperOptions(ctx, ctxDumpOptions) - req := httptest.NewRequest("GET", "/foo", bytes.NewBufferString("Foo")) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewBufferString("Foo")) req = req.WithContext(ctx) _, err = rt.RoundTrip(req) diff --git a/http/transport/external_dependency_round_tripper_test.go b/http/transport/external_dependency_round_tripper_test.go index 3fbc5b060..c3c876345 100644 --- a/http/transport/external_dependency_round_tripper_test.go +++ b/http/transport/external_dependency_round_tripper_test.go @@ -8,8 +8,9 @@ import ( "net/http/httptest" "testing" - "github.com/pace/bricks/http/middleware" "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/http/middleware" ) type edRoundTripperMock struct { @@ -24,9 +25,10 @@ func (m *edRoundTripperMock) RoundTrip(req *http.Request) (*http.Response, error func TestExternalDependencyRoundTripper(t *testing.T) { var edc middleware.ExternalDependencyContext + ctx := middleware.ContextWithExternalDependency(context.Background(), &edc) - r := httptest.NewRequest("GET", "http://example.com/test", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) r = r.WithContext(ctx) mock := &edRoundTripperMock{ @@ -46,9 +48,10 @@ func TestExternalDependencyRoundTripper(t *testing.T) { func TestExternalDependencyRoundTripperWithName(t *testing.T) { var edc middleware.ExternalDependencyContext + ctx := middleware.ContextWithExternalDependency(context.Background(), &edc) - r := httptest.NewRequest("GET", "http://example.com/test", nil) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) r = r.WithContext(ctx) mock := &edRoundTripperMock{ diff --git a/http/transport/jaeger_round_tripper.go b/http/transport/jaeger_round_tripper.go index 5054abfd0..b988c7882 100755 --- a/http/transport/jaeger_round_tripper.go +++ b/http/transport/jaeger_round_tripper.go @@ -6,11 +6,11 @@ import ( "fmt" "net/http" - "github.com/pace/bricks/maintenance/log" - "github.com/pace/bricks/maintenance/tracing/wire" - "github.com/opentracing/opentracing-go" olog "github.com/opentracing/opentracing-go/log" + + "github.com/pace/bricks/maintenance/log" + "github.com/pace/bricks/maintenance/tracing/wire" ) // JaegerRoundTripper implements a chainable round tripper for tracing @@ -31,6 +31,7 @@ func (l *JaegerRoundTripper) SetTransport(rt http.RoundTripper) { // RoundTrip executes a HTTP request with distributed tracing func (l *JaegerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { operationName := fmt.Sprintf("%s %s", req.Method, req.URL.Path) + span, ctx := opentracing.StartSpanFromContext(req.Context(), operationName) defer span.Finish() @@ -45,6 +46,7 @@ func (l *JaegerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error if attempt > 0 { span.LogFields(olog.Int("attempt", int(attempt))) } + if err != nil { span.LogFields(olog.Error(err)) return nil, err diff --git a/http/transport/jaeger_round_tripper_test.go b/http/transport/jaeger_round_tripper_test.go index c7b87ca0a..ea7bfcf4a 100644 --- a/http/transport/jaeger_round_tripper_test.go +++ b/http/transport/jaeger_round_tripper_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/opentracing/opentracing-go" + _ "github.com/pace/bricks/maintenance/tracing" ) @@ -20,7 +21,8 @@ func TestJaegerRoundTripper(t *testing.T) { tr := &recordingTransportWithResponse{statusCode: 202} l.SetTransport(tr) - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) + _, err := l.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) @@ -31,6 +33,7 @@ func TestJaegerRoundTripper(t *testing.T) { if strings.Contains(spanString, "attempt") { t.Errorf("Expected attempt to not be included in span %q", spanString) } + exs := []string{`operationName:"GET /foo"`, "numericVal:202"} for _, ex := range exs { if !strings.Contains(spanString, ex) { @@ -44,15 +47,16 @@ func TestJaegerRoundTripper(t *testing.T) { tr := &recordingTransportWithError{err: e} l.SetTransport(tr) - req := httptest.NewRequest("GET", "/bar", nil) - _, err := l.RoundTrip(req) + req := httptest.NewRequest(http.MethodGet, "/bar", nil) + _, err := l.RoundTrip(req) if got, ex := err.Error(), e.Error(); got != ex { t.Fatalf("Expected error %q to be returned, got %q", ex, got) } spanString := fmt.Sprintf("%#v", tr.span) exs := []string{`operationName:"GET /bar"`, `log.Field{key:"error"`} + for _, ex := range exs { if !strings.Contains(spanString, ex) { t.Errorf("Expected %q to be included in span %v", ex, spanString) @@ -60,11 +64,12 @@ func TestJaegerRoundTripper(t *testing.T) { } }) t.Run("With retries", func(t *testing.T) { - tr := &retriedTransport{statusCodes: []int{502, 503, 200}} + tr := &retriedTransport{statusCodes: []int{502, 503, http.StatusOK}} l := Chain(NewDefaultRetryRoundTripper(), &JaegerRoundTripper{}) l.Final(tr) - req := httptest.NewRequest("GET", "/bar", nil) + req := httptest.NewRequest(http.MethodGet, "/bar", nil) + _, err := l.RoundTrip(req) if err != nil { t.Fatalf("Expected err to be nil, got %#v", err) @@ -73,6 +78,7 @@ func TestJaegerRoundTripper(t *testing.T) { span := opentracing.SpanFromContext(tr.ctx) spanString := fmt.Sprintf("%#v", span) exs := []string{`operationName:"GET /bar"`, `log.Field{key:"attempt", fieldType:2, numericVal:3`} + for _, ex := range exs { if !strings.Contains(spanString, ex) { t.Errorf("Expected %q to be included in span %v", ex, spanString) diff --git a/http/transport/locale_round_tripper_test.go b/http/transport/locale_round_tripper_test.go index cc27312d4..0eb2d0e14 100644 --- a/http/transport/locale_round_tripper_test.go +++ b/http/transport/locale_round_tripper_test.go @@ -8,10 +8,10 @@ import ( "net/http/httputil" "testing" - "github.com/pace/bricks/locale" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pace/bricks/locale" ) type roundTripperMock struct { @@ -28,10 +28,11 @@ func TestLocaleRoundTrip(t *testing.T) { lrt := &LocaleRoundTripper{transport: mock} l := locale.NewLocale("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", "Europe/Paris") - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) - lrt.RoundTrip(r.WithContext(locale.WithLocale(context.Background(), l))) // nolint: errcheck + _, err = lrt.RoundTrip(r.WithContext(locale.WithLocale(context.Background(), l))) + require.NoError(t, err) lctx, ok := locale.FromCtx(mock.r.Context()) require.True(t, ok) diff --git a/http/transport/logging_round_tripper.go b/http/transport/logging_round_tripper.go index b6465f236..b7f38fac2 100644 --- a/http/transport/logging_round_tripper.go +++ b/http/transport/logging_round_tripper.go @@ -30,17 +30,18 @@ func (l *LoggingRoundTripper) SetTransport(rt http.RoundTripper) { func (l *LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { ctx := req.Context() startTime := time.Now() - le := log.Ctx(ctx).Debug(). - Str("url", req.URL.String()). - Str("method", req.Method). - Str("sentry:type", "http"). - Str("sentry:category", "http") + le := log.Ctx(ctx).Debug(). //nolint:zerologlint + Str("url", req.URL.String()). + Str("method", req.Method). + Str("sentry:type", "http"). + Str("sentry:category", "http") resp, err := l.Transport().RoundTrip(req) dur := float64(time.Since(startTime)) / float64(time.Millisecond) le = le.Float64("duration", dur) attempt := attemptFromCtx(ctx) + if attempt > 0 { le = le.Int("attempt", int(attempt)) } diff --git a/http/transport/logging_round_tripper_test.go b/http/transport/logging_round_tripper_test.go index 02bf4555b..782b439b8 100644 --- a/http/transport/logging_round_tripper_test.go +++ b/http/transport/logging_round_tripper_test.go @@ -5,6 +5,7 @@ package transport import ( "bytes" "context" + "net/http" "net/http/httptest" "net/url" "strings" @@ -19,16 +20,18 @@ func TestLoggingRoundTripper(t *testing.T) { ctx := log.Output(out).WithContext(context.Background()) // create request with context and url - req := httptest.NewRequest("GET", "/foo", nil).WithContext(ctx) + req := httptest.NewRequest(http.MethodGet, "/foo", nil).WithContext(ctx) + url, err := url.Parse("http://example.com/foo") if err != nil { panic(err) } + req.URL = url t.Run("Without retries", func(t *testing.T) { l := &LoggingRoundTripper{} - l.SetTransport(&transportWithResponse{statusCode: 200}) + l.SetTransport(&transportWithResponse{statusCode: http.StatusOK}) _, err = l.RoundTrip(req) if err != nil { @@ -39,11 +42,12 @@ func TestLoggingRoundTripper(t *testing.T) { if !strings.Contains(got, "duration") { t.Errorf("Expected duration to be contained in log output, got %v", got) } + if strings.Contains(got, "retries") { t.Errorf("Expected retries to not be contained in log output, got %v", got) } - exs := []string{`"level":"debug"`, `"url":"http://example.com/foo"`, `"method":"GET"`, `"status_code":200`, `"message":"HTTP GET example.com"`} + exs := []string{`"level":"debug"`, `"url":"http://example.com/foo"`, `"method":http.MethodGet`, `"status_code":200`, `"message":"HTTP GET example.com"`} for _, ex := range exs { if !strings.Contains(got, ex) { t.Errorf("Expected %v to be contained in log output, got %v", ex, got) @@ -60,7 +64,8 @@ func TestLoggingRoundTripper(t *testing.T) { } got := out.String() - exs := []string{`"level":"debug"`, `"url":"http://example.com/foo"`, `"method":"GET"`, `"status_code":200`, `"message":"HTTP GET example.com"`, `"attempt":3`} + exs := []string{`"level":"debug"`, `"url":"http://example.com/foo"`, `"method":http.MethodGet`, `"status_code":200`, `"message":"HTTP GET example.com"`, `"attempt":3`} + for _, ex := range exs { if !strings.Contains(got, ex) { t.Errorf("Expected %v to be contained in log output, got %v", ex, got) diff --git a/http/transport/request_id.go b/http/transport/request_id.go index e4945cc19..74b33ed53 100644 --- a/http/transport/request_id.go +++ b/http/transport/request_id.go @@ -30,5 +30,6 @@ func (l *RequestIDRoundTripper) RoundTrip(req *http.Request) (*http.Response, er if reqID := log.RequestIDFromContext(ctx); reqID != "" { req.Header.Set("Request-Id", reqID) } + return l.Transport().RoundTrip(req) } diff --git a/http/transport/request_id_test.go b/http/transport/request_id_test.go index e5f6ff43a..c61feb2a9 100644 --- a/http/transport/request_id_test.go +++ b/http/transport/request_id_test.go @@ -8,9 +8,10 @@ import ( "testing" "github.com/gorilla/mux" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pace/bricks/maintenance/log" ) func TestRequestIDRoundTripper(t *testing.T) { @@ -18,7 +19,7 @@ func TestRequestIDRoundTripper(t *testing.T) { rt.SetTransport(&transportWithResponse{}) t.Run("without req_id", func(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) _, err := rt.RoundTrip(req) assert.NoError(t, err) assert.Empty(t, req.Header["Request-Id"]) @@ -34,7 +35,7 @@ func TestRequestIDRoundTripper(t *testing.T) { require.Equal(t, ID, log.RequestID(r)) require.Equal(t, ID, log.RequestIDFromContext(r.Context())) - r1 := httptest.NewRequest("GET", "/foo", nil) + r1 := httptest.NewRequest(http.MethodGet, "/foo", nil) r1 = r1.WithContext(r.Context()) _, err := rt.RoundTrip(r1) @@ -44,7 +45,7 @@ func TestRequestIDRoundTripper(t *testing.T) { }) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) req.Header.Set("Request-Id", ID) r.ServeHTTP(rec, req) assert.Equal(t, http.StatusNoContent, rec.Code) diff --git a/http/transport/request_source_round_tripper_test.go b/http/transport/request_source_round_tripper_test.go index 3305ee1d2..83fcb43fe 100644 --- a/http/transport/request_source_round_tripper_test.go +++ b/http/transport/request_source_round_tripper_test.go @@ -3,6 +3,7 @@ package transport import ( + "net/http" "net/http/httptest" "testing" @@ -10,7 +11,7 @@ import ( ) func TestRequestSourceRoundTripper(t *testing.T) { - req := httptest.NewRequest("GET", "/foo", nil) + req := httptest.NewRequest(http.MethodGet, "/foo", nil) rt := RequestSourceRoundTripper{SourceName: "foobar"} rt.SetTransport(&transportWithResponse{}) diff --git a/http/transport/retry_round_tripper.go b/http/transport/retry_round_tripper.go index ffa30fe44..36c2c075d 100644 --- a/http/transport/retry_round_tripper.go +++ b/http/transport/retry_round_tripper.go @@ -27,6 +27,7 @@ func RetryNetErr() rehttp.RetryFn { if _, isNetError := attempt.Error.(*net.OpError); isNetError { return true } + return false } } @@ -98,6 +99,7 @@ func (l *RetryRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) transport: transportWithAttempt(l.Transport()), } retryTransport.RoundTripper = wrappedTransport + resp, err := retryTransport.RoundTrip(req) if err != nil { return nil, err diff --git a/http/transport/retry_round_tripper_test.go b/http/transport/retry_round_tripper_test.go index d1a26494c..23e509fcc 100644 --- a/http/transport/retry_round_tripper_test.go +++ b/http/transport/retry_round_tripper_test.go @@ -34,7 +34,7 @@ func TestRetryRoundTripper(t *testing.T) { name: "Successful response after some retries", args: args{ requestBody: []byte(`{"key":"value""}`), - statuses: []int{408, 502, 503, 504, 200}, + statuses: []int{408, 502, 503, 504, http.StatusOK}, }, wantRetries: 5, }, @@ -69,7 +69,7 @@ func TestRetryRoundTripper(t *testing.T) { name: "Exceed retries", args: args{ requestBody: []byte(`{"key":"value""}`), - statuses: []int{408, 502, 503, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 200}, + statuses: []int{408, 502, 503, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, 504, http.StatusOK}, }, wantRetries: 10, wantErr: ErrRetryFailed, @@ -84,7 +84,7 @@ func TestRetryRoundTripper(t *testing.T) { } rt.SetTransport(tr) - req := httptest.NewRequest("GET", "/foo", bytes.NewReader(tt.args.requestBody)) + req := httptest.NewRequest(http.MethodGet, "/foo", bytes.NewReader(tt.args.requestBody)) resp, err := rt.RoundTrip(req.WithContext(context.Background())) require.Equal(t, tt.wantRetries, tr.attempts) @@ -93,6 +93,7 @@ func TestRetryRoundTripper(t *testing.T) { require.ErrorIs(t, err, tt.wantErr) return } + require.NoError(t, err) body, err := io.ReadAll(resp.Body) @@ -120,6 +121,7 @@ func (t *retriedTransport) RoundTrip(req *http.Request) (*http.Response, error) if t.err != nil { return nil, fmt.Errorf("%w", t.err) } + readAll, _ := io.ReadAll(req.Body) body := io.NopCloser(bytes.NewReader(readAll)) resp := &http.Response{Body: body, StatusCode: t.statusCodes[t.attempts]} diff --git a/internal/service/generate/cmds.go b/internal/service/generate/cmds.go index a0706eec8..370215333 100644 --- a/internal/service/generate/cmds.go +++ b/internal/service/generate/cmds.go @@ -38,7 +38,7 @@ func Commands(path string, options CommandOptions) { filepath.Join(path, "cmd", options.ControlName), } for _, dir := range dirs { - err := os.MkdirAll(dir, 0o770) // nolint: gosec + err := os.MkdirAll(dir, 0o750) if err != nil { log.Fatal(fmt.Printf("Failed to create dir %s: %v", dir, err)) } @@ -46,7 +46,7 @@ func Commands(path string, options CommandOptions) { // Create commands files for _, dir := range dirs { - f, err := os.Create(filepath.Join(dir, "main.go")) + f, err := os.Create(filepath.Join(dir, "main.go")) //nolint:gosec if err != nil { log.Fatal(err) } @@ -59,6 +59,7 @@ func Commands(path string, options CommandOptions) { } else { generateControlMain(code, cmdName) } + _, err = f.WriteString(copyright()) if err != nil { log.Fatal(err) @@ -103,5 +104,6 @@ func copyright() string { stmt := "" now := time.Now() stmt += fmt.Sprintf("// Copyright © %04d by PACE Telematics GmbH. All rights reserved.\n", now.Year()) + return stmt } diff --git a/internal/service/generate/dockerfile.go b/internal/service/generate/dockerfile.go index 9418589ba..cbd6e3582 100644 --- a/internal/service/generate/dockerfile.go +++ b/internal/service/generate/dockerfile.go @@ -18,7 +18,7 @@ type DockerfileOptions struct { // Dockerfile generate a dockerfile using the given options // for specified path func Dockerfile(path string, options DockerfileOptions) { - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } diff --git a/internal/service/generate/error.go b/internal/service/generate/error.go index 7411b22ac..5a9e41713 100644 --- a/internal/service/generate/error.go +++ b/internal/service/generate/error.go @@ -18,6 +18,7 @@ type ErrorDefinitionFileOptions struct { func ErrorDefinitionFile(options ErrorDefinitionFileOptions) { // generate error definition g := generator.Generator{} + result, err := g.BuildSource(options.Source, options.Path, options.PkgName) if err != nil { log.Fatal(err) @@ -28,6 +29,7 @@ func ErrorDefinitionFile(options ErrorDefinitionFileOptions) { func ErrorDefinitionsMarkdown(options ErrorDefinitionFileOptions) { g := generator.Generator{} + result, err := g.BuildMarkdown(options.Source) if err != nil { log.Fatal(err) @@ -38,11 +40,16 @@ func ErrorDefinitionsMarkdown(options ErrorDefinitionFileOptions) { func writeResult(result, path string) { // create file - file, err := os.Create(path) + file, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } - defer file.Close() // nolint: errcheck + + defer func() { + if err := file.Close(); err != nil { + log.Printf("failed closing file body: %v\n", err) + } + }() // write file _, err = file.WriteString(result) diff --git a/internal/service/generate/errordefinition/generator/generate.go b/internal/service/generate/errordefinition/generator/generate.go index 5fb79ac44..d8ae0a24a 100644 --- a/internal/service/generate/errordefinition/generator/generate.go +++ b/internal/service/generate/errordefinition/generator/generate.go @@ -27,6 +27,7 @@ type Generator struct { func loadDefinitionData(source string) ([]byte, error) { var data []byte + if strings.HasPrefix(source, "http://") || strings.HasPrefix(source, "https://") { loc, err := url.Parse(source) if err != nil { @@ -40,7 +41,8 @@ func loadDefinitionData(source string) ([]byte, error) { } else { // read definition file from disk var err error - data, err = os.ReadFile(source) // nolint: gosec + + data, err = os.ReadFile(source) //nolint:gosec if err != nil { return nil, err } @@ -54,12 +56,18 @@ func loadDefinitionDataFromURI(url *url.URL) ([]byte, error) { if err != nil { return nil, err } - defer resp.Body.Close() // nolint: errcheck + + defer func() { + if err := resp.Body.Close(); err != nil { + fmt.Fprintf(os.Stderr, "failed closing response body: %v", err) + } + }() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } + return body, nil } @@ -73,6 +81,7 @@ func (g *Generator) BuildSource(source, packagePath, packageName string) (string // parse definition var errors runtime.Errors + err = json.Unmarshal(data, &errors) if err != nil { return "", err @@ -89,16 +98,15 @@ func (g *Generator) BuildDefinitions(errors runtime.Errors, packagePath, package // create a error code const for easier runtime error comparison - var constObjects []jen.Code - for _, jsonError := range errors { + constObjects := make([]jen.Code, 0) + for _, jsonError := range errors { // skip example if given if jsonError.Code == "EXAMPLE" { continue } constObjects = append(constObjects, jen.Id(fmt.Sprintf("ERR_CODE_%s", jsonError.Code)).Op("=").Lit(jsonError.Code)) - } if len(constObjects) > 0 { @@ -106,7 +114,6 @@ func (g *Generator) BuildDefinitions(errors runtime.Errors, packagePath, package } for _, jsonError := range errors { - // skip example if given if jsonError.Code == "EXAMPLE" { continue diff --git a/internal/service/generate/errordefinition/generator/markdown.go b/internal/service/generate/errordefinition/generator/markdown.go index 56f5af068..d3f483fab 100644 --- a/internal/service/generate/errordefinition/generator/markdown.go +++ b/internal/service/generate/errordefinition/generator/markdown.go @@ -37,6 +37,7 @@ func (g *Generator) BuildMarkdown(source string) (string, error) { func (g *Generator) parseDefinitions(data []byte) (ErrorDefinitions, error) { var parsedData []ErrorDefinition + err := json.Unmarshal(data, &parsedData) if err != nil { return nil, err @@ -70,12 +71,14 @@ func (g *Generator) generateMarkdown(eds ErrorDefinitions) (string, error) { if err != nil { return "", err } + _, err = output.WriteString(`|Code|Title| |-----------|-----------| `) if err != nil { panic(err) } + for _, detail := range details { _, err := output.WriteString(fmt.Sprintf("|%s|%s|\n", detail.Code, detail.Title)) if err != nil { diff --git a/internal/service/generate/makefile.go b/internal/service/generate/makefile.go index cec4de7bc..eedbdd0fe 100644 --- a/internal/service/generate/makefile.go +++ b/internal/service/generate/makefile.go @@ -17,7 +17,7 @@ type MakefileOptions struct { // Makefile generates a with given options for the // specified path func Makefile(path string, options MakefileOptions) { - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } diff --git a/internal/service/generate/rest.go b/internal/service/generate/rest.go index c06857bd3..07887aada 100644 --- a/internal/service/generate/rest.go +++ b/internal/service/generate/rest.go @@ -18,6 +18,7 @@ type RestOptions struct { func Rest(options RestOptions) { // generate jsonapi g := generator.Generator{} + result, err := g.BuildSource(options.Source, options.Path, options.PkgName) if err != nil { log.Fatal(err) @@ -28,7 +29,12 @@ func Rest(options RestOptions) { if err != nil { log.Fatal(err) } - defer file.Close() // nolint: errcheck + + defer func() { + if err := file.Close(); err != nil { + log.Printf("failed closing file body: %v\n", err) + } + }() // write file _, err = file.WriteString(result) diff --git a/internal/service/helper.go b/internal/service/helper.go index 5d99df655..312786b2d 100644 --- a/internal/service/helper.go +++ b/internal/service/helper.go @@ -30,6 +30,7 @@ func GoPath() string { if err != nil { log.Fatal(err) } + return filepath.Join(usr.HomeDir, "go") } @@ -45,6 +46,7 @@ func PacePath() string { if err != nil { log.Fatal(err) } + return filepath.Join(usr.HomeDir, "PACE") } @@ -64,7 +66,7 @@ func GoServicePackagePath(name string) string { // AutoInstall cmdName if not installed already using go get -u goGetPath func AutoInstall(cmdName, goGetPath string) { if _, err := os.Stat(GoBinCommand(cmdName)); os.IsNotExist(err) { - fmt.Fprintf(os.Stderr, "Installing %s using: go get -u %s\n", cmdName, goGetPath) // nolint: errcheck + fmt.Fprintf(os.Stderr, "Installing %s using: go get -u %s\n", cmdName, goGetPath) // assume error means no file SimpleExec("go", "get", "-u", goGetPath) } else if err != nil { @@ -79,10 +81,11 @@ func GoBinCommand(cmdName string) string { // SimpleExec executes the command and uses the parent process STDIN,STDOUT,STDERR func SimpleExec(cmdName string, arguments ...string) { - cmd := exec.Command(cmdName, arguments...) // nolint: gosec + cmd := exec.Command(cmdName, arguments...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr + err := cmd.Run() if err != nil { log.Fatal(err) @@ -91,11 +94,12 @@ func SimpleExec(cmdName string, arguments ...string) { // SimpleExecInPath executes the command and uses the parent process STDIN,STDOUT,STDERR in passed dir func SimpleExecInPath(dir, cmdName string, arguments ...string) { - cmd := exec.Command(cmdName, arguments...) // nolint: gosec + cmd := exec.Command(cmdName, arguments...) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr cmd.Dir = dir + err := cmd.Run() if err != nil { log.Fatal(err) @@ -104,10 +108,11 @@ func SimpleExecInPath(dir, cmdName string, arguments ...string) { // GoBinCommandText writes the command output to the passed writer func GoBinCommandText(w io.Writer, cmdName string, arguments ...string) { - cmd := exec.Command(cmdName, arguments...) // nolint: gosec + cmd := exec.Command(cmdName, arguments...) cmd.Stdin = os.Stdin cmd.Stdout = w cmd.Stderr = os.Stderr + err := cmd.Run() if err != nil { log.Fatal(err) diff --git a/internal/service/new.go b/internal/service/new.go index 51fba793a..48f744e4d 100644 --- a/internal/service/new.go +++ b/internal/service/new.go @@ -32,7 +32,8 @@ func New(name string, options NewOptions) { // add REST API if there was a source specified if options.RestSource != "" { restDir := filepath.Join(dir, "internal", "http", "rest") - err := os.MkdirAll(restDir, 0o770) // nolint: gosec + + err := os.MkdirAll(restDir, 0o750) if err != nil { log.Fatal(fmt.Printf("Failed to generate dir for rest api %s: %v", restDir, err)) } diff --git a/internal/service/service.go b/internal/service/service.go index 7c3d65b68..a83180060 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -29,6 +29,7 @@ func Path(name string) { if err != nil { log.Fatal(err) } + fmt.Println(dir) } @@ -36,10 +37,8 @@ func Path(name string) { // by env PACE_EDITOR or EDITOR func Edit(name string) { editor, ok := os.LookupEnv("PACE_EDITOR") - if !ok { editor, ok = os.LookupEnv("EDITOR") - if !ok { log.Fatal("No $PACE_EDITOR or $EDITOR defined!") } @@ -76,6 +75,7 @@ func Run(name string, options RunOptions) { } else { args, err = filepath.Glob(filepath.Join(dir, fmt.Sprintf("cmd/%s/*.go", options.CmdName))) } + if err != nil { log.Fatal(err) } @@ -115,12 +115,14 @@ func Test(name string, options TestOptions) { } } -// Lint executes golint or installes if not already installed +// Lint executes golint or installs if not already installed func Lint(name string) { AutoInstall("golint", "golang.org/x/lint/golint") var buf bytes.Buffer + GoBinCommandText(&buf, "go", "list", filepath.Join(GoServicePackagePath(name), "...")) + paths := strings.Split(buf.String(), "\n") // start go run diff --git a/locale/context.go b/locale/context.go index 07c69efa2..f6f86139f 100644 --- a/locale/context.go +++ b/locale/context.go @@ -24,7 +24,9 @@ func FromCtx(ctx context.Context) (*Locale, bool) { if val == nil { return new(Locale), false } + l, ok := val.(*Locale) + return l, ok } diff --git a/locale/http.go b/locale/http.go index 3112e8de7..29bb2aa49 100644 --- a/locale/http.go +++ b/locale/http.go @@ -20,9 +20,11 @@ func (l Locale) Request(r *http.Request) *http.Request { if l.HasLanguage() { r.Header.Set(HeaderAcceptLanguage, l.acceptLanguage) } + if l.HasTimezone() { r.Header.Set(HeaderAcceptTimezone, l.acceptTimezone) } + return r } diff --git a/locale/http_test.go b/locale/http_test.go index ec0f527c1..a0fd53a00 100644 --- a/locale/http_test.go +++ b/locale/http_test.go @@ -12,7 +12,7 @@ import ( ) func TestEmptyRequest(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) l := FromRequest(r) @@ -21,7 +21,7 @@ func TestEmptyRequest(t *testing.T) { } func TestFilledRequest(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) r.Header.Set(HeaderAcceptLanguage, "fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5") r.Header.Set(HeaderAcceptTimezone, "Europe/Paris") @@ -33,7 +33,7 @@ func TestFilledRequest(t *testing.T) { func TestExtendRequestWithEmptyLocale(t *testing.T) { l := new(Locale) - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) data, err := httputil.DumpRequest(l.Request(r), false) @@ -44,7 +44,7 @@ func TestExtendRequestWithEmptyLocale(t *testing.T) { func TestExtendRequest(t *testing.T) { l := NewLocale("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", "Europe/Paris") - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) data, err := httputil.DumpRequest(l.Request(r), false) @@ -64,7 +64,7 @@ func (m *httpRecorderNext) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func TestMiddlewareWithoutLocale(t *testing.T) { - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) rec := new(httpRecorderNext) @@ -79,7 +79,7 @@ func TestMiddlewareWithoutLocale(t *testing.T) { func TestMiddlewareWithLocale(t *testing.T) { l := NewLocale("fr-CH, fr;q=0.9, en;q=0.8, de;q=0.7, *;q=0.5", "Europe/Paris") - r, err := http.NewRequest("GET", "http://example.com/test", nil) + r, err := http.NewRequest(http.MethodGet, "http://example.com/test", nil) require.NoError(t, err) rec := new(httpRecorderNext) diff --git a/locale/locale.go b/locale/locale.go index 180700045..a8123f378 100644 --- a/locale/locale.go +++ b/locale/locale.go @@ -88,6 +88,7 @@ func (l Locale) Now() time.Time { if err != nil { // if the tz doesn't exist return time.Now() } + return time.Now().In(loc) } @@ -100,5 +101,6 @@ func ParseLocale(serialized string) (*Locale, error) { if len(parts) != 2 { return nil, fmt.Errorf("invalid locale format: %q", serialized) } + return NewLocale(parts[0], parts[1]), nil } diff --git a/locale/locale_test.go b/locale/locale_test.go index 735d96e4d..2503ff882 100644 --- a/locale/locale_test.go +++ b/locale/locale_test.go @@ -44,6 +44,7 @@ func TestTimezone(t *testing.T) { loc, err := l.Location() assert.NoError(t, err) + timeInUTC := time.Date(2018, 8, 30, 12, 0, 0, 0, time.UTC) assert.Equal(t, "2018-08-30 14:00:00 +0200 CEST", timeInUTC.In(loc).String()) } @@ -58,6 +59,7 @@ func TestTimezoneAndLocale(t *testing.T) { loc, err := l.Location() assert.NoError(t, err) + timeInUTC := time.Date(2018, 8, 30, 12, 0, 0, 0, time.UTC) assert.Equal(t, "2018-08-30 14:00:00 +0200 CEST", timeInUTC.In(loc).String()) } diff --git a/locale/strategy.go b/locale/strategy.go index 59b3cbc46..6a05fb866 100644 --- a/locale/strategy.go +++ b/locale/strategy.go @@ -14,6 +14,7 @@ type Strategy func(ctx context.Context) *Locale // If only lang or timezone fallback should be defined as a fallback, the None value may be used. func NewFallbackStrategy(lang, timezone string) Strategy { l := NewLocale(lang, timezone) + return func(ctx context.Context) *Locale { return l } @@ -50,6 +51,7 @@ func (s *StrategyList) PushFront(strategies ...Strategy) { // Locale executes all strategies and returns the new locale func (s *StrategyList) Locale(ctx context.Context) *Locale { var l Locale + for i := s.strategies.Front(); i != nil; i = i.Next() { curLoc := (i.Value.(Strategy))(ctx) @@ -68,13 +70,16 @@ func (s *StrategyList) Locale(ctx context.Context) *Locale { break } } + return &l } // NewDefaultFallbackStrategy returns a strategy list configured via environment func NewDefaultFallbackStrategy() *StrategyList { var sl StrategyList + sl.PushFront(NewFallbackStrategy(cfg.Language, cfg.Timezone)) sl.PushFront(NewContextStrategy()) + return &sl } diff --git a/locale/strategy_test.go b/locale/strategy_test.go index 52df4f032..d0039c409 100644 --- a/locale/strategy_test.go +++ b/locale/strategy_test.go @@ -18,6 +18,7 @@ func TestStrategy(t *testing.T) { func TestStrategyWithCtx(t *testing.T) { var sl StrategyList + sl.PushBack( NewContextStrategy(), NewFallbackStrategy("de-DE", "Europe/Berlin"), diff --git a/maintenance/errors/bricks.go b/maintenance/errors/bricks.go index 601ec1913..3e3f6f6cb 100644 --- a/maintenance/errors/bricks.go +++ b/maintenance/errors/bricks.go @@ -32,6 +32,7 @@ func NewBricksError(opts ...BricksErrorOption) *BricksError { for _, opt := range opts { opt(e) } + return e } @@ -65,6 +66,7 @@ func (e *BricksError) AsRuntimeError() *runtime.Error { Title: e.title, Detail: e.detail, } + return j } diff --git a/maintenance/errors/context.go b/maintenance/errors/context.go index 0d49fc6f8..fe14b01a7 100644 --- a/maintenance/errors/context.go +++ b/maintenance/errors/context.go @@ -41,5 +41,6 @@ func IsStdLibContextError(err error) bool { if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { return true } + return false } diff --git a/maintenance/errors/context_test.go b/maintenance/errors/context_test.go index b5681d60d..2edfb4277 100644 --- a/maintenance/errors/context_test.go +++ b/maintenance/errors/context_test.go @@ -27,6 +27,7 @@ func TestHide(t *testing.T) { err error exposedErr error } + tests := []struct { name string args args diff --git a/maintenance/errors/error.go b/maintenance/errors/error.go index 77e6265b7..d06b39a7b 100644 --- a/maintenance/errors/error.go +++ b/maintenance/errors/error.go @@ -11,12 +11,13 @@ import ( "os" "time" + "github.com/prometheus/client_golang/prometheus" + "github.com/rs/zerolog" + "github.com/pace/bricks/http/jsonapi/runtime" "github.com/pace/bricks/http/oauth2" "github.com/pace/bricks/maintenance/errors/raven" "github.com/pace/bricks/maintenance/log" - "github.com/prometheus/client_golang/prometheus" - "github.com/rs/zerolog" ) var paceHTTPPanicCounter = prometheus.NewGauge(prometheus.GaugeOpts{ @@ -54,6 +55,7 @@ func requestFromContext(ctx context.Context) *http.Request { if v := ctx.Value(reqKey); v != nil { return v.(*http.Request) } + return nil } @@ -63,6 +65,7 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context { if r := requestFromContext(ctx); r != nil { return contextWithRequest(targetCtx, r) } + return targetCtx } @@ -97,18 +100,21 @@ func HandleRequest(handlerName string, w http.ResponseWriter, r *http.Request) { // HandleError reports the passed error to sentry func HandleError(rp interface{}, handlerName string, w http.ResponseWriter, r *http.Request) { ctx := r.Context() + pw, ok := rp.(*PanicWrap) if ok { log.Ctx(ctx).Error().Str("handler", handlerName).Msgf("Panic: %v", pw.err) + rp = pw.err // unwrap error } else { log.Ctx(ctx).Error().Str("handler", handlerName).Msgf("Error: %v", rp) } + log.Stack(ctx) sentryEvent{ctx, r, rp, 1, handlerName}.Send() - runtime.WriteError(w, http.StatusInternalServerError, errors.New("Internal Server Error")) + runtime.WriteError(w, http.StatusInternalServerError, errors.New("internal Server Error")) } // Handle logs the given error and reports it to sentry. @@ -116,10 +122,12 @@ func Handle(ctx context.Context, rp interface{}) { pw, ok := rp.(*PanicWrap) if ok { log.Ctx(ctx).Error().Msgf("Panic: %v", pw.err) + rp = pw.err // unwrap error } else { log.Ctx(ctx).Error().Msgf("Error: %v", rp) } + log.Stack(ctx) sentryEvent{ctx, nil, rp, 1, ""}.Send() @@ -171,6 +179,7 @@ func (e sentryEvent) build() *raven.Packet { } rvalStr := fmt.Sprint(rp) + var packet *raven.Packet if err, ok := rp.(error); ok { @@ -190,10 +199,12 @@ func (e sentryEvent) build() *raven.Packet { // add user userID, ok := oauth2.UserID(ctx) + user := raven.User{ID: userID} if r != nil { user.IP = log.ProxyAwareRemote(r) } + packet.Interfaces = append(packet.Interfaces, &user) if ok { packet.Tags = append(packet.Tags, raven.Tag{Key: "user_id", Value: userID}) @@ -204,14 +215,17 @@ func (e sentryEvent) build() *raven.Packet { packet.Extra["req_id"] = reqID packet.Tags = append(packet.Tags, raven.Tag{Key: "req_id", Value: reqID}) } + if traceID := log.TraceIDFromContext(ctx); traceID != "" { packet.Extra["uber_trace_id"] = traceID packet.Tags = append(packet.Tags, raven.Tag{Key: "trace_id", Value: traceID}) } + packet.Extra["handler"] = handlerName if clientID, ok := oauth2.ClientID(ctx); ok { packet.Extra["oauth2_client_id"] = clientID } + if scopes := oauth2.Scopes(ctx); len(scopes) > 0 { packet.Extra["oauth2_scopes"] = scopes } @@ -247,6 +261,7 @@ func getBreadcrumbs(ctx context.Context) []*raven.Breadcrumb { } result := make([]*raven.Breadcrumb, len(data)) + for i, d := range data { crumb, err := createBreadcrumb(d) if err != nil { @@ -268,6 +283,7 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { return nil, errors.New(`cannot parse "time"`) } + delete(data, "time") time, err := time.Parse(time.RFC3339, timeRaw) @@ -279,6 +295,7 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { return nil, errors.New(`cannot parse "level"`) } + delete(data, "level") level, err := translateZerologLevelToSentryLevel(levelRaw) @@ -290,12 +307,14 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { return nil, errors.New(`cannot parse "message"`) } + delete(data, "message") categoryRaw, ok := data["sentry:category"] if !ok { categoryRaw = "" } + delete(data, "sentry:category") category, ok := categoryRaw.(string) @@ -307,6 +326,7 @@ func createBreadcrumb(data map[string]interface{}) (*raven.Breadcrumb, error) { if !ok { typRaw = "" } + delete(data, "sentry:type") typ, ok := typRaw.(string) diff --git a/maintenance/errors/error_test.go b/maintenance/errors/error_test.go index dc56c2009..8744732cd 100644 --- a/maintenance/errors/error_test.go +++ b/maintenance/errors/error_test.go @@ -30,13 +30,14 @@ func TestHandler(t *testing.T) { }) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) mux.ServeHTTP(rec, req) if rec.Code != 500 { t.Errorf("Expected 500, got %d", rec.Code) } + if strings.Contains(rec.Body.String(), `"error":"Error"`) { t.Errorf(`Expected "error":"Error", got %q`, rec.Body.String()) } @@ -108,9 +109,9 @@ func Test_createBreadcrumb(t *testing.T) { "sentry:category": "http", "sentry:type": "http", "message": "HTTPS GET www.pace.car", - "method": "GET", + "method": http.MethodGet, "attempt": 1, - "status_code": 200, + "status_code": http.StatusOK, "duration": 227.717783, "url": "https://www.pace.car/", "req_id": "bpboj6bipt34r4teo7g0", @@ -122,9 +123,9 @@ func Test_createBreadcrumb(t *testing.T) { Timestamp: 1582795168, Type: "http", Data: map[string]interface{}{ - "method": "GET", + "method": http.MethodGet, "attempt": 1, - "status_code": 200, + "status_code": http.StatusOK, "duration": 227.717783, "url": "https://www.pace.car/", }, @@ -183,10 +184,10 @@ func Test_createBreadcrumb(t *testing.T) { // which should be passed to all subsequent requests and handler. func TestHandlerWithLogSink(t *testing.T) { rec1 := httptest.NewRecorder() - req1 := httptest.NewRequest("GET", "/test1", nil) + req1 := httptest.NewRequest(http.MethodGet, "/test1", nil) rec2 := httptest.NewRecorder() - req2 := httptest.NewRequest("GET", "/test2", nil) + req2 := httptest.NewRequest(http.MethodGet, "/test2", nil) var ( sink1Ctx context.Context @@ -196,25 +197,27 @@ func TestHandlerWithLogSink(t *testing.T) { mux := http.NewServeMux() mux.HandleFunc("/test1", func(w http.ResponseWriter, r *http.Request) { sink1Ctx = r.Context() + log.Ctx(r.Context()).Debug().Msg("ONLY FOR SINK1") w.WriteHeader(http.StatusOK) }) mux.HandleFunc("/test2", func(w http.ResponseWriter, r *http.Request) { require.NotEqual(t, "", log.RequestID(r), "request should have request id") + sink2Ctx = r.Context() client := &http.Client{ Transport: transport.Chain(&transport.LoggingRoundTripper{}, &transport.DumpRoundTripper{}), } - r0, err := http.NewRequest("GET", "https://www.pace.car/de", nil) + r0, err := http.NewRequest(http.MethodGet, "https://www.pace.car/de", nil) assert.NoError(t, err, `failed creating request to "/succeed"`) r0 = r0.WithContext(r.Context()) _, err = client.Do(r0) assert.NoError(t, err, `request to "/succeed" should not error`) - r1, err := http.NewRequest("GET", "http://localhost/fail", nil) + r1, err := http.NewRequest(http.MethodGet, "http://localhost/fail", nil) assert.NoError(t, err, `failed creating request to "/fail"`) r1 = r1.WithContext(r.Context()) @@ -228,22 +231,26 @@ func TestHandlerWithLogSink(t *testing.T) { panic("Sink2 Test Error, IGNORE") }) + handler := log.Handler()(Handler()(mux)) handler.ServeHTTP(rec1, req1) + resp1 := rec1.Result() require.Equal(t, http.StatusOK, resp1.StatusCode, "wrong status code") - resp1.Body.Close() + _ = resp1.Body.Close() handler.ServeHTTP(rec2, req2) + resp2 := rec2.Result() require.Equal(t, http.StatusInternalServerError, resp2.StatusCode, "wrong status code") - resp2.Body.Close() + _ = resp2.Body.Close() sink1, ok := log.SinkFromContext(sink1Ctx) assert.True(t, ok, "failed getting sink1") var sink1LogLines []json.RawMessage + assert.NoError(t, json.Unmarshal(sink1.ToJSON(), &sink1LogLines), "failed extracting logs from sink1") assert.Len(t, sink1LogLines, 2, "more log lines than expected") @@ -253,6 +260,7 @@ func TestHandlerWithLogSink(t *testing.T) { assert.True(t, ok, "failed getting sink2") var sink2LogLines []json.RawMessage + assert.NoError(t, json.Unmarshal(sink2.ToJSON(), &sink2LogLines), "failed extracting logs from sink2") assert.NotContains(t, string(sink2LogLines[0]), "ONLY FOR SINK1", "unexpected log line found") diff --git a/maintenance/errors/raven/client.go b/maintenance/errors/raven/client.go index e5f26fe24..b1e59014c 100644 --- a/maintenance/errors/raven/client.go +++ b/maintenance/errors/raven/client.go @@ -23,9 +23,10 @@ import ( "time" "github.com/certifi/gocertifi" + pkgErrors "github.com/pkg/errors" + "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/redact" - pkgErrors "github.com/pkg/errors" ) const ( @@ -65,6 +66,7 @@ func (timestamp *Timestamp) UnmarshalJSON(data []byte) error { } *timestamp = Timestamp(t) + return nil } @@ -111,7 +113,9 @@ func (t *Tag) UnmarshalJSON(data []byte) error { if err := json.Unmarshal(data, &tag); err != nil { return err } + *t = Tag{tag[0], tag[1]} + return nil } @@ -140,6 +144,7 @@ func (t *Tags) UnmarshalJSON(data []byte) error { } *t = tags + return nil } @@ -184,6 +189,7 @@ type Breadcrumb struct { func NewPacket(message string, interfaces ...Interface) *Packet { extra := Extra{} setExtraDefaults(extra) + return &Packet{ Message: message, Interfaces: interfaces, @@ -196,6 +202,7 @@ func NewPacketWithExtra(message string, extra Extra, interfaces ...Interface) *P if extra == nil { extra = Extra{} } + setExtraDefaults(extra) return &Packet{ @@ -210,6 +217,7 @@ func setExtraDefaults(extra Extra) Extra { extra["runtime.NumCPU"] = runtime.NumCPU() extra["runtime.GOMAXPROCS"] = runtime.GOMAXPROCS(0) // 0 just returns the current value extra["runtime.NumGoroutine"] = runtime.NumGoroutine() + return extra } @@ -219,25 +227,32 @@ func (packet *Packet) Init(project string) error { if packet.Project == "" { packet.Project = project } + if packet.EventID == "" { var err error + packet.EventID, err = uuid() if err != nil { return err } } + if time.Time(packet.Timestamp).IsZero() { packet.Timestamp = Timestamp(time.Now()) } + if packet.Level == "" { packet.Level = ERROR } + if packet.Logger == "" { packet.Logger = "root" } + if packet.ServerName == "" { packet.ServerName = hostname } + if packet.Platform == "" { packet.Platform = "go" } @@ -264,14 +279,17 @@ func (packet *Packet) AddTags(tags map[string]string) { func uuid() (string, error) { id := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, id) if err != nil { return "", err } + id[6] &= 0x0F // clear version id[6] |= 0x40 // set version to 4 (random uuid) id[8] &= 0x3F // clear variant id[8] |= 0x80 // set to IETF variant + return hex.EncodeToString(id), nil } @@ -282,6 +300,7 @@ func (packet *Packet) JSON() ([]byte, error) { } interfaces := make(map[string]Interface, len(packet.Interfaces)) + for _, inter := range packet.Interfaces { if inter != nil { interfaces[inter.Class()] = inter @@ -293,6 +312,7 @@ func (packet *Packet) JSON() ([]byte, error) { if err != nil { return nil, err } + packetJSON[len(packetJSON)-1] = ',' packetJSON = append(packetJSON, interfaceJSON[1:]...) } @@ -312,6 +332,7 @@ func (c *context) setTags(t map[string]string) { if c.tags == nil { c.tags = make(map[string]string) } + for k, v := range t { c.tags[k] = v } @@ -329,17 +350,21 @@ func (c *context) interfaces() []Interface { if c.user != nil { len++ } + if c.http != nil { len++ } + interfaces := make([]Interface, len) if c.user != nil { interfaces[i] = c.user i++ } + if c.http != nil { interfaces[i] = c.http } + return interfaces } @@ -349,6 +374,7 @@ var MaxQueueBuffer = 100 func newTransport() Transport { t := &HTTPTransport{} + rootCAs, err := gocertifi.CACerts() if err != nil { log.Println("raven: failed to load root TLS certificates:", err) @@ -360,6 +386,7 @@ func newTransport() Transport { }, } } + return t } @@ -372,12 +399,15 @@ func newClient(tags map[string]string) *Client { queue: make(chan *outgoingPacket, MaxQueueBuffer), } dsn := os.Getenv("SENTRY_DSN") + err := client.SetDSN(dsn) if err != nil && dsn != "" { log.Warnf("DSN environment was set to %q but failed: %v", dsn, err) } + client.SetRelease(os.Getenv("SENTRY_RELEASE")) client.SetEnvironment(os.Getenv("ENVIRONMENT")) + return client } @@ -445,6 +475,7 @@ var DefaultClient = newClient(nil) func (c *Client) SetIgnoreErrors(errs []string) error { joinedRegexp := strings.Join(errs, "|") + r, err := regexp.Compile(joinedRegexp) if err != nil { return fmt.Errorf("failed to compile regexp %q for %q: %v", joinedRegexp, errs, err) @@ -453,12 +484,14 @@ func (c *Client) SetIgnoreErrors(errs []string) error { c.mu.Lock() c.ignoreErrorsRegexp = r c.mu.Unlock() + return nil } func (c *Client) shouldExcludeErr(errStr string) bool { c.mu.RLock() defer c.mu.RUnlock() + return c.ignoreErrorsRegexp != nil && c.ignoreErrorsRegexp.MatchString(errStr) } @@ -484,6 +517,7 @@ func (client *Client) SetDSN(dsn string) error { if uri.User == nil { return ErrMissingUser } + publicKey := uri.User.Username() secretKey, hasSecretKey := uri.User.Password() uri.User = nil @@ -492,6 +526,7 @@ func (client *Client) SetDSN(dsn string) error { client.projectID = uri.Path[idx+1:] uri.Path = uri.Path[:idx+1] + "api/" + client.projectID + "/store/" } + if client.projectID == "" { return ErrMissingProjectID } @@ -514,6 +549,7 @@ func SetDSN(dsn string) error { return DefaultClient.SetDSN(dsn) } func (client *Client) SetRelease(release string) { client.mu.Lock() defer client.mu.Unlock() + client.release = release } @@ -521,6 +557,7 @@ func (client *Client) SetRelease(release string) { func (client *Client) SetEnvironment(environment string) { client.mu.Lock() defer client.mu.Unlock() + client.environment = environment } @@ -528,6 +565,7 @@ func (client *Client) SetEnvironment(environment string) { func (client *Client) SetDefaultLoggerName(name string) { client.mu.Lock() defer client.mu.Unlock() + client.defaultLoggerName = name } @@ -539,7 +577,9 @@ func (client *Client) SetSampleRate(rate float32) error { if rate < 0 || rate > 1 { return ErrInvalidSampleRate } + client.sampleRate = rate + return nil } @@ -559,12 +599,12 @@ func SetSampleRate(rate float32) error { return DefaultClient.SetSampleRate(rate func (client *Client) worker() { for outgoingPacket := range client.queue { - client.mu.RLock() url, authHeader := client.url, client.authHeader client.mu.RUnlock() outgoingPacket.ch <- client.Transport.Send(url, authHeader, outgoingPacket.packet) + client.wg.Done() } } @@ -606,6 +646,7 @@ func (client *Client) Capture(packet *Packet, captureTags map[string]string) (ev // Initialize any required packet fields client.mu.RLock() packet.AddTags(client.context.tags) + projectID := client.projectID release := client.release environment := client.environment @@ -620,7 +661,9 @@ func (client *Client) Capture(packet *Packet, captureTags map[string]string) (ev err := packet.Init(projectID) if err != nil { ch <- err + client.wg.Done() + return } @@ -648,6 +691,7 @@ func (client *Client) Capture(packet *Packet, captureTags map[string]string) (ev client.DropHandler(packet) } ch <- ErrPacketDropped + client.wg.Done() } @@ -693,6 +737,7 @@ func (client *Client) CaptureMessageAndWait(message string, tags map[string]stri } packet := NewPacket(message, append(append(interfaces, client.context.interfaces()...), &Message{message, nil})...) + eventID, ch := client.Capture(packet, tags) if eventID != "" { <-ch @@ -750,6 +795,7 @@ func (client *Client) CaptureErrorAndWait(err error, tags map[string]string, int cause := pkgErrors.Cause(err) packet := NewPacketWithExtra(err.Error(), extra, append(append(interfaces, client.context.interfaces()...), NewException(cause, GetOrNewStacktrace(cause, 1, 3, client.includePaths)))...) + eventID, ch := client.Capture(packet, tags) if eventID != "" { <-ch @@ -772,6 +818,7 @@ func (client *Client) CapturePanic(f func(), tags map[string]string, interfaces // be completely noop though if we cared. defer func() { var packet *Packet + err = recover() switch rval := err.(type) { case nil: @@ -780,12 +827,14 @@ func (client *Client) CapturePanic(f func(), tags map[string]string, interfaces if client.shouldExcludeErr(rval.Error()) { return } + packet = NewPacket(rval.Error(), append(append(interfaces, client.context.interfaces()...), NewException(rval, NewStacktrace(2, 3, client.includePaths)))...) default: rvalStr := fmt.Sprint(rval) if client.shouldExcludeErr(rvalStr) { return } + packet = NewPacket(rvalStr, append(append(interfaces, client.context.interfaces()...), NewException(errors.New(rvalStr), NewStacktrace(2, 3, client.includePaths)))...) } @@ -793,6 +842,7 @@ func (client *Client) CapturePanic(f func(), tags map[string]string, interfaces }() f() + return } @@ -810,6 +860,7 @@ func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, inte // be completely noop though if we cared. defer func() { var packet *Packet + err = recover() switch rval := err.(type) { case nil: @@ -818,16 +869,19 @@ func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, inte if client.shouldExcludeErr(rval.Error()) { return } + packet = NewPacket(rval.Error(), append(append(interfaces, client.context.interfaces()...), NewException(rval, NewStacktrace(2, 3, client.includePaths)))...) default: rvalStr := fmt.Sprint(rval) if client.shouldExcludeErr(rvalStr) { return } + packet = NewPacket(rvalStr, append(append(interfaces, client.context.interfaces()...), NewException(errors.New(rvalStr), NewStacktrace(2, 3, client.includePaths)))...) } var ch chan error + errorID, ch = client.Capture(packet, tags) if errorID != "" { <-ch @@ -835,6 +889,7 @@ func (client *Client) CapturePanicAndWait(f func(), tags map[string]string, inte }() f() + return } @@ -946,22 +1001,28 @@ func (t *HTTPTransport) Send(url, authHeader string, packet *Packet) error { if err != nil { return fmt.Errorf("error serializing packet: %v", err) } - req, err := http.NewRequest("POST", url, body) + + req, err := http.NewRequest(http.MethodPost, url, body) if err != nil { return fmt.Errorf("can't create new request: %v", err) } + req.Header.Set("X-Sentry-Auth", authHeader) req.Header.Set("User-Agent", userAgent) req.Header.Set("Content-Type", contentType) + res, err := t.Do(req) if err != nil { return err } - io.Copy(io.Discard, res.Body) // nolint: errcheck + + io.Copy(io.Discard, res.Body) res.Body.Close() - if res.StatusCode != 200 { + + if res.StatusCode != http.StatusOK { return fmt.Errorf("raven: got http status %d", res.StatusCode) } + return nil } @@ -980,11 +1041,13 @@ func serializedPacket(packet *Packet) (io.Reader, string, error) { buf := &bytes.Buffer{} b64 := base64.NewEncoder(base64.StdEncoding, buf) deflate, _ := zlib.NewWriterLevel(b64, zlib.BestCompression) - deflate.Write(packetJSON) // nolint: errcheck + deflate.Write(packetJSON) deflate.Close() b64.Close() + return buf, "application/octet-stream", nil } + return bytes.NewReader(packetJSON), "application/json", nil } diff --git a/maintenance/errors/raven/exception.go b/maintenance/errors/raven/exception.go index 552eaad12..f47e96684 100644 --- a/maintenance/errors/raven/exception.go +++ b/maintenance/errors/raven/exception.go @@ -14,9 +14,11 @@ func NewException(err error, stacktrace *Stacktrace) *Exception { Value: msg, Type: reflect.TypeOf(err).String(), } + if m := errorMsgPattern.FindStringSubmatch(msg); m != nil { ex.Module, ex.Value = m[1], m[2] } + return ex } @@ -37,6 +39,7 @@ func (e *Exception) Culprit() string { if e.Stacktrace == nil { return "" } + return e.Stacktrace.Culprit() } diff --git a/maintenance/errors/raven/http.go b/maintenance/errors/raven/http.go index 0d8fbb112..3d2b76262 100644 --- a/maintenance/errors/raven/http.go +++ b/maintenance/errors/raven/http.go @@ -15,6 +15,7 @@ func NewHttp(req *http.Request) *Http { if req.TLS != nil || req.Header.Get("X-Forwarded-Proto") == "https" { proto = "https" } + h := &Http{ Method: req.Method, Cookies: req.Header.Get("Cookie"), @@ -25,10 +26,13 @@ func NewHttp(req *http.Request) *Http { if addr, port, err := net.SplitHostPort(req.RemoteAddr); err == nil { h.Env = map[string]string{"REMOTE_ADDR": addr, "REMOTE_PORT": port} } + for k, v := range req.Header { h.Headers[k] = strings.Join(v, ",") } + h.Headers["Host"] = req.Host + return h } @@ -42,6 +46,7 @@ func sanitizeQuery(query url.Values) url.Values { } } } + return query } @@ -74,13 +79,17 @@ func RecoveryHandler(handler func(http.ResponseWriter, *http.Request)) func(http defer func() { if rval := recover(); rval != nil { debug.PrintStack() + rvalStr := fmt.Sprint(rval) + var packet *Packet + if err, ok := rval.(error); ok { packet = NewPacket(rvalStr, NewException(errors.New(rvalStr), GetOrNewStacktrace(err, 2, 3, nil)), NewHttp(r)) } else { packet = NewPacket(rvalStr, NewException(errors.New(rvalStr), NewStacktrace(2, 3, nil)), NewHttp(r)) } + Capture(packet, nil) w.WriteHeader(http.StatusInternalServerError) } diff --git a/maintenance/errors/raven/stacktrace.go b/maintenance/errors/raven/stacktrace.go index aab2a3429..824b2b442 100644 --- a/maintenance/errors/raven/stacktrace.go +++ b/maintenance/errors/raven/stacktrace.go @@ -32,6 +32,7 @@ func (s *Stacktrace) Culprit() string { return frame.Module + "." + frame.Function } } + return "" } @@ -58,21 +59,26 @@ func GetOrNewStacktrace(err error, skip int, context int, appPackagePrefixes []s }) if errHasStacktrace { var frames []*StacktraceFrame + for _, f := range stacktracer.StackTrace() { pc := uintptr(f) - 1 fn := runtime.FuncForPC(pc) + var file string + var line int if fn != nil { file, line = fn.FileLine(pc) } else { file = "unknown" } + frame := NewStacktraceFrame(pc, file, line, context, appPackagePrefixes) if frame != nil { frames = append([]*StacktraceFrame{frame}, frames...) } } + return &Stacktrace{Frames: frames} } else { return NewStacktrace(skip+1, context, appPackagePrefixes) @@ -89,11 +95,13 @@ func GetOrNewStacktrace(err error, skip int, context int, appPackagePrefixes []s // be considered "in app". func NewStacktrace(skip int, context int, appPackagePrefixes []string) *Stacktrace { var frames []*StacktraceFrame + for i := 1 + skip; ; i++ { pc, file, line, ok := runtime.Caller(i) if !ok { break } + frame := NewStacktraceFrame(pc, file, line, context, appPackagePrefixes) if frame != nil { frames = append(frames, frame) @@ -111,6 +119,7 @@ func NewStacktrace(skip int, context int, appPackagePrefixes []string) *Stacktra for i, j := 0, len(frames)-1; i < j; i, j = i+1, j-1 { frames[i], frames[j] = frames[j], frames[i] } + return &Stacktrace{frames} } @@ -162,6 +171,7 @@ func NewStacktraceFrame(pc uintptr, file string, line, context int, appPackagePr frame.ContextLine = string(contextLine[0]) } } + return frame } @@ -199,6 +209,7 @@ var ( func fileContext(filename string, line, context int) ([][]byte, int) { fileCacheLock.Lock() defer fileCacheLock.Unlock() + lines, ok := fileCache[filename] if !ok { data, err := os.ReadFile(filename) @@ -209,6 +220,7 @@ func fileContext(filename string, line, context int) ([][]byte, int) { fileCache[filename] = nil return nil, 0 } + lines = bytes.Split(data, []byte{'\n'}) fileCache[filename] = lines } @@ -220,20 +232,26 @@ func fileContext(filename string, line, context int) ([][]byte, int) { line-- // stack trace lines are 1-indexed start := line - context + var idx int + if start < 0 { start = 0 idx = line } else { idx = context } + end := line + context + 1 + if line >= len(lines) { return nil, 0 } + if end > len(lines) { end = len(lines) } + return lines[start:end], idx } @@ -246,6 +264,7 @@ func trimPath(filename string) string { return trimmed } } + return filename } @@ -256,6 +275,7 @@ func init() { if prefix[len(prefix)-1] != filepath.Separator { prefix += string(filepath.Separator) } + trimPaths = append(trimPaths, prefix) } } diff --git a/maintenance/failover/failover.go b/maintenance/failover/failover.go index 36d41c7c6..8ba5809ef 100644 --- a/maintenance/failover/failover.go +++ b/maintenance/failover/failover.go @@ -11,11 +11,12 @@ import ( "time" "github.com/bsm/redislock" + "github.com/redis/go-redis/v9" + "github.com/pace/bricks/backend/k8sapi" "github.com/pace/bricks/maintenance/errors" "github.com/pace/bricks/maintenance/health" "github.com/pace/bricks/maintenance/log" - "github.com/redis/go-redis/v9" ) const waitRetry = time.Millisecond * 500 @@ -138,6 +139,7 @@ func (a *ActivePassive) Run(ctx context.Context) error { // try to acquire the lock, as we are not the active if a.getState() != ACTIVE { var err error + lock, err = a.locker.Obtain(ctx, lockName, a.timeToFailover, &redislock.Options{ RetryStrategy: redislock.LimitRetry(redislock.LinearBackoff(a.timeToFailover/3), 3), }) @@ -160,12 +162,14 @@ func (a *ActivePassive) Run(ctx context.Context) error { if err != nil { logger.Debug().Err(err).Msg("failed to get TTL") } + if d == 0 { // TTL seems to be expired, retry to get lock or become // passive in next iteration logger.Debug().Msg("ttl expired") a.becomeUndefined(ctx) } + refreshTime := d / 2 logger.Debug().Msgf("set refresh to %v", refreshTime) @@ -185,8 +189,12 @@ func (a *ActivePassive) Stop() { // Handler implements the readiness http endpoint func (a *ActivePassive) Handler(w http.ResponseWriter, r *http.Request) { label := a.label(a.getState()) + w.WriteHeader(http.StatusOK) - fmt.Fprintln(w, strings.ToUpper(label)) + + if _, err := fmt.Fprintln(w, strings.ToUpper(label)); err != nil { + w.WriteHeader(http.StatusInternalServerError) + } } func (a *ActivePassive) label(s status) string { @@ -228,11 +236,14 @@ func (a *ActivePassive) setState(ctx context.Context, state status) bool { a.stateMu.Lock() a.state = UNDEFINED a.stateMu.Unlock() + return false } + a.stateMu.Lock() a.state = state a.stateMu.Unlock() + return true } @@ -240,5 +251,6 @@ func (a *ActivePassive) getState() status { a.stateMu.RLock() state := a.state a.stateMu.RUnlock() + return state } diff --git a/maintenance/health/health.go b/maintenance/health/health.go index 495dee15c..5682f8059 100644 --- a/maintenance/health/health.go +++ b/maintenance/health/health.go @@ -30,7 +30,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { )) } -// ReadinessCheck allows to set a different function for the readiness check. The default readiness check +// SetCustomReadinessCheck allows to set a different function for the readiness check. The default readiness check // is the same as the liveness check and does always return OK func SetCustomReadinessCheck(check func(http.ResponseWriter, *http.Request)) { readinessCheck.check = check @@ -39,6 +39,7 @@ func SetCustomReadinessCheck(check func(http.ResponseWriter, *http.Request)) { func liveness(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) + if _, err := fmt.Fprint(w, "OK\n"); err != nil { log.Warnf("could not write output: %s", err) } diff --git a/maintenance/health/health_test.go b/maintenance/health/health_test.go index 8947bd1b3..bad08da2b 100644 --- a/maintenance/health/health_test.go +++ b/maintenance/health/health_test.go @@ -8,31 +8,35 @@ import ( "net/http/httptest" "testing" - "github.com/pace/bricks/maintenance/log" "github.com/stretchr/testify/require" + + "github.com/pace/bricks/maintenance/log" ) func TestHandlerLiveness(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/liveness", nil) + req := httptest.NewRequest(http.MethodGet, "/health/liveness", nil) HandlerLiveness().ServeHTTP(rec, req) - checkResult(rec, 200, "OK\n", t) + checkResult(rec, http.StatusOK, "OK\n", t) } func TestHandlerReadiness(t *testing.T) { // check the default rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/health/readiness", nil) + req := httptest.NewRequest(http.MethodGet, "/health/readiness", nil) HandlerReadiness().ServeHTTP(rec, req) // check another readiness check - checkResult(rec, 200, "OK\n", t) + checkResult(rec, http.StatusOK, "OK\n", t) + rec = httptest.NewRecorder() + SetCustomReadinessCheck(func(w http.ResponseWriter, request *http.Request) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusNotFound) + if _, err := w.Write([]byte("Err\n")); err != nil { log.Warnf("could not write output: %s", err) } @@ -44,9 +48,11 @@ func TestHandlerReadiness(t *testing.T) { func checkResult(rec *httptest.ResponseRecorder, expCode int, expBody string, t *testing.T) { resp := rec.Result() defer resp.Body.Close() + if resp.StatusCode != expCode { t.Errorf("Expected /health to respond with %d, got: %d", expCode, resp.StatusCode) } + data, err := io.ReadAll(resp.Body) require.NoError(t, err) require.Equal(t, expBody, string(data)) diff --git a/maintenance/health/servicehealthcheck/connection_state.go b/maintenance/health/servicehealthcheck/connection_state.go index f9b485587..504029664 100644 --- a/maintenance/health/servicehealthcheck/connection_state.go +++ b/maintenance/health/servicehealthcheck/connection_state.go @@ -17,6 +17,7 @@ type ConnectionState struct { func (cs *ConnectionState) setConnectionState(result HealthCheckResult) { cs.m.Lock() defer cs.m.Unlock() + cs.result = result cs.lastCheck = time.Now() } @@ -36,6 +37,7 @@ func (cs *ConnectionState) SetHealthy() { func (cs *ConnectionState) GetState() HealthCheckResult { cs.m.Lock() defer cs.m.Unlock() + return cs.result } @@ -43,5 +45,6 @@ func (cs *ConnectionState) GetState() HealthCheckResult { func (cs *ConnectionState) LastChecked() time.Time { cs.m.Lock() defer cs.m.Unlock() + return cs.lastCheck } diff --git a/maintenance/health/servicehealthcheck/health_handler.go b/maintenance/health/servicehealthcheck/health_handler.go index 45e84ae01..2f19b886a 100644 --- a/maintenance/health/servicehealthcheck/health_handler.go +++ b/maintenance/health/servicehealthcheck/health_handler.go @@ -14,7 +14,9 @@ import ( func HealthHandler() http.HandlerFunc { return func(w http.ResponseWriter, _ *http.Request) { var errors []string + var warnings []string + for name, res := range checksResults(&requiredChecks) { if res.State == Err { errors = append(errors, fmt.Sprintf("%s: %s", name, res.Msg)) @@ -22,12 +24,16 @@ func HealthHandler() http.HandlerFunc { warnings = append(warnings, fmt.Sprintf("%s: %s", name, res.Msg)) } } + if len(errors) > 0 { log.Logger().Info().Strs("errors", errors).Strs("warnings", warnings).Msg("Health check failed") + msg := fmt.Sprintf("ERR: %d errors and %d warnings", len(errors), len(warnings)) writeResult(w, http.StatusServiceUnavailable, msg) + return } + writeResult(w, http.StatusOK, string(Ok)) } } @@ -35,6 +41,7 @@ func HealthHandler() http.HandlerFunc { func writeResult(w http.ResponseWriter, status int, body string) { w.Header().Set("Content-Type", "text/plain") w.WriteHeader(status) + if _, err := fmt.Fprint(w, body); err != nil { log.Warnf("could not write output: %s", err) } diff --git a/maintenance/health/servicehealthcheck/health_handler_json.go b/maintenance/health/servicehealthcheck/health_handler_json.go index cb54c5029..1747809d6 100644 --- a/maintenance/health/servicehealthcheck/health_handler_json.go +++ b/maintenance/health/servicehealthcheck/health_handler_json.go @@ -22,8 +22,11 @@ func JSONHealthHandler() http.HandlerFunc { checkResponse := make(map[string]serviceStats) var errors []string + var warnings []string + status := http.StatusOK + for name, res := range checksResults(&requiredChecks) { scr := serviceStats{ Status: res.State, @@ -33,10 +36,12 @@ func JSONHealthHandler() http.HandlerFunc { if res.State == Err { scr.Error = res.Msg status = http.StatusServiceUnavailable + errors = append(errors, fmt.Sprintf("%s: %s", name, res.Msg)) } else if res.State == Warn { warnings = append(warnings, fmt.Sprintf("%s: %s", name, res.Msg)) } + checkResponse[name] = scr } @@ -50,6 +55,7 @@ func JSONHealthHandler() http.HandlerFunc { scr.Error = res.Msg status = http.StatusServiceUnavailable } + checkResponse[name] = scr } @@ -59,6 +65,7 @@ func JSONHealthHandler() http.HandlerFunc { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) + err := json.NewEncoder(w).Encode(checkResponse) if err != nil { log.Warnf("json health handler endpoint: encoding failed: %v", err) diff --git a/maintenance/health/servicehealthcheck/health_handler_json_test.go b/maintenance/health/servicehealthcheck/health_handler_json_test.go index fdfad8ca2..78ff3855b 100644 --- a/maintenance/health/servicehealthcheck/health_handler_json_test.go +++ b/maintenance/health/servicehealthcheck/health_handler_json_test.go @@ -76,12 +76,14 @@ func TestJSONHealthHandler(t *testing.T) { for _, hc := range tc.requiredHC { RegisterHealthCheck(hc.name, hc) } + for _, hc := range tc.optionalHC { RegisterOptionalHealthCheck(hc, hc.name) } testRequest(t, handler, tc.expCode, func(t *testing.T, resBody []byte) { var res map[string]serviceStats + err := json.Unmarshal(resBody, &res) require.NoError(t, err) diff --git a/maintenance/health/servicehealthcheck/health_handler_readable.go b/maintenance/health/servicehealthcheck/health_handler_readable.go index 822bbbf8d..df7b97866 100644 --- a/maintenance/health/servicehealthcheck/health_handler_readable.go +++ b/maintenance/health/servicehealthcheck/health_handler_readable.go @@ -24,13 +24,17 @@ func ReadableHealthHandler() http.HandlerFunc { table := "%-" + strconv.Itoa(longestCheckName) + "s %-3s %s\n" bodyBuilder := &strings.Builder{} bodyBuilder.WriteString("Required Services: \n") + for name, res := range reqChecks { bodyBuilder.WriteString(fmt.Sprintf(table, name, res.State, res.Msg)) + if res.State == Err { status = http.StatusServiceUnavailable } } + bodyBuilder.WriteString("Optional Services: \n") + for name, res := range optChecks { bodyBuilder.WriteString(fmt.Sprintf(table, name, res.State, res.Msg)) // do not change status, as this is optional diff --git a/maintenance/health/servicehealthcheck/health_handler_readable_test.go b/maintenance/health/servicehealthcheck/health_handler_readable_test.go index 210f7c5f4..1e11af72b 100644 --- a/maintenance/health/servicehealthcheck/health_handler_readable_test.go +++ b/maintenance/health/servicehealthcheck/health_handler_readable_test.go @@ -65,6 +65,7 @@ func TestReadableHealthHandler(t *testing.T) { for _, hc := range tc.req { RegisterHealthCheck(hc.name, hc) } + for _, hc := range tc.opt { RegisterOptionalHealthCheck(hc, hc.name) } @@ -75,6 +76,7 @@ func TestReadableHealthHandler(t *testing.T) { results := strings.Split(string(resBody), "Optional Services: \n") reqRes := strings.Split(strings.Split(results[0], "Required Services: \n")[1], "\n") optRes := strings.Split(results[1], "\n") + testListHealthChecks(t, tc.expReq, reqRes) testListHealthChecks(t, tc.expOpt, optRes) }) diff --git a/maintenance/health/servicehealthcheck/healthcheck.go b/maintenance/health/servicehealthcheck/healthcheck.go index 881ef8c02..eee147559 100755 --- a/maintenance/health/servicehealthcheck/healthcheck.go +++ b/maintenance/health/servicehealthcheck/healthcheck.go @@ -59,12 +59,15 @@ var optionalChecks sync.Map func checksResults(checks *sync.Map) map[string]HealthCheckResult { results := make(map[string]HealthCheckResult) + checks.Range(func(key, value interface{}) bool { name := key.(string) result := value.(*ConnectionState).GetState() results[name] = result + return true }) + return results } @@ -110,6 +113,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts log.Warnf("tried to register health check with name %q twice", name) return } + if _, inOpt := optionalChecks.Load(name); inOpt { log.Warnf("tried to register health check with name %q twice", name) return @@ -119,7 +123,9 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts if len(name) > longestCheckName { longestCheckName = len(name) } + var bgState ConnectionState + checks.Store(name, &bgState) go func() { @@ -135,6 +141,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts // calculate when the warmup phase should be finished healthCheckStart := time.Now() warmupDeadline := healthCheckStart.Add(hcCfg.warmupDelay) + for { <-timer.C func() { @@ -143,6 +150,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts ctx, cancel := context.WithTimeout(ctx, hcCfg.maxWait) defer cancel() + span, ctx := opentracing.StartSpanFromContext(ctx, "BackgroundHealthCheck") defer span.Finish() @@ -151,6 +159,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts // Too soon, leave the same state return } + initErr := initHealthCheck(ctx, initHC) if initErr != nil { // Init failed again @@ -173,6 +182,7 @@ func registerHealthCheck(checks *sync.Map, name string, check HealthCheck, opts }) // sanity trigger a health check, since we can not guarantee what the real implementation does ... go check.HealthCheck(ctx) + return } } @@ -189,6 +199,7 @@ func initHealthCheck(ctx context.Context, initHC Initializable) (err error) { defer func() { if rp := recover(); rp != nil { err = fmt.Errorf("panic: %v", rp) + errors.Handle(ctx, rp) } }() diff --git a/maintenance/health/servicehealthcheck/healthcheck_test.go b/maintenance/health/servicehealthcheck/healthcheck_test.go index 197722d54..711040b4b 100644 --- a/maintenance/health/servicehealthcheck/healthcheck_test.go +++ b/maintenance/health/servicehealthcheck/healthcheck_test.go @@ -51,7 +51,7 @@ func TestHandlerHealthCheck(t *testing.T) { for _, tc := range testCases { t.Run(tc.title, func(t *testing.T) { resetHealthChecks() - // set warmup for unit testing explicitely to 0 + // set warmup for unit testing explicitly to 0 RegisterHealthCheck(tc.check.name, tc.check, UseWarmup(0)) testRequest(t, handler, tc.expCode, expBody(tc.expBody)) }) @@ -60,6 +60,7 @@ func TestHandlerHealthCheck(t *testing.T) { func TestInitErrorRetryAndCaching(t *testing.T) { handler := HealthHandler() + resetHealthChecks() bgInterval := time.Second @@ -89,6 +90,7 @@ func TestInitErrorRetryAndCaching(t *testing.T) { } // No init err, but expect err because of cache hc.initErr = false + waitForBackgroundCheck(bgInterval) testRequest(t, handler, http.StatusServiceUnavailable, expBody("ERR: 1 errors and 0 warnings")) } @@ -124,6 +126,7 @@ func TestInitErrorRetryAndCaching(t *testing.T) { // Remove init err, no caching, expect OK hc.initErr = false + waitForBackgroundCheck(bgInterval) testRequest(t, handler, http.StatusOK, expBody("OK")) } @@ -133,6 +136,7 @@ func TestInitErrorRetryAndCaching(t *testing.T) { func TestHandlerHealthCheckOptional(t *testing.T) { checkOpt := &mockHealthCheck{name: "TestHandlerHealthCheckErr", healthCheckErr: true} checkReq := &mockHealthCheck{name: "TestOk"} + resetHealthChecks() RegisterHealthCheck(checkReq.name, checkReq) @@ -162,11 +166,15 @@ func testRequest(t *testing.T, handler http.Handler, expCode int, expBody resBod rec := httptest.NewRecorder() handler.ServeHTTP(rec, nil) + resp := rec.Result() assert.Equal(t, expCode, resp.StatusCode) + defer resp.Body.Close() + data, err := io.ReadAll(resp.Body) require.NoError(t, err) + if expBody != nil { expBody(t, data) } @@ -179,6 +187,7 @@ func waitForBackgroundCheck(additionalWait ...time.Duration) { if len(additionalWait) > 0 { t += additionalWait[0] } + time.Sleep(t) } diff --git a/maintenance/health/servicehealthcheck/mocks_test.go b/maintenance/health/servicehealthcheck/mocks_test.go index 682b92732..18467c192 100644 --- a/maintenance/health/servicehealthcheck/mocks_test.go +++ b/maintenance/health/servicehealthcheck/mocks_test.go @@ -21,6 +21,7 @@ func (t *mockHealthCheck) Init(_ context.Context) error { if t.initErr { return errors.New("initError") } + return nil } @@ -28,5 +29,6 @@ func (t *mockHealthCheck) HealthCheck(_ context.Context) HealthCheckResult { if t.healthCheckErr { return HealthCheckResult{State: Err, Msg: "healthCheckErr"} } + return HealthCheckResult{State: Ok} } diff --git a/maintenance/log/handler.go b/maintenance/log/handler.go index 4f3cc3bdd..af28633b7 100755 --- a/maintenance/log/handler.go +++ b/maintenance/log/handler.go @@ -9,12 +9,12 @@ import ( "time" "github.com/opentracing/opentracing-go" - "github.com/uber/jaeger-client-go" - - "github.com/pace/bricks/maintenance/log/hlog" "github.com/rs/xid" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/uber/jaeger-client-go" + + "github.com/pace/bricks/maintenance/log/hlog" ) // RequestIDHeader name of the header that can contain a request ID @@ -73,10 +73,12 @@ func ProxyAwareRemote(r *http.Request) string { addresses := strings.Split(r.Header.Get(h), ",") for i := len(addresses) - 1; i >= 0; i-- { ip := strings.TrimSpace(addresses[i]) + realIP := net.ParseIP(ip) if !realIP.IsGlobalUnicast() || isPrivate(realIP) { continue // bad address, go to next } + return ip } } @@ -87,6 +89,7 @@ func ProxyAwareRemote(r *http.Request) string { log.Ctx(r.Context()).Warn().Err(err).Msg("failed to decode the remote address") return "" } + return host } @@ -109,6 +112,7 @@ func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http. return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + var id xid.ID // try extract of xid from header diff --git a/maintenance/log/handler_test.go b/maintenance/log/handler_test.go index ca87ad6ea..4bfe57303 100644 --- a/maintenance/log/handler_test.go +++ b/maintenance/log/handler_test.go @@ -11,21 +11,24 @@ import ( func TestLoggingHandler(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) mux := http.NewServeMux() mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { if RequestID(r) == "" { t.Error("Request should have request id") } - w.WriteHeader(201) + + w.WriteHeader(http.StatusCreated) }) Handler()(mux).ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 201 { + if resp.StatusCode != http.StatusCreated { t.Error("expected 201 status code") } } diff --git a/maintenance/log/hlog/hlog.go b/maintenance/log/hlog/hlog.go index 65f801a4a..0c8daecd5 100644 --- a/maintenance/log/hlog/hlog.go +++ b/maintenance/log/hlog/hlog.go @@ -86,6 +86,7 @@ func RemoteAddrHandler(fieldKey string) func(next http.Handler) http.Handler { return c.Str(fieldKey, host) }) } + next.ServeHTTP(w, r) }) } @@ -102,6 +103,7 @@ func UserAgentHandler(fieldKey string) func(next http.Handler) http.Handler { return c.Str(fieldKey, ua) }) } + next.ServeHTTP(w, r) }) } @@ -118,6 +120,7 @@ func RefererHandler(fieldKey string) func(next http.Handler) http.Handler { return c.Str(fieldKey, ref) }) } + next.ServeHTTP(w, r) }) } @@ -133,6 +136,7 @@ func IDFromRequest(r *http.Request) (id xid.ID, ok bool) { if r == nil { return } + return IDFromCtx(r.Context()) } @@ -161,21 +165,25 @@ func RequestIDHandler(fieldKey, headerName string) func(next http.Handler) http. return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + id, ok := IDFromRequest(r) if !ok { id = xid.New() ctx = context.WithValue(ctx, idKey{}, id) r = r.WithContext(ctx) } + if fieldKey != "" { log := zerolog.Ctx(ctx) log.UpdateContext(func(c zerolog.Context) zerolog.Context { return c.Str(fieldKey, id.String()) }) } + if headerName != "" { w.Header().Set(headerName, id.String()) } + next.ServeHTTP(w, r) }) } @@ -192,6 +200,7 @@ func CustomHeaderHandler(fieldKey, header string) func(next http.Handler) http.H return c.Str(fieldKey, val) }) } + next.ServeHTTP(w, r) }) } @@ -218,5 +227,6 @@ func ContextTransfer(parentCtx, out context.Context) context.Context { if !found { return out } + return WithValue(out, id) } diff --git a/maintenance/log/log.go b/maintenance/log/log.go index c14e948fd..124ba1091 100644 --- a/maintenance/log/log.go +++ b/maintenance/log/log.go @@ -12,12 +12,11 @@ import ( "time" "github.com/caarlos0/env/v10" - "github.com/pace/bricks/maintenance/log/hlog" - + isatty "github.com/mattn/go-isatty" "github.com/rs/zerolog" "github.com/rs/zerolog/log" - isatty "github.com/mattn/go-isatty" + "github.com/pace/bricks/maintenance/log/hlog" ) type config struct { @@ -54,7 +53,9 @@ func init() { if !ok { Fatalf("Unknown log level: %q", cfg.LogLevel) } + zerolog.SetGlobalLevel(v) + log.Logger = log.Logger.Level(v) // auto detect log format @@ -86,6 +87,7 @@ func RequestID(r *http.Request) string { if ok { return id.String() } + return "" } diff --git a/maintenance/log/log_test.go b/maintenance/log/log_test.go index e4780aeb7..ce934d206 100644 --- a/maintenance/log/log_test.go +++ b/maintenance/log/log_test.go @@ -4,12 +4,13 @@ package log import ( "context" + "net/http" "net/http/httptest" "testing" ) func TestLog(t *testing.T) { - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) if RequestID(req) != "" { t.Error("Request without set error ID can't have a request id") } diff --git a/maintenance/log/sink.go b/maintenance/log/sink.go index c47bef166..748a01ecc 100644 --- a/maintenance/log/sink.go +++ b/maintenance/log/sink.go @@ -22,6 +22,7 @@ const defaultSinkSize = 1000 func ContextWithSink(ctx context.Context, sink *Sink) context.Context { l := log.Ctx(ctx).Output(sink) ctx = l.WithContext(ctx) + return context.WithValue(ctx, sinkKey{}, sink) } @@ -68,6 +69,7 @@ func NewSink(opts ...SinkOption) *Sink { if sink.customSize > 0 { sinkSize = sink.customSize } + sink.ring = newStringRing(sinkSize) return sink @@ -84,6 +86,7 @@ func handlerWithSink(silentPrefixes ...string) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var sink Sink + for _, prefix := range silentPrefixes { if strings.HasPrefix(r.URL.Path, prefix) { sink.Silent = true @@ -118,6 +121,7 @@ func (s *Sink) Pretty() string { s.rwmutex.Lock() defer s.rwmutex.Unlock() + for _, str := range s.ring.GetContent() { n, err := strings.NewReader(str).WriteTo(writer) if err != nil { @@ -180,6 +184,7 @@ func (r *stringRing) writeString(c string) { r.data = append(r.data, c) return } + if len(r.data) < r.size-1 { // default case: ring has not reached maximum size yet // so just append and increase @@ -201,6 +206,7 @@ func (r *stringRing) GetContent() []string { } else { out := r.data[r.nextPos:] out = append(out, r.data[:r.nextPos]...) + return out } } diff --git a/maintenance/log/sink_test.go b/maintenance/log/sink_test.go index b5432265a..425e2d5fa 100644 --- a/maintenance/log/sink_test.go +++ b/maintenance/log/sink_test.go @@ -12,30 +12,35 @@ import ( func Test_Sink(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) var sink *Sink + mux := http.NewServeMux() mux.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { require.NotEqual(t, "", RequestID(r), "request should have request id") var ok bool + sink, ok = SinkFromContext(r.Context()) require.True(t, ok, "SinkFromContext() returned false unexpectedly") Req(r).Info().Msg("this is a test message for the sink") - w.WriteHeader(201) + w.WriteHeader(http.StatusCreated) }) Handler()(mux).ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - require.Equal(t, 201, resp.StatusCode, "wrong status code") + require.Equal(t, http.StatusCreated, resp.StatusCode, "wrong status code") logs := sink.ToJSON() var result []interface{} + require.NoError(t, json.Unmarshal(logs, &result), "could not unmarshal logs") require.Len(t, result, 2, "expecting exactly one log, but got %d", len(result)) @@ -46,12 +51,15 @@ func TestOverflowRing(t *testing.T) { for i := 0; i < 2; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"00", "01"}, ring.data) ring.writeString("02") require.Equal(t, []string{"00", "01", "02"}, ring.data) + for i := 3; i < 5; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"03", "04", "02"}, ring.data) } @@ -60,11 +68,14 @@ func TestRingGetContent(t *testing.T) { for i := 0; i < 2; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"00", "01"}, ring.GetContent()) ring.writeString("02") require.Equal(t, []string{"00", "01", "02"}, ring.GetContent()) + for i := 3; i < 5; i++ { ring.writeString(fmt.Sprintf("%02d", i)) } + require.Equal(t, []string{"02", "03", "04"}, ring.GetContent()) } diff --git a/maintenance/metric/handler_test.go b/maintenance/metric/handler_test.go index 27d46ae15..b9e3edd06 100644 --- a/maintenance/metric/handler_test.go +++ b/maintenance/metric/handler_test.go @@ -3,20 +3,24 @@ package metric import ( + "net/http" "net/http/httptest" "testing" ) func TestHandler(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) Handler().ServeHTTP(rec, req) resp := rec.Result() - defer resp.Body.Close() - if resp.StatusCode != 200 { + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { t.Errorf("Failed to respond with prometheus metrics: %v", resp.StatusCode) } } diff --git a/maintenance/metric/jsonapi/jsonapi.go b/maintenance/metric/jsonapi/jsonapi.go index 52f8284d2..25dd4fb70 100644 --- a/maintenance/metric/jsonapi/jsonapi.go +++ b/maintenance/metric/jsonapi/jsonapi.go @@ -105,6 +105,7 @@ func NewMetric(serviceName, path string, w http.ResponseWriter, r *http.Request) func (m *Metric) WriteHeader(statusCode int) { clientID, _ := oauth2.ClientID(m.request.Context()) IncPaceAPIHTTPRequestTotal(strconv.Itoa(statusCode), m.request.Method, m.path, m.serviceName, clientID) + duration := float64(time.Since(m.requestStart).Nanoseconds()) / float64(time.Second) AddPaceAPIHTTPRequestDurationSeconds(duration, m.request.Method, m.path, m.serviceName) m.ResponseWriter.WriteHeader(statusCode) @@ -114,6 +115,7 @@ func (m *Metric) WriteHeader(statusCode int) { func (m *Metric) Write(p []byte) (int, error) { size, err := m.ResponseWriter.Write(p) m.sizeWritten += size + return size, err } @@ -157,6 +159,7 @@ type lenCallbackReader struct { func (r *lenCallbackReader) Read(p []byte) (int, error) { n, err := r.reader.Read(p) r.size += n + return n, err } @@ -165,5 +168,6 @@ func (r *lenCallbackReader) Close() error { n, _ := io.Copy(io.Discard, r.reader) r.size += int(n) r.onEOF(r.size) + return r.reader.Close() } diff --git a/maintenance/metric/jsonapi/jsonapi_test.go b/maintenance/metric/jsonapi/jsonapi_test.go index 1b97b7237..8cd6a25f5 100644 --- a/maintenance/metric/jsonapi/jsonapi_test.go +++ b/maintenance/metric/jsonapi/jsonapi_test.go @@ -16,26 +16,30 @@ func TestMetric(t *testing.T) { t.Run("capture metrics", func(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test/1234567", nil) + req := httptest.NewRequest(http.MethodGet, "/test/1234567", nil) handler := func(w http.ResponseWriter, r *http.Request) { w = NewMetric("simple", "/test/{id}", w, r) - w.WriteHeader(204) + w.WriteHeader(http.StatusNoContent) } handler(rec, req) - req.Body.Close() // that's something the server does + if err := req.Body.Close(); err != nil { // that's something the server does + panic(err) + } resp := rec.Result() - defer resp.Body.Close() + defer func() { + _ = resp.Body.Close() + }() - if resp.StatusCode != 204 { + if resp.StatusCode != http.StatusNoContent { t.Errorf("Failed to return correct 204 response status, got: %v", resp.StatusCode) } }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() @@ -54,22 +58,25 @@ func TestMetric(t *testing.T) { t.Run("capture request size", func(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("POST", "/noop", strings.NewReader("some static request body")) + req := httptest.NewRequest(http.MethodPost, "/noop", strings.NewReader("some static request body")) handler := func(w http.ResponseWriter, r *http.Request) { NewMetric("noop", "/noop", w, r) } handler(rec, req) - req.Body.Close() // that's something the server does + if err := req.Body.Close(); err != nil { // that's something the server does + panic(err) + } }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="POST",path="/noop",service="noop",type="req"} 24` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -80,10 +87,11 @@ func TestMetric(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() reqBody := strings.NewReader("some request body") - req := httptest.NewRequest("POST", "/foobar", readerWithoutLen{reqBody}) + req := httptest.NewRequest(http.MethodPost, "/foobar", readerWithoutLen{reqBody}) handler := func(w http.ResponseWriter, r *http.Request) { NewMetric("foobar", "/foobar", w, r) + _, err := io.Copy(io.Discard, r.Body) // read request body if err != nil { panic(err) @@ -91,15 +99,18 @@ func TestMetric(t *testing.T) { } handler(rec, req) - req.Body.Close() // that's something the server does + if err := req.Body.Close(); err != nil { // that's something the server does + panic(err) + } }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="POST",path="/foobar",service="foobar",type="req"} 17` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -110,7 +121,7 @@ func TestMetric(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() reqBody := strings.NewReader("some request body that noone ever reads") - req := httptest.NewRequest("POST", "/barfoo", readerWithoutLen{reqBody}) + req := httptest.NewRequest(http.MethodPost, "/barfoo", readerWithoutLen{reqBody}) handler := func(w http.ResponseWriter, r *http.Request) { NewMetric("barfoo", "/barfoo", w, r) @@ -118,15 +129,16 @@ func TestMetric(t *testing.T) { } handler(rec, req) - req.Body.Close() // that's something the server does + _ = req.Body.Close() // that's something the server does }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="POST",path="/barfoo",service="barfoo",type="req"} 39` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } @@ -136,10 +148,11 @@ func TestMetric(t *testing.T) { t.Run("capture response size", func(t *testing.T) { t.Run("api request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/lalala", nil) + req := httptest.NewRequest(http.MethodGet, "/lalala", nil) handler := func(w http.ResponseWriter, r *http.Request) { w = NewMetric("lalala", "/lalala", w, r) + _, err := w.Write([]byte("hehehehe")) if err != nil { panic(err) @@ -147,15 +160,16 @@ func TestMetric(t *testing.T) { } handler(rec, req) - req.Body.Close() // that's something the server does + _ = req.Body.Close() // that's something the server does }) t.Run("get metrics request", func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) metric.Handler().ServeHTTP(rec, req) body := rec.Body.String() wantMetric := `pace_api_http_size_bytes_sum{method="GET",path="/lalala",service="lalala",type="resp"} 8` + if !strings.Contains(body, wantMetric) { t.Errorf("Expected metric %q, got: %v", wantMetric, body) } diff --git a/maintenance/terminationlog/termlog.go b/maintenance/terminationlog/termlog.go index 5bfc206a7..a688f3217 100644 --- a/maintenance/terminationlog/termlog.go +++ b/maintenance/terminationlog/termlog.go @@ -19,7 +19,7 @@ var logFile *os.File // Fatalf implements log Fatalf interface func Fatalf(format string, v ...interface{}) { if logFile != nil { - fmt.Fprintf(logFile, format, v...) + _, _ = fmt.Fprintf(logFile, format, v...) } log.Fatal().Msg(fmt.Sprintf(format, v...)) @@ -28,7 +28,7 @@ func Fatalf(format string, v ...interface{}) { // Fatal implements log Fatal interface func Fatal(v ...interface{}) { if logFile != nil { - fmt.Fprint(logFile, v...) + _, _ = fmt.Fprint(logFile, v...) } log.Fatal().Msg(fmt.Sprint(v...)) diff --git a/maintenance/terminationlog/termlog_linux_amd64.go b/maintenance/terminationlog/termlog_linux_amd64.go index 352253392..519b82b59 100644 --- a/maintenance/terminationlog/termlog_linux_amd64.go +++ b/maintenance/terminationlog/termlog_linux_amd64.go @@ -8,6 +8,7 @@ package terminationlog import ( + "log" "os" "syscall" ) @@ -17,11 +18,12 @@ const termLog = "/dev/termination-log" func init() { file, err := os.OpenFile(termLog, os.O_RDWR, 0o666) - if err == nil { logFile = file // redirect stderr to the termLog - syscall.Dup2(int(logFile.Fd()), 2) // nolint: errcheck + if err := syscall.Dup2(int(logFile.Fd()), 2); err != nil { + log.Fatal(err) + } } } diff --git a/maintenance/terminationlog/termlog_linux_arm64.go b/maintenance/terminationlog/termlog_linux_arm64.go index a4590d876..05dc1a4b5 100644 --- a/maintenance/terminationlog/termlog_linux_arm64.go +++ b/maintenance/terminationlog/termlog_linux_arm64.go @@ -22,6 +22,6 @@ func init() { logFile = file // redirect stderr to the termLog - syscall.Dup3(int(logFile.Fd()), 2, 0) // nolint: errcheck + syscall.Dup3(int(logFile.Fd()), 2, 0) } } diff --git a/maintenance/tracing/tracing.go b/maintenance/tracing/tracing.go index 2d957ff35..5a2859d06 100755 --- a/maintenance/tracing/tracing.go +++ b/maintenance/tracing/tracing.go @@ -9,12 +9,13 @@ import ( opentracing "github.com/opentracing/opentracing-go" olog "github.com/opentracing/opentracing-go/log" - "github.com/pace/bricks/maintenance/log" - "github.com/pace/bricks/maintenance/tracing/wire" - "github.com/pace/bricks/maintenance/util" "github.com/uber/jaeger-client-go/config" "github.com/uber/jaeger-lib/metrics/prometheus" "github.com/zenazn/goji/web/mutil" + + "github.com/pace/bricks/maintenance/log" + "github.com/pace/bricks/maintenance/tracing/wire" + "github.com/pace/bricks/maintenance/util" ) // Closer can be used in shutdown hooks to ensure that the internal queue of @@ -30,6 +31,7 @@ func init() { log.Warnf("Unable to load Jaeger config from ENV: %v", err) return } + if cfg.ServiceName == "" { log.Warn("Using Jaeger noop tracer since no JAEGER_SERVICE_NAME is present") return @@ -39,6 +41,7 @@ func init() { config.Metrics(prometheus.New()), ) opentracing.SetGlobalTracer(Tracer) + if err != nil { log.Fatal(err) } @@ -80,6 +83,7 @@ func (h *traceLogHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { handlerSpan.LogFields(olog.String("req_id", log.RequestID(r)), olog.String("path", r.URL.Path), olog.String("method", r.Method)) + ww := mutil.WrapWriter(w) h.next.ServeHTTP(ww, r.WithContext(ctx)) handlerSpan.LogFields(olog.Int("bytes", ww.BytesWritten()), olog.Int("status_code", ww.Status())) diff --git a/maintenance/tracing/tracing_test.go b/maintenance/tracing/tracing_test.go index b81dd1486..81845c9f0 100644 --- a/maintenance/tracing/tracing_test.go +++ b/maintenance/tracing/tracing_test.go @@ -8,11 +8,11 @@ import ( "net/http/httptest" "testing" + "github.com/gorilla/mux" "github.com/opentracing/opentracing-go" - "github.com/pace/bricks/maintenance/util" "github.com/stretchr/testify/require" - "github.com/gorilla/mux" + "github.com/pace/bricks/maintenance/util" ) func TestHandlerIgnore(t *testing.T) { @@ -21,7 +21,7 @@ func TestHandlerIgnore(t *testing.T) { r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {}) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/test", nil) + req := httptest.NewRequest(http.MethodGet, "/test", nil) // This test does not tests if any prefix is ignored r.ServeHTTP(rec, req) @@ -31,27 +31,28 @@ func TestHandler(t *testing.T) { r := mux.NewRouter() r.Use(Handler()) r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(200) + w.WriteHeader(http.StatusOK) }) rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", "/", nil) + req := httptest.NewRequest(http.MethodGet, "/", nil) r.ServeHTTP(rec, req) // This test does not tests the tracing - require.Equal(t, 200, rec.Result().StatusCode) + require.Equal(t, http.StatusOK, rec.Result().StatusCode) } func TestRequest(t *testing.T) { - r := Request(httptest.NewRequest("GET", "/test", nil)) + r := Request(httptest.NewRequest(http.MethodGet, "/test", nil)) // check that header is empty if len(r.Header["Uber-Trace-Id"]) != 0 { t.Errorf("expected no tracing id but got one") } - r = httptest.NewRequest("GET", "/test", nil) + r = httptest.NewRequest(http.MethodGet, "/test", nil) _, ctx := opentracing.StartSpanFromContext(context.Background(), "foo") + r = Request(r.WithContext(ctx)) if len(r.Header["Uber-Trace-Id"]) != 1 { t.Errorf("expected one tracing id but got none (JAEGER_SERVICE_NAME not in env?)") diff --git a/maintenance/tracing/wire/wire.go b/maintenance/tracing/wire/wire.go index a925719bc..7b915df50 100755 --- a/maintenance/tracing/wire/wire.go +++ b/maintenance/tracing/wire/wire.go @@ -18,5 +18,6 @@ func ToWire(spanCtx opentracing.SpanContext, r *http.Request) error { spanCtx, opentracing.HTTPHeaders, carrier) + return err } diff --git a/maintenance/util/ignore_prefix_handler.go b/maintenance/util/ignore_prefix_handler.go index a456dc7b4..d5d8a2ade 100644 --- a/maintenance/util/ignore_prefix_handler.go +++ b/maintenance/util/ignore_prefix_handler.go @@ -44,6 +44,7 @@ func NewConfigurableHandler(next, actualHandler http.Handler, cfgs ...Configurab log.Fatal(err) } } + return middleware } @@ -56,5 +57,6 @@ func (m configurableHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } } + m.actualHandler.ServeHTTP(w, r) } diff --git a/maintenance/util/ignore_prefix_handler_test.go b/maintenance/util/ignore_prefix_handler_test.go index ee6890c05..b6f0df1af 100644 --- a/maintenance/util/ignore_prefix_handler_test.go +++ b/maintenance/util/ignore_prefix_handler_test.go @@ -35,9 +35,15 @@ func TestMiddlewareWithBlacklist(t *testing.T) { for _, tc := range testCases { t.Run(tc.title, func(t *testing.T) { rec := httptest.NewRecorder() - req := httptest.NewRequest("GET", tc.path, nil) + req := httptest.NewRequest(http.MethodGet, tc.path, nil) r.ServeHTTP(rec, req) + resp := rec.Result() + + defer func() { + _ = resp.Body.Close() + }() + require.Equal(t, tc.statusCodeExpected, resp.StatusCode) }) } diff --git a/pkg/cache/example_test.go b/pkg/cache/example_test.go index bbed062d6..e5c86956e 100644 --- a/pkg/cache/example_test.go +++ b/pkg/cache/example_test.go @@ -27,6 +27,7 @@ func Example_inMemory() { if err != nil { panic(err) } + fmt.Println(string(v)) // forget diff --git a/pkg/cache/memory.go b/pkg/cache/memory.go index 017cd0c9e..c66d718cb 100644 --- a/pkg/cache/memory.go +++ b/pkg/cache/memory.go @@ -36,12 +36,15 @@ func InMemory() *Memory { func (c *Memory) Put(_ context.Context, key string, value []byte, ttl time.Duration) error { v := inMemoryValue{value: make([]byte, len(value))} copy(v.value, value) + if ttl != 0 { v.expiresAt = time.Now().Add(ttl) } + c.mx.Lock() c.values[key] = v c.mx.Unlock() + return nil } @@ -53,9 +56,11 @@ func (c *Memory) Get(ctx context.Context, key string) ([]byte, time.Duration, er c.mx.RLock() v, ok := c.values[key] c.mx.RUnlock() + if !ok { return nil, 0, fmt.Errorf("key %q: %w", key, ErrNotFound) } + var ttl time.Duration if !v.expiresAt.IsZero() { ttl = time.Until(v.expiresAt) @@ -64,8 +69,10 @@ func (c *Memory) Get(ctx context.Context, key string) ([]byte, time.Duration, er return nil, 0, fmt.Errorf("key %q: %w", key, ErrNotFound) } } + value := make([]byte, len(v.value)) copy(value, v.value) + return value, ttl, nil } diff --git a/pkg/cache/memory_test.go b/pkg/cache/memory_test.go index af57a9537..e56ac6798 100644 --- a/pkg/cache/memory_test.go +++ b/pkg/cache/memory_test.go @@ -5,9 +5,10 @@ package cache_test import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/pkg/cache" "github.com/pace/bricks/pkg/cache/testsuite" - "github.com/stretchr/testify/suite" ) func TestMemory(t *testing.T) { diff --git a/pkg/cache/redis.go b/pkg/cache/redis.go index fe3c96ebd..0f867a379 100644 --- a/pkg/cache/redis.go +++ b/pkg/cache/redis.go @@ -35,6 +35,7 @@ func (c *Redis) Put(ctx context.Context, key string, value []byte, ttl time.Dura if err != nil { return fmt.Errorf("%w: redis: %s", ErrBackend, err) } + return nil } @@ -51,26 +52,32 @@ var redisGETAndPTTL = redis.NewScript(`return { // non-nil. func (c *Redis) Get(ctx context.Context, key string) ([]byte, time.Duration, error) { key = c.prefix + key + r, err := redisGETAndPTTL.Run(ctx, c.client, []string{key}).Result() if err != nil { return nil, 0, fmt.Errorf("%w: redis: %s", ErrBackend, err) } + result, ok := r.([]interface{}) if !ok { return nil, 0, fmt.Errorf("%w: redis returned unexpected type %T, expected %T", ErrBackend, r, result) } + v := result[0] if v == nil { return nil, 0, fmt.Errorf("key %q: %w", key, ErrNotFound) } + value, ok := v.(string) if !ok { return nil, 0, fmt.Errorf("%w: redis returned unexpected type %T, expected %T", ErrBackend, v, value) } + ttl, ok := result[1].(int64) if !ok { return nil, 0, fmt.Errorf("%w: redis returned unexpected type %T, expected %T", ErrBackend, result[1], ttl) } + switch { case ttl == -1: // key exists but has no associated expire return []byte(value), 0, nil @@ -90,5 +97,6 @@ func (c *Redis) Forget(ctx context.Context, key string) error { if err != nil { return fmt.Errorf("%w: redis: %s", ErrBackend, err) } + return nil } diff --git a/pkg/cache/redis_test.go b/pkg/cache/redis_test.go index a555c6871..e54fc2193 100644 --- a/pkg/cache/redis_test.go +++ b/pkg/cache/redis_test.go @@ -5,16 +5,18 @@ package cache_test import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/backend/redis" "github.com/pace/bricks/pkg/cache" "github.com/pace/bricks/pkg/cache/testsuite" - "github.com/stretchr/testify/suite" ) func TestIntegrationRedis(t *testing.T) { if testing.Short() { t.SkipNow() } + suite.Run(t, &testsuite.CacheTestSuite{ Cache: cache.InRedis(redis.Client(), "test:cache:"), }) diff --git a/pkg/cache/testsuite/cache.go b/pkg/cache/testsuite/cache.go index 3ca906c1f..1f31c25a5 100644 --- a/pkg/cache/testsuite/cache.go +++ b/pkg/cache/testsuite/cache.go @@ -9,9 +9,10 @@ import ( "sync" "time" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/pkg/cache" - "github.com/stretchr/testify/suite" ) type CacheTestSuite struct { @@ -24,24 +25,30 @@ func (suite *CacheTestSuite) TestPut() { ctx := log.WithContext(context.Background()) _ = c.Forget(ctx, "foo") // make sure it doesn't exist + suite.Run("does not error", func() { err := c.Put(ctx, "foo", []byte("bar"), time.Second) suite.NoError(err) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "") // make sure it doesn't exist + suite.Run("accepts all null values", func() { err := c.Put(ctx, "", nil, 0) suite.NoError(err) }) + _ = c.Forget(ctx, "") // clean up _ = c.Forget(ctx, "中文پنجابی🥰🥸") // make sure it doesn't exist + suite.Run("supports unicode", func() { err := c.Put(ctx, "中文پنجابی🥰🥸", []byte("🦤ᐃᓄᒃᑎᑐᑦລາວ"), 0) suite.NoError(err) }) + _ = c.Forget(ctx, "中文پنجابی🥰🥸") // clean up suite.Run("does not error when repeated", func() { @@ -49,6 +56,7 @@ func (suite *CacheTestSuite) TestPut() { err := c.Put(ctx, "foo", []byte("bar"), time.Second) suite.NoError(err) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("stores a value", func() { @@ -56,6 +64,7 @@ func (suite *CacheTestSuite) TestPut() { value, _, _ := c.Get(ctx, "foo") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("is unaffected from manipulating the input", func() { @@ -65,15 +74,18 @@ func (suite *CacheTestSuite) TestPut() { value, _, _ := c.Get(ctx, "foo") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "foo") // clean up for i := 0; i <= 5; i++ { // make sure it doesn't exist _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } + suite.Run("does not error on simultaneous use", func() { var wg sync.WaitGroup for i := 0; i <= 5; i++ { wg.Add(1) + go func() { err := c.Put(ctx, fmt.Sprintf("foo%d", i), []byte("bar"), 0) suite.NoError(err) @@ -82,6 +94,7 @@ func (suite *CacheTestSuite) TestPut() { wg.Wait() } }) + for i := 0; i <= 5; i++ { // clean up _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } @@ -92,6 +105,7 @@ func (suite *CacheTestSuite) TestGet() { ctx := log.WithContext(context.Background()) _ = c.Forget(ctx, "foo") // make sure it doesn't exist + suite.Run("returns the ttl if set", func() { _ = c.Put(ctx, "foo", []byte("bar"), time.Minute) _, ttl, _ := c.Get(ctx, "foo") @@ -99,6 +113,7 @@ func (suite *CacheTestSuite) TestGet() { suite.LessOrEqual(int64(ttl), int64(time.Minute)) suite.Greater(int64(ttl), int64(time.Minute-time.Second)) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("returns 0 as ttl if ttl not set", func() { @@ -106,33 +121,41 @@ func (suite *CacheTestSuite) TestGet() { _, ttl, _ := c.Get(ctx, "foo") suite.Equal(time.Duration(0), ttl) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("returns not found error", func() { _, _, err := c.Get(ctx, "foo") suite.True(errors.Is(err, cache.ErrNotFound)) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("returns not found if ttl ran out", func() { err := c.Put(ctx, "foo", []byte("bar"), time.Millisecond) // minimum ttl suite.NoError(err) + <-time.After(2 * time.Millisecond) + _, _, err = c.Get(ctx, "foo") suite.True(errors.Is(err, cache.ErrNotFound)) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "foo1") // make sure it doesn't exist _ = c.Forget(ctx, "foo2") // make sure it doesn't exist + suite.Run("retrieves the right value", func() { _ = c.Put(ctx, "foo1", []byte("bar1"), 0) _ = c.Put(ctx, "foo2", []byte("bar2"), 0) value1, _, _ := c.Get(ctx, "foo1") value2, _, _ := c.Get(ctx, "foo2") + suite.Equal([]byte("bar1"), value1) suite.Equal([]byte("bar2"), value2) }) + _ = c.Forget(ctx, "foo1") // clean up _ = c.Forget(ctx, "foo2") // clean up @@ -143,6 +166,7 @@ func (suite *CacheTestSuite) TestGet() { value, _, _ := c.Get(ctx, "foo") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("does not produce nil", func() { @@ -150,34 +174,42 @@ func (suite *CacheTestSuite) TestGet() { value, _, _ := c.Get(ctx, "foo") suite.NotNil(value) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "") // make sure it doesn't exist + suite.Run("returns value stored with an empty key", func() { _ = c.Put(ctx, "", []byte("bar"), 0) value, _, _ := c.Get(ctx, "") suite.Equal([]byte("bar"), value) }) + _ = c.Forget(ctx, "") // clean up _ = c.Forget(ctx, "中文پنجابی🥰🥸") // make sure it doesn't exist + suite.Run("supports unicode", func() { _ = c.Put(ctx, "中文پنجابی🥰🥸", []byte("🦤ᐃᓄᒃᑎᑐᑦລາວ\x00"), 0) value, _, _ := c.Get(ctx, "中文پنجابی🥰🥸") suite.Equal([]byte("🦤ᐃᓄᒃᑎᑐᑦລາວ\x00"), value) }) + _ = c.Forget(ctx, "中文پنجابی🥰🥸") // clean up for i := 0; i <= 5; i++ { // make sure it doesn't exist _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } + suite.Run("does not error on simultaneous use", func() { for i := 0; i <= 5; i++ { _ = c.Put(ctx, fmt.Sprintf("foo%d", i), []byte("bar"), 0) } + var wg sync.WaitGroup for i := 0; i <= 5; i++ { wg.Add(1) + go func() { _, _, err := c.Get(ctx, fmt.Sprintf("foo%d", i)) suite.NoError(err) @@ -186,6 +218,7 @@ func (suite *CacheTestSuite) TestGet() { wg.Wait() } }) + for i := 0; i <= 5; i++ { // clean up _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } @@ -196,12 +229,14 @@ func (suite *CacheTestSuite) TestForget() { ctx := log.WithContext(context.Background()) _ = c.Forget(ctx, "foo") // make sure it doesn't exist + suite.Run("works", func() { _ = c.Put(ctx, "foo", []byte("bar"), 0) _ = c.Forget(ctx, "foo") _, _, err := c.Get(ctx, "foo") suite.True(errors.Is(err, cache.ErrNotFound)) }) + _ = c.Forget(ctx, "foo") // clean up suite.Run("does not error when repeated", func() { @@ -209,25 +244,31 @@ func (suite *CacheTestSuite) TestForget() { err := c.Forget(ctx, "foo") suite.NoError(err) }) + _ = c.Forget(ctx, "foo") // clean up _ = c.Forget(ctx, "中文پنجابی🥰🥸") // make sure it doesn't exist + suite.Run("supports unicode", func() { err := c.Forget(ctx, "中文پنجابی🥰🥸") suite.NoError(err) }) + _ = c.Forget(ctx, "中文پنجابی🥰🥸") // clean up for i := 0; i <= 5; i++ { // make sure it doesn't exist _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } + suite.Run("does not error on simultaneous use", func() { for i := 0; i <= 5; i++ { _ = c.Put(ctx, fmt.Sprintf("foo%d", i), []byte("bar"), 0) } + var wg sync.WaitGroup for i := 0; i <= 5; i++ { wg.Add(1) + go func() { err := c.Forget(ctx, fmt.Sprintf("foo%d", i)) suite.NoError(err) @@ -236,6 +277,7 @@ func (suite *CacheTestSuite) TestForget() { wg.Wait() } }) + for i := 0; i <= 5; i++ { // clean up _ = c.Forget(ctx, fmt.Sprintf("foo%d", i)) } diff --git a/pkg/cache/testsuite/cache_test.go b/pkg/cache/testsuite/cache_test.go index b58fcdd20..f5826caac 100644 --- a/pkg/cache/testsuite/cache_test.go +++ b/pkg/cache/testsuite/cache_test.go @@ -5,9 +5,10 @@ package testsuite_test import ( "testing" + "github.com/stretchr/testify/suite" + "github.com/pace/bricks/pkg/cache" . "github.com/pace/bricks/pkg/cache/testsuite" - "github.com/stretchr/testify/suite" ) // TestStringsTestSuite tests the reference in-memory cache implementation. diff --git a/pkg/context/transfer.go b/pkg/context/transfer.go index a4e122672..f88b6cb96 100755 --- a/pkg/context/transfer.go +++ b/pkg/context/transfer.go @@ -4,6 +4,7 @@ import ( "context" "github.com/opentracing/opentracing-go" + http "github.com/pace/bricks/http/middleware" "github.com/pace/bricks/http/oauth2" "github.com/pace/bricks/locale" @@ -39,6 +40,7 @@ func TransferTracingContext(in, out context.Context) context.Context { if span != nil { out = opentracing.ContextWithSpan(out, span) } + return out } @@ -47,5 +49,6 @@ func TransferExternalDependencyContext(in, out context.Context) context.Context if edc == nil { return out } + return http.ContextWithExternalDependency(out, edc) } diff --git a/pkg/isotime/isotime.go b/pkg/isotime/isotime.go index bd831e0a1..ca6f0e5c6 100644 --- a/pkg/isotime/isotime.go +++ b/pkg/isotime/isotime.go @@ -23,6 +23,7 @@ func ParseISO8601(str string) (time.Time, error) { } var t time.Time + var err error for _, l := range iso8601Layouts { diff --git a/pkg/isotime/isotime_test.go b/pkg/isotime/isotime_test.go index 310e694a8..2a74d7d69 100644 --- a/pkg/isotime/isotime_test.go +++ b/pkg/isotime/isotime_test.go @@ -87,6 +87,7 @@ func TestParseISO8601(t *testing.T) { t.Errorf("ParseISO8601() error = %v, wantErr %v", err, tt.wantErr) return } + if !got.Equal(tt.want) { t.Errorf("ParseISO8601() = %v, want %v", got, tt.want) } diff --git a/pkg/lock/redis/lock.go b/pkg/lock/redis/lock.go index 6fbe55978..b3e289b3a 100644 --- a/pkg/lock/redis/lock.go +++ b/pkg/lock/redis/lock.go @@ -10,12 +10,12 @@ import ( "sync" "time" - redisbackend "github.com/pace/bricks/backend/redis" - pberrors "github.com/pace/bricks/maintenance/errors" - "github.com/bsm/redislock" "github.com/redis/go-redis/v9" "github.com/rs/zerolog/log" + + redisbackend "github.com/pace/bricks/backend/redis" + pberrors "github.com/pace/bricks/maintenance/errors" ) var ( @@ -43,6 +43,7 @@ type LockOption func(l *Lock) func NewLock(name string, opts ...LockOption) *Lock { initClient() + l := &Lock{Name: name} for _, opt := range []LockOption{ // default options SetTTL(5 * time.Second), @@ -50,9 +51,11 @@ func NewLock(name string, opts ...LockOption) *Lock { } { opt(l) } + for _, opt := range opts { opt(l) } + return l } @@ -67,6 +70,7 @@ func (l *Lock) Acquire(ctx context.Context) (bool, error) { lock, err := l.locker.Obtain(ctx, l.Name, l.lockTTL, opts) if err != nil { log.Ctx(ctx).Debug().Err(err).Str("lockName", l.Name).Msg("Could not acquire lock") + switch { case errors.Is(err, redislock.ErrNotObtained): return false, nil @@ -76,6 +80,7 @@ func (l *Lock) Acquire(ctx context.Context) (bool, error) { } l.lock = lock + return true, nil } @@ -94,6 +99,7 @@ func (l *Lock) AcquireWait(ctx context.Context) error { } l.lock = lock + return nil } @@ -122,6 +128,7 @@ func (l *Lock) AcquireAndKeepUp(ctx context.Context) (context.Context, context.C defer cancelLock() keepUpLock(lockCtx, lock, l.lockTTL) + err := lock.Release(ctx) if err != nil && err != redislock.ErrLockNotHeld { log.Ctx(lockCtx).Debug().Err(err).Msgf("could not release lock %q", l.Name) @@ -136,6 +143,7 @@ func (l *Lock) AcquireAndKeepUp(ctx context.Context) (context.Context, context.C func keepUpLock(ctx context.Context, lock *redislock.Lock, refreshTTL time.Duration) { refreshInterval := refreshTTL / 5 lockRunsOutIn := refreshTTL // initial value after obtaining the lock + for { select { case <-ctx.Done(): @@ -149,13 +157,15 @@ func keepUpLock(ctx context.Context, lock *redislock.Lock, refreshTTL time.Durat // Try to refresh lock. case <-time.After(refreshInterval): } - if err := lock.Refresh(ctx, refreshTTL, nil); err == redislock.ErrNotObtained { + + if err := lock.Refresh(ctx, refreshTTL, nil); errors.Is(err, redislock.ErrNotObtained) { // Don't return just yet. Get the TTL of the lock and try to // refresh for as long as the TTL is not over. if lockRunsOutIn, err = lock.TTL(ctx); err != nil { log.Ctx(ctx).Debug().Err(err).Msg("could not get ttl of lock") return // assuming we lost the lock } + continue } else if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("could not refresh lock") @@ -177,6 +187,7 @@ func (l *Lock) Release(ctx context.Context) error { if err := l.lock.Release(ctx); err != nil { log.Ctx(ctx).Debug().Err(err).Msg("error releasing redis lock") + switch { case errors.Is(err, redislock.ErrLockNotHeld): // well, since our only goal is that the lock is released, this will suffice @@ -186,6 +197,7 @@ func (l *Lock) Release(ctx context.Context) error { } l.lock = nil + return nil } diff --git a/pkg/lock/redis/lock_test.go b/pkg/lock/redis/lock_test.go index 93c3ca24c..803e31665 100644 --- a/pkg/lock/redis/lock_test.go +++ b/pkg/lock/redis/lock_test.go @@ -29,14 +29,18 @@ func TestIntegration_RedisLock(t *testing.T) { for try := 0; true; try++ { lockCtx, releaseLock, err = lock.AcquireAndKeepUp(ctx) require.NoError(t, err) + if lockCtx == nil { t.Log("Not obtained, try again in 1sec") time.Sleep(time.Second) + continue } + require.NotNil(t, lockCtx) require.NotNil(t, releaseLock) releaseLock() + break } } diff --git a/pkg/redact/context.go b/pkg/redact/context.go index 3c6d5d97a..4a184a878 100644 --- a/pkg/redact/context.go +++ b/pkg/redact/context.go @@ -17,6 +17,7 @@ func Ctx(ctx context.Context) *PatternRedactor { if rd, ok := ctx.Value(patternRedactorKey{}).(*PatternRedactor); ok { return rd.Clone() } + return NewPatternRedactor(RedactionSchemeDoNothing()) } @@ -25,5 +26,6 @@ func ContextTransfer(ctx, targetCtx context.Context) context.Context { if redactor := Ctx(ctx); redactor != nil { return context.WithValue(targetCtx, patternRedactorKey{}, redactor) } + return targetCtx } diff --git a/pkg/redact/default.go b/pkg/redact/default.go index e731e0914..6d302f439 100644 --- a/pkg/redact/default.go +++ b/pkg/redact/default.go @@ -2,7 +2,7 @@ package redact -// redactionSafe last 4 digits are usually concidered safe (e.g. credit cards, iban, ...) +// redactionSafe last 4 digits are usually considered safe (e.g. credit cards, iban, ...) const redactionSafe = 4 var Default *PatternRedactor diff --git a/pkg/redact/redact.go b/pkg/redact/redact.go index 797655aed..2f3bd1be6 100644 --- a/pkg/redact/redact.go +++ b/pkg/redact/redact.go @@ -23,8 +23,10 @@ func (r *PatternRedactor) Mask(data string) string { if pattern == nil { continue } + data = pattern.ReplaceAllStringFunc(data, r.scheme) } + return data } @@ -36,12 +38,14 @@ func (r *PatternRedactor) AddPatterns(patterns ...*regexp.Regexp) { // RemovePattern deletes a pattern from the redactor func (r *PatternRedactor) RemovePattern(pattern *regexp.Regexp) { index := -1 + for i, p := range r.patterns { if p == pattern || p.String() == pattern.String() { index = i break } } + if index >= 0 { r.patterns = append(r.patterns[:index], r.patterns[index+1:]...) } @@ -55,5 +59,6 @@ func (r *PatternRedactor) Clone() *PatternRedactor { rc := NewPatternRedactor(r.scheme) rc.patterns = make([]*regexp.Regexp, len(r.patterns)) copy(rc.patterns, r.patterns) + return rc } diff --git a/pkg/redact/redact_test.go b/pkg/redact/redact_test.go index df9106814..69e1d5f31 100644 --- a/pkg/redact/redact_test.go +++ b/pkg/redact/redact_test.go @@ -6,9 +6,9 @@ import ( "regexp" "testing" - "github.com/pace/bricks/pkg/redact" - "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/pkg/redact" ) func TestRedactionSchemeKeepLast(t *testing.T) { @@ -30,6 +30,7 @@ and a ********************ring, as well as ****************cret` res := redactor.Mask(originalString) assert.Equal(t, expectedString1, res) redactor.RemovePattern(regexp.MustCompile("DE12345678909876543210")) + res = redactor.Mask(originalString) assert.Equal(t, expectedString2, res) } diff --git a/pkg/redact/scheme.go b/pkg/redact/scheme.go index 960acd799..7a522ca8b 100644 --- a/pkg/redact/scheme.go +++ b/pkg/redact/scheme.go @@ -22,6 +22,7 @@ func RedactionSchemeKeepLast(num int) func(string) string { for i := 0; i < len(runes)-num; i++ { runes[i] = '*' } + return string(runes) } } @@ -35,6 +36,7 @@ func RedactionSchemeKeepLastJWTNoSignature(num int) func(string) string { if PatternJWT.Match([]byte(s)) { parts := strings.Split(s, ".") parts[2] = defaultScheme(parts[2]) + return strings.Join(parts, ".") } diff --git a/pkg/routine/backoff.go b/pkg/routine/backoff.go index 3b7baf3e0..98f9d4361 100644 --- a/pkg/routine/backoff.go +++ b/pkg/routine/backoff.go @@ -28,5 +28,6 @@ func (all combinedExponentialBackoff) Duration(key interface{}) (dur time.Durati backoff.Reset() } } + return } diff --git a/pkg/routine/cluster_background_task_test.go b/pkg/routine/cluster_background_task_test.go index f89bc8352..21b1d0e15 100644 --- a/pkg/routine/cluster_background_task_test.go +++ b/pkg/routine/cluster_background_task_test.go @@ -15,8 +15,9 @@ import ( "testing" "time" - "github.com/pace/bricks/pkg/routine" "github.com/stretchr/testify/assert" + + "github.com/pace/bricks/pkg/routine" ) func Example_clusterBackgroundTask() { @@ -40,6 +41,7 @@ func Example_clusterBackgroundTask() { default: } out <- fmt.Sprintf("task run %d", i) + time.Sleep(100 * time.Millisecond) } }, @@ -56,6 +58,7 @@ func Example_clusterBackgroundTask() { for i := 0; i < 3; i++ { println(<-out) } + cancel() // Output: @@ -83,11 +86,13 @@ func TestIntegrationRunNamed_clusterBackgroundTask(t *testing.T) { var wg sync.WaitGroup for i := 0; i < 2; i++ { wg.Add(1) + go func() { spawnProcess(&buf) wg.Done() }() } + wg.Wait() // until both processes are done exp := `task run 0 @@ -101,16 +106,18 @@ task run 2 } func spawnProcess(w io.Writer) { - cmd := exec.Command(os.Args[0], + cmd := exec.Command(os.Args[0], //nolint:gosec "-test.timeout=2s", "-test.run=Example_clusterBackgroundTask", ) + cmd.Env = append(os.Environ(), "TEST_SUBPROCESS=1", "ROUTINE_REDIS_LOCK_TTL=200ms", ) cmd.Stdout = w cmd.Stderr = w + err := cmd.Run() if err != nil { _, _ = w.Write([]byte("error starting subprocess: " + err.Error())) @@ -134,12 +141,14 @@ func (b *subprocessOutputBuffer) Write(p []byte) (int, error) { strings.Contains(s, "Redis connection pool created"): return len(p), nil } + return b.buf.Write(p) } func (b *subprocessOutputBuffer) String() string { b.mx.Lock() defer b.mx.Unlock() + return b.buf.String() } @@ -151,5 +160,6 @@ func println(s string) { // go around the test runner _, _ = log.Writer().Write([]byte(s + "\n")) } + fmt.Println(s) } diff --git a/pkg/routine/instance.go b/pkg/routine/instance.go index c1aaa32c2..608029e1a 100755 --- a/pkg/routine/instance.go +++ b/pkg/routine/instance.go @@ -7,11 +7,11 @@ import ( "fmt" "time" - "github.com/pace/bricks/maintenance/errors" - "github.com/pace/bricks/pkg/lock/redis" - exponential "github.com/jpillora/backoff" "github.com/opentracing/opentracing-go" + + "github.com/pace/bricks/maintenance/errors" + "github.com/pace/bricks/pkg/lock/redis" ) type routineThatKeepsRunningOneInstance struct { @@ -39,7 +39,9 @@ func (r *routineThatKeepsRunningOneInstance) Run(ctx context.Context) { } r.num = ctx.Value(ctxNumKey{}).(int64) + var tryAgainIn time.Duration // zero on first run + for { select { case <-ctx.Done(): @@ -50,6 +52,7 @@ func (r *routineThatKeepsRunningOneInstance) Run(ctx context.Context) { // after the routine returned. singleRunCtx, cancel := context.WithCancel(ctx) tryAgainIn = r.singleRun(singleRunCtx) + cancel() } } @@ -62,23 +65,31 @@ func (r *routineThatKeepsRunningOneInstance) singleRun(ctx context.Context) time defer span.Finish() l := redis.NewLock("routine:lock:"+r.Name, redis.SetTTL(r.lockTTL)) + lockCtx, cancel, err := l.AcquireAndKeepUp(ctx) if err != nil { go errors.Handle(ctx, err) // report error to Sentry, non-blocking return r.backoff.Duration("lock") } + if lockCtx != nil { defer cancel() + routinePanicked := true + func() { defer errors.HandleWithCtx(ctx, fmt.Sprintf("routine %d", r.num)) // handle panics r.Routine(lockCtx) + routinePanicked = false }() + if routinePanicked { return r.backoff.Duration("routine") } } + r.backoff.ResetAll() + return r.retryInterval } diff --git a/pkg/routine/routine.go b/pkg/routine/routine.go index a85586357..f5216ad83 100755 --- a/pkg/routine/routine.go +++ b/pkg/routine/routine.go @@ -93,8 +93,10 @@ func Run(parentCtx context.Context, routine func(context.Context)) (cancel conte // add routine number to context and logger num := atomic.AddInt64(&ctr, 1) + span, ctx := opentracing.StartSpanFromContext(ctx, fmt.Sprintf("Routine %d", num)) defer span.Finish() + ctx = context.WithValue(ctx, ctxNumKey{}, num) logger := log.Ctx(ctx).With().Int64("routine", num).Logger() ctx = logger.WithContext(ctx) @@ -118,7 +120,8 @@ func Run(parentCtx context.Context, routine func(context.Context)) (cancel conte defer cancel() routine(ctx) }() - return + + return //nolint:nakedret } type ctxNumKey struct{} @@ -135,6 +138,7 @@ var ( func init() { c := make(chan os.Signal, 1) signal.Notify(c, syscall.SIGINT, syscall.SIGTERM) + go func() { <-c // block until SIGINT/SIGTERM is received signal.Stop(c) @@ -146,6 +150,7 @@ func init() { Int("count", len(contexts)). Ints64("routines", routineNumbers()). Msg("received shutdown signal, canceling all running routines") + for _, cancel := range contexts { cancel() } @@ -157,5 +162,6 @@ func routineNumbers() []int64 { for num := range contexts { routines = append(routines, num) } + return routines } diff --git a/pkg/routine/routine_test.go b/pkg/routine/routine_test.go index 788aef69a..538138665 100644 --- a/pkg/routine/routine_test.go +++ b/pkg/routine/routine_test.go @@ -45,6 +45,7 @@ func TestRun_transfersLogger(t *testing.T) { func TestRun_transfersSink(t *testing.T) { var sink log.Sink + logger := log.Logger() ctx := log.ContextWithSink(logger.WithContext(context.Background()), &sink) waitForRun(ctx, func(ctx context.Context) { @@ -75,6 +76,7 @@ func TestRun_transfersOAuth2Token(t *testing.T) { func TestRun_cancelsContextAfterRoutineIsFinished(t *testing.T) { routineCtx := contextAfterRun(context.Background(), nil) + require.Eventually(t, func() bool { return routineCtx.Err() == context.Canceled }, time.Second, time.Millisecond) @@ -86,12 +88,15 @@ func TestRun_blocksAfterShutdown(t *testing.T) { func testRunBlocksAfterShutdown(t *testing.T) { var endOfTest sync.WaitGroup + endOfTest.Add(1) // start routine that gets canceled by the shutdown routineCtx := make(chan context.Context) + Run(context.Background(), func(ctx context.Context) { routineCtx <- ctx + endOfTest.Wait() }) @@ -134,19 +139,24 @@ func TestRun_cancelsContextsOnSIGTERM(t *testing.T) { func testRunCancelsContextsOn(t *testing.T, signum syscall.Signal) { var endOfTest, routinesStarted sync.WaitGroup + endOfTest.Add(1) // start a few routines routineContexts := [3]context.Context{} routinesStarted.Add(len(routineContexts)) + for i := range routineContexts { i := i + Run(context.Background(), func(ctx context.Context) { routineContexts[i] = ctx + routinesStarted.Done() endOfTest.Wait() }) } + routinesStarted.Wait() // kill this process @@ -170,7 +180,8 @@ func exitAfterTest(t *testing.T, name string, testFunc func(*testing.T)) { testFunc(t) os.Exit(0) } - cmd := exec.Command(os.Args[0], "-test.run="+name) + + cmd := exec.Command(os.Args[0], "-test.run="+name) //nolint:gosec cmd.Env = append(os.Environ(), "ROUTINE_EXIT_AFTER_TEST=1") require.NoError(t, cmd.Run()) } @@ -178,6 +189,7 @@ func exitAfterTest(t *testing.T, name string, testFunc func(*testing.T)) { // Calls Run and returns once the routine is finished. func waitForRun(ctx context.Context, routine func(context.Context)) { done := make(chan struct{}) + Run(ctx, func(ctx context.Context) { defer func() { done <- struct{}{} }() routine(ctx) @@ -189,12 +201,15 @@ func waitForRun(ctx context.Context, routine func(context.Context)) { // routine is finished. func contextAfterRun(ctx context.Context, routine func(context.Context)) context.Context { var routineCtx context.Context + waitForRun(ctx, func(ctx context.Context) { if routine != nil { routine(ctx) } + routineCtx = ctx }) + return routineCtx } diff --git a/pkg/synctx/wg.go b/pkg/synctx/wg.go index 44f9d5faf..50723b009 100644 --- a/pkg/synctx/wg.go +++ b/pkg/synctx/wg.go @@ -15,5 +15,6 @@ type WaitGroup struct { func (wg *WaitGroup) Finish() <-chan struct{} { ch := make(chan struct{}) go func() { wg.Wait(); close(ch) }() + return ch } diff --git a/pkg/synctx/work_queue.go b/pkg/synctx/work_queue.go index d5b3693f5..2f2fda884 100644 --- a/pkg/synctx/work_queue.go +++ b/pkg/synctx/work_queue.go @@ -27,6 +27,7 @@ type WorkQueue struct { // the passed context for cancellation func NewWorkQueue(ctx context.Context) *WorkQueue { ctx, cancel := context.WithCancel(ctx) + return &WorkQueue{ ctx: ctx, done: make(chan struct{}), @@ -39,14 +40,14 @@ func NewWorkQueue(ctx context.Context) *WorkQueue { // will be immediately executed. func (queue *WorkQueue) Add(description string, fn WorkFunc) { queue.wg.Add(1) + go func() { err := fn(queue.ctx) - // if one of the work queue items fails the whole - // queue will be canceled if err != nil { queue.setErr(fmt.Errorf("failed to %s: %v", description, err)) queue.cancel() } + queue.wg.Done() }() } @@ -60,8 +61,6 @@ func (queue *WorkQueue) Wait() { case <-queue.wg.Finish(): case <-queue.ctx.Done(): err := queue.ctx.Err() - // if the queue was canceled and no error was set already - // store the error if err != nil { queue.setErr(err) } diff --git a/pkg/synctx/work_queue_test.go b/pkg/synctx/work_queue_test.go index 871b89218..cb8f87a66 100644 --- a/pkg/synctx/work_queue_test.go +++ b/pkg/synctx/work_queue_test.go @@ -13,6 +13,7 @@ func TestWorkQueueNoTask(t *testing.T) { ctx := context.Background() q := NewWorkQueue(ctx) q.Wait() + if q.Err() != nil { t.Error("expected no error") } @@ -25,10 +26,12 @@ func TestWorkQueueOneTask(t *testing.T) { if ctx1 == ctx { t.Error("should not directly pass the context") } + return nil }) q.Wait() + if q.Err() != nil { t.Error("expected no error") } @@ -42,10 +45,12 @@ func TestWorkQueueOneTaskWithErr(t *testing.T) { }) q.Wait() + if q.Err() == nil { t.Error("expected error") return } + expected := "failed to some work: Some error" if q.Err().Error() != expected { t.Errorf("expected error %q, got: %q", q.Err().Error(), expected) @@ -55,6 +60,7 @@ func TestWorkQueueOneTaskWithErr(t *testing.T) { func TestWorkQueueOneTaskWithCancel(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() + q := NewWorkQueue(ctx) q.Add("some work", func(ctx context.Context) error { time.Sleep(10 * time.Millisecond) @@ -62,10 +68,12 @@ func TestWorkQueueOneTaskWithCancel(t *testing.T) { }) q.Wait() + if q.Err() == nil { t.Error("expected error") return } + expected := "context canceled" if q.Err().Error() != expected { t.Errorf("expected error %q, got: %q", q.Err().Error(), expected) diff --git a/pkg/tracking/utm/context.go b/pkg/tracking/utm/context.go index 256e43bdd..2e13e5d2e 100644 --- a/pkg/tracking/utm/context.go +++ b/pkg/tracking/utm/context.go @@ -46,6 +46,7 @@ func ContextWithUTMData(parentCtx context.Context, data UTMData) context.Context func FromContext(ctx context.Context) (UTMData, bool) { val := ctx.Value(key) data, found := val.(UTMData) + return data, found } @@ -54,5 +55,6 @@ func ContextTransfer(in, out context.Context) context.Context { if !exists { return out // do nothing } + return ContextWithUTMData(out, utmData) } diff --git a/pkg/tracking/utm/context_test.go b/pkg/tracking/utm/context_test.go index b355d476c..2acfef17e 100644 --- a/pkg/tracking/utm/context_test.go +++ b/pkg/tracking/utm/context_test.go @@ -20,6 +20,7 @@ func TestContextWithUTMData(t *testing.T) { ctxWithData := ContextWithUTMData(ctx, data) _, found := FromContext(ctx) assert.False(t, found) + dataFromCtx, found := FromContext(ctxWithData) assert.True(t, found) assert.Equal(t, data, dataFromCtx) diff --git a/pkg/tracking/utm/http.go b/pkg/tracking/utm/http.go index 6cc7e45ab..7ab7cf744 100644 --- a/pkg/tracking/utm/http.go +++ b/pkg/tracking/utm/http.go @@ -21,6 +21,7 @@ func FromRequest(req *http.Request) (UTMData, error) { if data == emptyData { return emptyData, ErrNotFound } + return data, nil } @@ -28,6 +29,7 @@ func AttachToRequest(data UTMData, req *http.Request) *http.Request { if data == emptyData { return req } + q := req.URL.Query() q.Set("utm_source", data.Source) q.Set("utm_medium", data.Medium) @@ -35,7 +37,9 @@ func AttachToRequest(data UTMData, req *http.Request) *http.Request { q.Set("utm_term", data.Term) q.Set("utm_content", data.Content) q.Set("utm_partner_client", data.Client) + req.URL.RawQuery = q.Encode() + return req } @@ -48,10 +52,12 @@ func Middleware() func(http.Handler) http.Handler { next.ServeHTTP(w, r) return } + clientID, found := oauth2.ClientID(r.Context()) if found && data.Client == "" { data.Client = clientID } + ctx := ContextWithUTMData(r.Context(), data) r = r.WithContext(ctx) next.ServeHTTP(w, r) @@ -73,8 +79,10 @@ func (r *RoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if !found { // no utm data found, skip directly to next roundtripper return r.transport.RoundTrip(req) } + newReq := cloneRequest(req) newReq = AttachToRequest(data, newReq) + return r.transport.RoundTrip(newReq) } @@ -98,5 +106,6 @@ func cloneRequest(r *http.Request) *http.Request { for k, s := range r.Header { r2.Header[k] = append([]string(nil), s...) } + return r2 } diff --git a/pkg/tracking/utm/http_test.go b/pkg/tracking/utm/http_test.go index 6fcffbed5..51fa45b3e 100644 --- a/pkg/tracking/utm/http_test.go +++ b/pkg/tracking/utm/http_test.go @@ -54,6 +54,7 @@ func TestRoundTripper_RoundTrip(t *testing.T) { resp, err := tripper.RoundTrip(req) require.NoError(t, err) require.NotNil(t, resp) + _ = resp.Body.Close() } var _ http.RoundTripper = (*mockTripper)(nil) @@ -71,5 +72,6 @@ func (m *mockTripper) RoundTrip(req *http.Request) (*http.Response, error) { assert.Equal(m.t, v, h, fmt.Sprintf("expected query paramater %q to match value", k)) } } + return m.resp, nil } diff --git a/test/livetest/livetest.go b/test/livetest/livetest.go index 313e3236e..1bc0a64e6 100644 --- a/test/livetest/livetest.go +++ b/test/livetest/livetest.go @@ -8,6 +8,7 @@ import ( "time" opentracing "github.com/opentracing/opentracing-go" + "github.com/pace/bricks/maintenance/log" ) @@ -60,10 +61,12 @@ func executeTest(ctx context.Context, t TestFunc, name string) error { // setup tracing span, ctx := opentracing.StartSpanFromContext(ctx, "Livetest") defer span.Finish() + logger := log.Ctx(ctx) proxy := NewTestProxy(ctx, name) startTime := time.Now() + func() { defer func() { err := recover() @@ -76,7 +79,9 @@ func executeTest(ctx context.Context, t TestFunc, name string) error { t(proxy) }() + duration := float64(time.Since(startTime)) / float64(time.Second) + proxy.okIfNoSkipFail() paceLivetestDurationSeconds.WithLabelValues(cfg.ServiceName).Observe(duration) diff --git a/test/livetest/livetest_test.go b/test/livetest/livetest_test.go index bc32a2a76..34e87db19 100644 --- a/test/livetest/livetest_test.go +++ b/test/livetest/livetest_test.go @@ -4,6 +4,7 @@ package livetest import ( "context" + "net/http" "net/http/httptest" "strings" "testing" @@ -61,9 +62,10 @@ func TestIntegrationExample(t *testing.T) { return } - req := httptest.NewRequest("GET", "/metrics", nil) + req := httptest.NewRequest(http.MethodGet, "/metrics", nil) resp := httptest.NewRecorder() metric.Handler().ServeHTTP(resp, req) + body := resp.Body.String() sn := cfg.ServiceName diff --git a/test/livetest/test_proxy.go b/test/livetest/test_proxy.go index 90f35b3d0..73ec688a0 100644 --- a/test/livetest/test_proxy.go +++ b/test/livetest/test_proxy.go @@ -65,6 +65,7 @@ func (t *T) Errorf(format string, args ...interface{}) { // Fail marks the test as failed func (t *T) Fail() { log.Ctx(t.ctx).Info().Msg("Fail...") + if t.state == StateRunning { t.state = StateFailed } @@ -112,6 +113,7 @@ func (t *T) Name() string { func (t *T) Skip(args ...interface{}) { log.Ctx(t.ctx).Info().Msg("Skip...") log.Ctx(t.ctx).Info().Msg(fmt.Sprint(args...)) + if t.state == StateRunning { t.state = StateSkipped } @@ -120,9 +122,11 @@ func (t *T) Skip(args ...interface{}) { // SkipNow skips the test immediately func (t *T) SkipNow() { log.Ctx(t.ctx).Info().Msg("Skip...") + if t.state == StateRunning { t.state = StateSkipped } + panic(ErrSkipNow) } @@ -130,6 +134,7 @@ func (t *T) SkipNow() { func (t *T) Skipf(format string, args ...interface{}) { log.Ctx(t.ctx).Info().Msg("Skip...") log.Ctx(t.ctx).Info().Msgf(format, args...) + if t.state == StateRunning { t.state = StateSkipped } diff --git a/tools/jsonapigen/main.go b/tools/jsonapigen/main.go index a2851902c..cfe08e488 100644 --- a/tools/jsonapigen/main.go +++ b/tools/jsonapigen/main.go @@ -26,7 +26,7 @@ func main() { log.Fatal(err) } - f, err := os.Create(path) + f, err := os.Create(path) //nolint:gosec if err != nil { log.Fatal(err) } diff --git a/tools/testserver/main.go b/tools/testserver/main.go index c5d6ac323..8a4eeb95a 100755 --- a/tools/testserver/main.go +++ b/tools/testserver/main.go @@ -9,23 +9,22 @@ import ( "net/http" "time" - "github.com/pace/bricks/grpc" - "github.com/pace/bricks/http/security" - "github.com/pace/bricks/http/transport" - "github.com/pace/bricks/locale" - - "github.com/pace/bricks/maintenance/failover" - "github.com/pace/bricks/maintenance/health/servicehealthcheck" - "github.com/opentracing/opentracing-go" olog "github.com/opentracing/opentracing-go/log" + "github.com/pace/bricks/backend/couchdb" "github.com/pace/bricks/backend/objstore" "github.com/pace/bricks/backend/postgres" "github.com/pace/bricks/backend/redis" + "github.com/pace/bricks/grpc" pacehttp "github.com/pace/bricks/http" "github.com/pace/bricks/http/oauth2" + "github.com/pace/bricks/http/security" + "github.com/pace/bricks/http/transport" + "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/errors" + "github.com/pace/bricks/maintenance/failover" + "github.com/pace/bricks/maintenance/health/servicehealthcheck" "github.com/pace/bricks/maintenance/log" _ "github.com/pace/bricks/maintenance/tracing" "github.com/pace/bricks/test/livetest" @@ -52,8 +51,9 @@ func (*OauthBackend) IntrospectToken(ctx context.Context, token string) (*oauth2 type TestService struct{} -func (*TestService) GetTest(ctx context.Context, w simple.GetTestResponseWriter, r *simple.GetTestRequest) error { +func (*TestService) GetTest(ctx context.Context, _ simple.GetTestResponseWriter, _ *simple.GetTestRequest) error { log.Debug("Request in flight, this will wait 5 min....") + for t := 0; t < 360; t++ { select { case <-ctx.Done(): @@ -62,16 +62,19 @@ func (*TestService) GetTest(ctx context.Context, w simple.GetTestResponseWriter, time.Sleep(time.Second) } } + return nil } func main() { db := postgres.DefaultConnectionPool() rdb := redis.Client() + cdb, err := couchdb.DefaultDatabase() if err != nil { log.Fatal(err) } + _, err = objstore.Client() if err != nil { log.Fatal(err) @@ -81,15 +84,23 @@ func main() { if err != nil { log.Fatal(err) } - go ap.Run(log.WithContext(context.Background())) // nolint: errcheck + + go func() { + if err := ap.Run(log.WithContext(context.Background())); err != nil { + log.Println(err) + } + }() h := pacehttp.Router() + servicehealthcheck.RegisterHealthCheckFunc("fail-50", func(ctx context.Context) (r servicehealthcheck.HealthCheckResult) { if time.Now().Unix()%2 == 0 { panic("boom") } + r.Msg = "Foo" r.State = servicehealthcheck.Ok + return }) @@ -101,14 +112,17 @@ func main() { // do dummy database query cdb := db.WithContext(ctx) + var result struct { Calc int //nolint } + res, err := cdb.QueryOne(&result, `SELECT ? + ? AS Calc`, 10, 10) if err != nil { log.Ctx(ctx).Debug().Err(err).Msg("Calc failed") return } + log.Ctx(ctx).Debug().Int("rows_affected", res.RowsAffected()).Msg("Calc done") // do dummy redis query @@ -121,21 +135,30 @@ func main() { // do dummy call to external service log.Ctx(ctx).Debug().Msg("Test before JSON") w.Header().Set("Content-Type", "application/json") - fmt.Fprintf(w, `{"street":"Haid-und-Neu-Straße 18, 76131 Karlsruhe", "sunset": "%s"}`, fetchSunsetandSunrise(ctx)) + + if _, err := fmt.Fprintf(w, `{"street":"Haid-und-Neu-Straße 18, 76131 Karlsruhe", "sunset": "%s"}`, fetchSunsetandSunrise(ctx)); err != nil { + log.Ctx(ctx).Warn().Err(err).Msg("Failed writing message") + } }) h.HandleFunc("/grpc", func(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() - conn, err := grpc.DialContext(ctx, ":3001") + conn, err := grpc.NewClient(":3001") if err != nil { log.Fatalf("did not connect: %s", err) } - defer conn.Close() + + defer func() { + if err := conn.Close(); err != nil { + log.Printf("Failed closing connection: %v", err) + } + }() ctx = security.ContextWithToken(ctx, security.TokenString("test")) c := math.NewMathServiceClient(conn) + o, err := c.Add(ctx, &math.Input{ A: 1, B: 23, @@ -144,20 +167,21 @@ func main() { log.Ctx(ctx).Debug().Err(err).Msg("failed to add") return } + log.Ctx(ctx).Info().Msgf("C: %d", o.C) ctx = locale.WithLocale(ctx, locale.NewLocale("fr-CH", "Europe/Paris")) _, err = c.Add(ctx, &math.Input{}) if err != nil { - log.Ctx(ctx).Debug().Err(err).Msg("failed to substract") + log.Ctx(ctx).Debug().Err(err).Msg("failed to add") return } if r.URL.Query().Get("error") != "" { - _, err = c.Substract(ctx, &math.Input{}) + _, err = c.Subtract(ctx, &math.Input{}) if err != nil { - log.Ctx(ctx).Debug().Err(err).Msg("failed to substract") + log.Ctx(ctx).Debug().Err(err).Msg("failed to subtract") return } } @@ -168,20 +192,28 @@ func main() { if row.Err != nil { log.Println(err) w.WriteHeader(http.StatusInternalServerError) + return } + var doc interface{} - row.ScanDoc(&doc) // nolint: errcheck + + if err := row.ScanDoc(&doc); err != nil { + log.Printf("Failed scanning document: %v", err) + } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(doc) // nolint: errcheck + + if err := json.NewEncoder(w).Encode(doc); err != nil { + log.Printf("Failed encoding document: %v", err) + } }) h.HandleFunc("/panic", func(w http.ResponseWriter, r *http.Request) { go func() { defer errors.HandleWithCtx(r.Context(), "Some worker") - panic(fmt.Errorf("Something went wrong %d - times", 100)) + panic(fmt.Errorf("something went wrong %d - times", 100)) }() panic("Test for sentry") @@ -195,7 +227,7 @@ func main() { // Test OAuth // // This middleware is configured against an Oauth application dummy - m := oauth2.NewMiddleware(new(OauthBackend)) // nolint: staticcheck + m := oauth2.NewMiddleware(new(OauthBackend)) //nolint:staticcheck sr := h.PathPrefix("/test").Subrouter() sr.Use(m.Handler) @@ -204,29 +236,39 @@ func main() { // // curl -H "Authorization: Bearer 83142f1b767e910e78ba2d554b6708c371f053d13d6075bcc39766853a932253" localhost:3000/test/auth sr.HandleFunc("/oauth", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Oauth test successful.\n") + if _, err := fmt.Fprintf(w, "Oauth test successful.\n"); err != nil { + log.Logger().Warn().Err(err).Msg("Failed testing OAuth") + } }) s := pacehttp.Server(h) log.Logger().Info().Str("addr", s.Addr).Msg("Starting testserver ...") - // nolint:errcheck - go livetest.Test(context.Background(), []livetest.TestFunc{ - func(t *livetest.T) { - t.Log("Test /test query") - - resp, err := http.Get("http://localhost:3000/test") - if err != nil { - t.Error(err) - t.Fail() - return - } - if resp.StatusCode != 200 { - t.Logf("Received status code: %d", resp.StatusCode) - t.Fail() - } - }, - }) + go func() { + if err := livetest.Test(context.Background(), []livetest.TestFunc{ + func(t *livetest.T) { + t.Log("Test /test query") + + resp, err := http.Get("http://localhost:3000/test") + if err != nil { + t.Error(err) + t.Fail() + return + } + + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + t.Logf("Received status code: %d", resp.StatusCode) + t.Fail() + } + }, + }); err != nil { + log.Logger().Warn().Err(err).Msg("Failure during livetest") + } + }() log.Fatal(s.ListenAndServe()) } @@ -237,7 +279,8 @@ func fetchSunsetandSunrise(ctx context.Context) string { span.LogFields(olog.Float64("lat", lat), olog.Float64("lon", lon)) url := fmt.Sprintf("https://api.sunrise-sunset.org/json?lat=%f&lng=%f&date=today", lat, lon) - req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { log.Fatal(err) } @@ -245,10 +288,12 @@ func fetchSunsetandSunrise(ctx context.Context) string { c := &http.Client{ Transport: transport.NewDefaultTransportChain(), } + resp, err := c.Do(req) if err != nil { log.Fatal(err) } + defer resp.Body.Close() var r struct { @@ -266,8 +311,10 @@ func fetchSunsetandSunrise(ctx context.Context) string { if err != nil { log.Fatal(err) } + sunset = sunset.Local() log.Ctx(ctx).Debug().Time("sunset", sunset).Str("str", r.Results.Sunset).Msg("Parsed sunset time") + return sunset.String() } diff --git a/tools/testserver/math/math.pb.go b/tools/testserver/math/math.pb.go index c4bfd46b9..69aa2a8a5 100644 --- a/tools/testserver/math/math.pb.go +++ b/tools/testserver/math/math.pb.go @@ -1,8 +1,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.27.1 -// protoc v3.17.3 -// source: tools/testserver/math/math.proto +// protoc-gen-go v1.34.2 +// protoc v5.27.2 +// source: math.proto package math @@ -32,7 +32,7 @@ type Input struct { func (x *Input) Reset() { *x = Input{} if protoimpl.UnsafeEnabled { - mi := &file_tools_testserver_math_math_proto_msgTypes[0] + mi := &file_math_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -45,7 +45,7 @@ func (x *Input) String() string { func (*Input) ProtoMessage() {} func (x *Input) ProtoReflect() protoreflect.Message { - mi := &file_tools_testserver_math_math_proto_msgTypes[0] + mi := &file_math_proto_msgTypes[0] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -58,7 +58,7 @@ func (x *Input) ProtoReflect() protoreflect.Message { // Deprecated: Use Input.ProtoReflect.Descriptor instead. func (*Input) Descriptor() ([]byte, []int) { - return file_tools_testserver_math_math_proto_rawDescGZIP(), []int{0} + return file_math_proto_rawDescGZIP(), []int{0} } func (x *Input) GetA() int64 { @@ -86,7 +86,7 @@ type Output struct { func (x *Output) Reset() { *x = Output{} if protoimpl.UnsafeEnabled { - mi := &file_tools_testserver_math_math_proto_msgTypes[1] + mi := &file_math_proto_msgTypes[1] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -99,7 +99,7 @@ func (x *Output) String() string { func (*Output) ProtoMessage() {} func (x *Output) ProtoReflect() protoreflect.Message { - mi := &file_tools_testserver_math_math_proto_msgTypes[1] + mi := &file_math_proto_msgTypes[1] if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -112,7 +112,7 @@ func (x *Output) ProtoReflect() protoreflect.Message { // Deprecated: Use Output.ProtoReflect.Descriptor instead. func (*Output) Descriptor() ([]byte, []int) { - return file_tools_testserver_math_math_proto_rawDescGZIP(), []int{1} + return file_math_proto_rawDescGZIP(), []int{1} } func (x *Output) GetC() int64 { @@ -122,47 +122,46 @@ func (x *Output) GetC() int64 { return 0 } -var File_tools_testserver_math_math_proto protoreflect.FileDescriptor +var File_math_proto protoreflect.FileDescriptor -var file_tools_testserver_math_math_proto_rawDesc = []byte{ - 0x0a, 0x20, 0x74, 0x6f, 0x6f, 0x6c, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x2f, 0x6d, 0x61, 0x74, 0x68, 0x2f, 0x6d, 0x61, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x12, 0x04, 0x4d, 0x61, 0x74, 0x68, 0x22, 0x23, 0x0a, 0x05, 0x49, 0x6e, 0x70, 0x75, - 0x74, 0x12, 0x0c, 0x0a, 0x01, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x61, 0x12, - 0x0c, 0x0a, 0x01, 0x62, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x62, 0x22, 0x16, 0x0a, - 0x06, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x0c, 0x0a, 0x01, 0x63, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x01, 0x63, 0x32, 0x57, 0x0a, 0x0b, 0x4d, 0x61, 0x74, 0x68, 0x53, 0x65, 0x72, - 0x76, 0x69, 0x63, 0x65, 0x12, 0x20, 0x0a, 0x03, 0x41, 0x64, 0x64, 0x12, 0x0b, 0x2e, 0x4d, 0x61, - 0x74, 0x68, 0x2e, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x1a, 0x0c, 0x2e, 0x4d, 0x61, 0x74, 0x68, 0x2e, - 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x26, 0x0a, 0x09, 0x53, 0x75, 0x62, 0x73, 0x74, 0x72, - 0x61, 0x63, 0x74, 0x12, 0x0b, 0x2e, 0x4d, 0x61, 0x74, 0x68, 0x2e, 0x49, 0x6e, 0x70, 0x75, 0x74, - 0x1a, 0x0c, 0x2e, 0x4d, 0x61, 0x74, 0x68, 0x2e, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x42, 0x17, - 0x5a, 0x15, 0x74, 0x6f, 0x6f, 0x6c, 0x73, 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x2f, 0x6d, 0x61, 0x74, 0x68, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +var file_math_proto_rawDesc = []byte{ + 0x0a, 0x0a, 0x6d, 0x61, 0x74, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x04, 0x4d, 0x61, + 0x74, 0x68, 0x22, 0x23, 0x0a, 0x05, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x12, 0x0c, 0x0a, 0x01, 0x61, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x61, 0x12, 0x0c, 0x0a, 0x01, 0x62, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x62, 0x22, 0x16, 0x0a, 0x06, 0x4f, 0x75, 0x74, 0x70, 0x75, + 0x74, 0x12, 0x0c, 0x0a, 0x01, 0x63, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x01, 0x63, 0x32, + 0x56, 0x0a, 0x0b, 0x4d, 0x61, 0x74, 0x68, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x20, + 0x0a, 0x03, 0x41, 0x64, 0x64, 0x12, 0x0b, 0x2e, 0x4d, 0x61, 0x74, 0x68, 0x2e, 0x49, 0x6e, 0x70, + 0x75, 0x74, 0x1a, 0x0c, 0x2e, 0x4d, 0x61, 0x74, 0x68, 0x2e, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, + 0x12, 0x25, 0x0a, 0x08, 0x53, 0x75, 0x62, 0x74, 0x72, 0x61, 0x63, 0x74, 0x12, 0x0b, 0x2e, 0x4d, + 0x61, 0x74, 0x68, 0x2e, 0x49, 0x6e, 0x70, 0x75, 0x74, 0x1a, 0x0c, 0x2e, 0x4d, 0x61, 0x74, 0x68, + 0x2e, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x42, 0x17, 0x5a, 0x15, 0x74, 0x6f, 0x6f, 0x6c, 0x73, + 0x2f, 0x74, 0x65, 0x73, 0x74, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x2f, 0x6d, 0x61, 0x74, 0x68, + 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( - file_tools_testserver_math_math_proto_rawDescOnce sync.Once - file_tools_testserver_math_math_proto_rawDescData = file_tools_testserver_math_math_proto_rawDesc + file_math_proto_rawDescOnce sync.Once + file_math_proto_rawDescData = file_math_proto_rawDesc ) -func file_tools_testserver_math_math_proto_rawDescGZIP() []byte { - file_tools_testserver_math_math_proto_rawDescOnce.Do(func() { - file_tools_testserver_math_math_proto_rawDescData = protoimpl.X.CompressGZIP(file_tools_testserver_math_math_proto_rawDescData) +func file_math_proto_rawDescGZIP() []byte { + file_math_proto_rawDescOnce.Do(func() { + file_math_proto_rawDescData = protoimpl.X.CompressGZIP(file_math_proto_rawDescData) }) - return file_tools_testserver_math_math_proto_rawDescData + return file_math_proto_rawDescData } -var file_tools_testserver_math_math_proto_msgTypes = make([]protoimpl.MessageInfo, 2) -var file_tools_testserver_math_math_proto_goTypes = []interface{}{ +var file_math_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_math_proto_goTypes = []any{ (*Input)(nil), // 0: Math.Input (*Output)(nil), // 1: Math.Output } -var file_tools_testserver_math_math_proto_depIdxs = []int32{ +var file_math_proto_depIdxs = []int32{ 0, // 0: Math.MathService.Add:input_type -> Math.Input - 0, // 1: Math.MathService.Substract:input_type -> Math.Input + 0, // 1: Math.MathService.Subtract:input_type -> Math.Input 1, // 2: Math.MathService.Add:output_type -> Math.Output - 1, // 3: Math.MathService.Substract:output_type -> Math.Output + 1, // 3: Math.MathService.Subtract:output_type -> Math.Output 2, // [2:4] is the sub-list for method output_type 0, // [0:2] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name @@ -170,13 +169,13 @@ var file_tools_testserver_math_math_proto_depIdxs = []int32{ 0, // [0:0] is the sub-list for field type_name } -func init() { file_tools_testserver_math_math_proto_init() } -func file_tools_testserver_math_math_proto_init() { - if File_tools_testserver_math_math_proto != nil { +func init() { file_math_proto_init() } +func file_math_proto_init() { + if File_math_proto != nil { return } if !protoimpl.UnsafeEnabled { - file_tools_testserver_math_math_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + file_math_proto_msgTypes[0].Exporter = func(v any, i int) any { switch v := v.(*Input); i { case 0: return &v.state @@ -188,7 +187,7 @@ func file_tools_testserver_math_math_proto_init() { return nil } } - file_tools_testserver_math_math_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { + file_math_proto_msgTypes[1].Exporter = func(v any, i int) any { switch v := v.(*Output); i { case 0: return &v.state @@ -205,18 +204,18 @@ func file_tools_testserver_math_math_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_tools_testserver_math_math_proto_rawDesc, + RawDescriptor: file_math_proto_rawDesc, NumEnums: 0, NumMessages: 2, NumExtensions: 0, NumServices: 1, }, - GoTypes: file_tools_testserver_math_math_proto_goTypes, - DependencyIndexes: file_tools_testserver_math_math_proto_depIdxs, - MessageInfos: file_tools_testserver_math_math_proto_msgTypes, + GoTypes: file_math_proto_goTypes, + DependencyIndexes: file_math_proto_depIdxs, + MessageInfos: file_math_proto_msgTypes, }.Build() - File_tools_testserver_math_math_proto = out.File - file_tools_testserver_math_math_proto_rawDesc = nil - file_tools_testserver_math_math_proto_goTypes = nil - file_tools_testserver_math_math_proto_depIdxs = nil + File_math_proto = out.File + file_math_proto_rawDesc = nil + file_math_proto_goTypes = nil + file_math_proto_depIdxs = nil } diff --git a/tools/testserver/math/math.proto b/tools/testserver/math/math.proto index 4970e0fcc..549b1612d 100644 --- a/tools/testserver/math/math.proto +++ b/tools/testserver/math/math.proto @@ -15,5 +15,5 @@ message Output { service MathService { rpc Add(Input) returns (Output); - rpc Substract(Input) returns (Output); + rpc Subtract(Input) returns (Output); } \ No newline at end of file diff --git a/tools/testserver/math/math_grpc.pb.go b/tools/testserver/math/math_grpc.pb.go index d234815b2..781d70dff 100644 --- a/tools/testserver/math/math_grpc.pb.go +++ b/tools/testserver/math/math_grpc.pb.go @@ -1,4 +1,8 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.4.0 +// - protoc v5.27.2 +// source: math.proto package math @@ -11,15 +15,20 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.62.0 or later. +const _ = grpc.SupportPackageIsVersion8 + +const ( + MathService_Add_FullMethodName = "/Math.MathService/Add" + MathService_Subtract_FullMethodName = "/Math.MathService/Subtract" +) // MathServiceClient is the client API for MathService service. // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type MathServiceClient interface { Add(ctx context.Context, in *Input, opts ...grpc.CallOption) (*Output, error) - Substract(ctx context.Context, in *Input, opts ...grpc.CallOption) (*Output, error) + Subtract(ctx context.Context, in *Input, opts ...grpc.CallOption) (*Output, error) } type mathServiceClient struct { @@ -31,17 +40,19 @@ func NewMathServiceClient(cc grpc.ClientConnInterface) MathServiceClient { } func (c *mathServiceClient) Add(ctx context.Context, in *Input, opts ...grpc.CallOption) (*Output, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Output) - err := c.cc.Invoke(ctx, "/Math.MathService/Add", in, out, opts...) + err := c.cc.Invoke(ctx, MathService_Add_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *mathServiceClient) Substract(ctx context.Context, in *Input, opts ...grpc.CallOption) (*Output, error) { +func (c *mathServiceClient) Subtract(ctx context.Context, in *Input, opts ...grpc.CallOption) (*Output, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Output) - err := c.cc.Invoke(ctx, "/Math.MathService/Substract", in, out, opts...) + err := c.cc.Invoke(ctx, MathService_Subtract_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -53,7 +64,7 @@ func (c *mathServiceClient) Substract(ctx context.Context, in *Input, opts ...gr // for forward compatibility type MathServiceServer interface { Add(context.Context, *Input) (*Output, error) - Substract(context.Context, *Input) (*Output, error) + Subtract(context.Context, *Input) (*Output, error) mustEmbedUnimplementedMathServiceServer() } @@ -64,8 +75,8 @@ type UnimplementedMathServiceServer struct { func (UnimplementedMathServiceServer) Add(context.Context, *Input) (*Output, error) { return nil, status.Errorf(codes.Unimplemented, "method Add not implemented") } -func (UnimplementedMathServiceServer) Substract(context.Context, *Input) (*Output, error) { - return nil, status.Errorf(codes.Unimplemented, "method Substract not implemented") +func (UnimplementedMathServiceServer) Subtract(context.Context, *Input) (*Output, error) { + return nil, status.Errorf(codes.Unimplemented, "method Subtract not implemented") } func (UnimplementedMathServiceServer) mustEmbedUnimplementedMathServiceServer() {} @@ -90,7 +101,7 @@ func _MathService_Add_Handler(srv interface{}, ctx context.Context, dec func(int } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/Math.MathService/Add", + FullMethod: MathService_Add_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { return srv.(MathServiceServer).Add(ctx, req.(*Input)) @@ -98,20 +109,20 @@ func _MathService_Add_Handler(srv interface{}, ctx context.Context, dec func(int return interceptor(ctx, in, info, handler) } -func _MathService_Substract_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { +func _MathService_Subtract_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(Input) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(MathServiceServer).Substract(ctx, in) + return srv.(MathServiceServer).Subtract(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/Math.MathService/Substract", + FullMethod: MathService_Subtract_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MathServiceServer).Substract(ctx, req.(*Input)) + return srv.(MathServiceServer).Subtract(ctx, req.(*Input)) } return interceptor(ctx, in, info, handler) } @@ -128,10 +139,10 @@ var MathService_ServiceDesc = grpc.ServiceDesc{ Handler: _MathService_Add_Handler, }, { - MethodName: "Substract", - Handler: _MathService_Substract_Handler, + MethodName: "Subtract", + Handler: _MathService_Subtract_Handler, }, }, Streams: []grpc.StreamDesc{}, - Metadata: "tools/testserver/math/math.proto", + Metadata: "math.proto", } diff --git a/tools/testserver/simplemath/main.go b/tools/testserver/simplemath/main.go index 92ac19051..d4e424248 100644 --- a/tools/testserver/simplemath/main.go +++ b/tools/testserver/simplemath/main.go @@ -5,12 +5,13 @@ import ( "fmt" "github.com/opentracing/opentracing-go" + "github.com/uber/jaeger-client-go" + "github.com/pace/bricks/grpc" "github.com/pace/bricks/http/security" "github.com/pace/bricks/locale" "github.com/pace/bricks/maintenance/log" "github.com/pace/bricks/tools/testserver/math" - "github.com/uber/jaeger-client-go" ) type GrpcAuthBackend struct{} @@ -26,6 +27,7 @@ func (*GrpcAuthBackend) AuthorizeUnary(ctx context.Context) (context.Context, er } else { return nil, fmt.Errorf("unauthenticated") } + return ctx, nil } @@ -37,18 +39,21 @@ func (*SimpleMathServer) Add(ctx context.Context, i *math.Input) (*math.Output, if loc, ok := locale.FromCtx(ctx); ok { log.Ctx(ctx).Debug().Msgf("Locale: %q", loc.Serialize()) } + span := opentracing.SpanFromContext(ctx) if sc, ok := span.Context().(jaeger.SpanContext); ok { log.Ctx(ctx).Debug().Msgf("Span: %q", sc.String()) } var o math.Output + o.C = i.A + i.B log.Ctx(ctx).Debug().Msgf("A: %d + B: %d = C: %d", i.A, i.B, o.C) + return &o, nil } -func (*SimpleMathServer) Substract(ctx context.Context, i *math.Input) (*math.Output, error) { +func (*SimpleMathServer) Subtract(ctx context.Context, i *math.Input) (*math.Output, error) { panic("not implemented") }