From d486efaad56c7a59f6dbe1b6892839cc9184cc42 Mon Sep 17 00:00:00 2001 From: Igor Lazarev Date: Sun, 16 Jul 2023 17:03:57 +0300 Subject: [PATCH] feature: joining errors --- errors.go | 20 ++++++++++- errors_test.go | 93 ++++++++++++++++++++++++++++++++++++++++++++++++- go.mod | 2 +- joining.go | 71 +++++++++++++++++++++++++++++++++++++ joining_test.go | 82 +++++++++++++++++++++++++++++++++++++++++++ logging.go | 16 +++++++-- logging_test.go | 34 ++++++++++++++++++ options.go | 5 +++ 8 files changed, 318 insertions(+), 5 deletions(-) create mode 100644 joining.go create mode 100644 joining_test.go diff --git a/errors.go b/errors.go index 5fbaf94..1a70136 100644 --- a/errors.go +++ b/errors.go @@ -69,7 +69,25 @@ func As[T any](err error) (T, bool) { return t, true } } - err = Unwrap(err) + switch x := err.(type) { + case interface{ Unwrap() error }: + err = x.Unwrap() + if err == nil { + var z T + return z, false + } + case interface{ Unwrap() []error }: + for _, err := range x.Unwrap() { + if t, ok := As[T](err); ok { + return t, ok + } + } + var z T + return z, false + default: + var z T + return z, false + } } var z T diff --git a/errors_test.go b/errors_test.go index ad2a517..42163b4 100644 --- a/errors_test.go +++ b/errors_test.go @@ -2,6 +2,7 @@ package errors_test import ( "encoding/json" + stderrors "errors" "fmt" "io/fs" "os" @@ -111,7 +112,7 @@ func TestStackTrace(t *testing.T) { err: wrap(errors.New("ooh")), want: []string{ "github.com/muonsoft/errors_test.wrap\n" + - "\t.+/errors/errors_test.go:160", + "\t.+/errors/errors_test.go:200", "github.com/muonsoft/errors_test.TestStackTrace\n" + "\t.+/errors/errors_test.go:111", }, @@ -140,6 +141,46 @@ func TestStackTrace(t *testing.T) { "\t.+/errors/errors_test.go:137", }, }, + { + name: "join one std error", + err: errors.Join(fmt.Errorf("ooh")), + want: []string{ + "github.com/muonsoft/errors_test.TestStackTrace\n" + + "\t.+/errors/errors_test.go:145", + }, + }, + { + name: "join two std errors", + err: errors.Join( + fmt.Errorf("ooh"), + fmt.Errorf("ooh"), + ), + want: []string{ + "github.com/muonsoft/errors_test.TestStackTrace\n" + + "\t.+/errors/errors_test.go:153", + }, + }, + { + name: "join one stacked error", + err: errors.Join( + errors.Errorf("ooh"), + ), + want: []string{ + "github.com/muonsoft/errors_test.TestStackTrace\n" + + "\t.+/errors/errors_test.go:165", + }, + }, + { + name: "join two stacked errors", + err: errors.Join( + errors.Errorf("ooh"), + errors.Errorf("ooh"), + ), + want: []string{ + "github.com/muonsoft/errors_test.TestStackTrace\n" + + "\t.+/errors/errors_test.go:174", + }, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -203,6 +244,11 @@ func TestFields(t *testing.T) { err: errors.Wrap(errors.Errorf("error"), errors.String("key", "value")), expected: "value", }, + { + name: "stringer interface", + err: errors.Wrap(errors.Errorf("error"), errors.Stringer("key", stringer{s: "value"})), + expected: "value", + }, { name: "strings", err: errors.Wrap(errors.Errorf("error"), errors.Strings("key", []string{"value"})), @@ -394,6 +440,43 @@ func TestAs(t *testing.T) { true, errFileNotFound, }, + { + "wrapped wrapped error", + errors.Wrap(wrapped{"error", errorT{"T"}}), + func(err error) (any, bool) { + return errors.As[errorT](err) + }, + true, + errorT{"T"}, + }, + { + "wrapped joined error", + errors.Wrap( + errors.Join( + wrapped{"error", errorT{"T"}}, + wrapped{"error", errorT{"T"}}, + ), + ), + func(err error) (any, bool) { + return errors.As[errorT](err) + }, + true, + errorT{"T"}, + }, + { + "wrapped std joined error", + errors.Wrap( + stderrors.Join( + wrapped{"error", errorT{"T"}}, + wrapped{"error", errorT{"T"}}, + ), + ), + func(err error) (any, bool) { + return errors.As[errorT](err) + }, + true, + errorT{"T"}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -555,3 +638,11 @@ type wrapped struct { func (e wrapped) Error() string { return e.msg } func (e wrapped) Unwrap() error { return e.err } + +type stringer struct { + s string +} + +func (s stringer) String() string { + return s.s +} diff --git a/go.mod b/go.mod index cedc942..82c9dbc 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/muonsoft/errors -go 1.18 +go 1.20 require github.com/sirupsen/logrus v1.8.1 diff --git a/joining.go b/joining.go new file mode 100644 index 0000000..17cd604 --- /dev/null +++ b/joining.go @@ -0,0 +1,71 @@ +package errors + +// Join returns an error that wraps the given errors with a stack trace +// at the point Join is called. Any nil error values are discarded. +// Join returns nil if errs contains no non-nil values. +// The error formats as the concatenation of the strings obtained +// by calling the Error method of each element of errs, with a newline +// between each string. +// If there is only one error in chain, then it's stack trace will be +// preserved if present. +func Join(errs ...error) error { + n := 0 + for _, err := range errs { + if err != nil { + n++ + } + } + if n == 0 { + return nil + } + if n == 1 { + for _, err := range errs { + if err != nil { + if isWrapper(err) { + return err + } + + return &stacked{ + wrapped: &wrapped{wrapped: err}, + stack: newStack(0), + } + } + } + } + + e := &joinError{errs: make([]error, 0, n)} + + for _, err := range errs { + if err != nil { + e.errs = append(e.errs, err) + } + } + + return &stacked{ + wrapped: &wrapped{wrapped: e}, + stack: newStack(0), + } +} + +type joinError struct { + errs []error +} + +// todo: add marshal json? + +func (e *joinError) Error() string { + var b []byte + + for i, err := range e.errs { + if i > 0 { + b = append(b, '\n') + } + b = append(b, err.Error()...) + } + + return string(b) +} + +func (e *joinError) Unwrap() []error { + return e.errs +} diff --git a/joining_test.go b/joining_test.go new file mode 100644 index 0000000..e93e4c4 --- /dev/null +++ b/joining_test.go @@ -0,0 +1,82 @@ +package errors_test + +import ( + "fmt" + "testing" + + "github.com/muonsoft/errors" +) + +func TestJoin_ReturnsNil(t *testing.T) { + if err := errors.Join(); err != nil { + t.Errorf("errors.Join() = %v, want nil", err) + } + if err := errors.Join(nil); err != nil { + t.Errorf("errors.Join(nil) = %v, want nil", err) + } + if err := errors.Join(nil, nil); err != nil { + t.Errorf("errors.Join(nil, nil) = %v, want nil", err) + } +} + +func TestJoin(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + tests := []struct { + errs []error + want []error + }{ + { + errs: []error{err1}, + want: []error{err1}, + }, + { + errs: []error{err1, err2}, + want: []error{err1, err2}, + }, + { + errs: []error{err1, nil, err2}, + want: []error{err1, err2}, + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.errs), func(t *testing.T) { + got := errors.Join(test.errs...) + for _, want := range test.want { + if !errors.Is(got, want) { + t.Errorf("want err %v in chain", want) + } + } + }) + } +} + +func TestJoin_ErrorMethod(t *testing.T) { + err1 := errors.New("err1") + err2 := errors.New("err2") + tests := []struct { + errs []error + want string + }{ + { + errs: []error{err1}, + want: "err1", + }, + { + errs: []error{err1, err2}, + want: "err1\nerr2", + }, + { + errs: []error{err1, nil, err2}, + want: "err1\nerr2", + }, + } + for _, test := range tests { + t.Run(fmt.Sprintf("%v", test.errs), func(t *testing.T) { + got := errors.Join(test.errs...).Error() + if got != test.want { + t.Errorf("Join().Error() = %q; want %q", got, test.want) + } + }) + } +} diff --git a/logging.go b/logging.go index 45d3d67..72878d6 100644 --- a/logging.go +++ b/logging.go @@ -43,12 +43,24 @@ func Log(err error, logger Logger) { if s, ok := e.(stackTracer); ok { logger.SetStackTrace(s.StackTrace()) } + } + logFields(err, logger) + + logger.Log(err.Error()) +} + +func logFields(err error, logger Logger) { + for e := err; e != nil; e = errors.Unwrap(e) { if w, ok := e.(LoggableError); ok { w.LogFields(logger) } - } - logger.Log(err.Error()) + if joined, ok := e.(interface{ Unwrap() []error }); ok { + for _, u := range joined.Unwrap() { + logFields(u, logger) + } + } + } } type BoolField struct { diff --git a/logging_test.go b/logging_test.go index 2d46dfb..03f3cb0 100644 --- a/logging_test.go +++ b/logging_test.go @@ -1,6 +1,7 @@ package errors_test import ( + stderrors "errors" "testing" "github.com/muonsoft/errors" @@ -45,3 +46,36 @@ func TestLog_errorWithStack(t *testing.T) { logger.AssertField(t, "deepKey", "deepValue") logger.AssertField(t, "deepestKey", "deepestValue") } + +func TestLog_joinedErrors(t *testing.T) { + logger := errorstest.NewLogger() + + err := errors.Wrap( + errors.Join( + errors.Wrap( + errors.Errorf("error 1", errors.String("key1", "value1")), + errors.String("key2", "value2"), + ), + errors.Errorf("error 2", errors.String("key3", "value3")), + stderrors.Join( + errors.Errorf("error 3", errors.String("key4", "value4")), + errors.Errorf("error 4", errors.String("key5", "value5")), + ), + ), + ) + errors.Log(err, logger) + + logger.AssertMessage(t, "error 1\nerror 2\nerror 3\nerror 4") + logger.AssertStackTrace(t, errorstest.StackTrace{ + { + Function: "github.com/muonsoft/errors_test.TestLog_joinedErrors", + File: ".+errors/logging_test.go", + Line: 54, + }, + }) + logger.AssertField(t, "key1", "value1") + logger.AssertField(t, "key2", "value2") + logger.AssertField(t, "key3", "value3") + logger.AssertField(t, "key4", "value4") + logger.AssertField(t, "key5", "value5") +} diff --git a/options.go b/options.go index ed38df8..a8405bd 100644 --- a/options.go +++ b/options.go @@ -2,6 +2,7 @@ package errors import ( "encoding/json" + "fmt" "time" ) @@ -60,6 +61,10 @@ func String(key string, value string) Option { } } +func Stringer(key string, value fmt.Stringer) Option { + return String(key, value.String()) +} + func Strings(key string, values []string) Option { return func(options *Options) { options.AddField(StringsField{Key: key, Values: values})