Skip to content

Commit

Permalink
feature: joining errors
Browse files Browse the repository at this point in the history
  • Loading branch information
strider2038 committed Jul 16, 2023
1 parent 21bcd07 commit d486efa
Show file tree
Hide file tree
Showing 8 changed files with 318 additions and 5 deletions.
20 changes: 19 additions & 1 deletion errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,25 @@ func As[T any](err error) (T, bool) {
return t, true
}
}
err = Unwrap(err)
switch x := err.(type) {
case interface{ Unwrap() error }:
err = x.Unwrap()
if err == nil {
var z T
return z, false
}
case interface{ Unwrap() []error }:
for _, err := range x.Unwrap() {
if t, ok := As[T](err); ok {
return t, ok
}
}
var z T
return z, false
default:
var z T
return z, false
}
}

var z T
Expand Down
93 changes: 92 additions & 1 deletion errors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package errors_test

import (
"encoding/json"
stderrors "errors"
"fmt"
"io/fs"
"os"
Expand Down Expand Up @@ -111,7 +112,7 @@ func TestStackTrace(t *testing.T) {
err: wrap(errors.New("ooh")),
want: []string{
"github.com/muonsoft/errors_test.wrap\n" +
"\t.+/errors/errors_test.go:160",
"\t.+/errors/errors_test.go:200",
"github.com/muonsoft/errors_test.TestStackTrace\n" +
"\t.+/errors/errors_test.go:111",
},
Expand Down Expand Up @@ -140,6 +141,46 @@ func TestStackTrace(t *testing.T) {
"\t.+/errors/errors_test.go:137",
},
},
{
name: "join one std error",
err: errors.Join(fmt.Errorf("ooh")),
want: []string{
"github.com/muonsoft/errors_test.TestStackTrace\n" +
"\t.+/errors/errors_test.go:145",
},
},
{
name: "join two std errors",
err: errors.Join(
fmt.Errorf("ooh"),
fmt.Errorf("ooh"),
),
want: []string{
"github.com/muonsoft/errors_test.TestStackTrace\n" +
"\t.+/errors/errors_test.go:153",
},
},
{
name: "join one stacked error",
err: errors.Join(
errors.Errorf("ooh"),
),
want: []string{
"github.com/muonsoft/errors_test.TestStackTrace\n" +
"\t.+/errors/errors_test.go:165",
},
},
{
name: "join two stacked errors",
err: errors.Join(
errors.Errorf("ooh"),
errors.Errorf("ooh"),
),
want: []string{
"github.com/muonsoft/errors_test.TestStackTrace\n" +
"\t.+/errors/errors_test.go:174",
},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -203,6 +244,11 @@ func TestFields(t *testing.T) {
err: errors.Wrap(errors.Errorf("error"), errors.String("key", "value")),
expected: "value",
},
{
name: "stringer interface",
err: errors.Wrap(errors.Errorf("error"), errors.Stringer("key", stringer{s: "value"})),
expected: "value",
},
{
name: "strings",
err: errors.Wrap(errors.Errorf("error"), errors.Strings("key", []string{"value"})),
Expand Down Expand Up @@ -394,6 +440,43 @@ func TestAs(t *testing.T) {
true,
errFileNotFound,
},
{
"wrapped wrapped error",
errors.Wrap(wrapped{"error", errorT{"T"}}),
func(err error) (any, bool) {
return errors.As[errorT](err)
},
true,
errorT{"T"},
},
{
"wrapped joined error",
errors.Wrap(
errors.Join(
wrapped{"error", errorT{"T"}},
wrapped{"error", errorT{"T"}},
),
),
func(err error) (any, bool) {
return errors.As[errorT](err)
},
true,
errorT{"T"},
},
{
"wrapped std joined error",
errors.Wrap(
stderrors.Join(
wrapped{"error", errorT{"T"}},
wrapped{"error", errorT{"T"}},
),
),
func(err error) (any, bool) {
return errors.As[errorT](err)
},
true,
errorT{"T"},
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
Expand Down Expand Up @@ -555,3 +638,11 @@ type wrapped struct {
func (e wrapped) Error() string { return e.msg }

func (e wrapped) Unwrap() error { return e.err }

type stringer struct {
s string
}

func (s stringer) String() string {
return s.s
}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/muonsoft/errors

go 1.18
go 1.20

require github.com/sirupsen/logrus v1.8.1

Expand Down
71 changes: 71 additions & 0 deletions joining.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package errors

// Join returns an error that wraps the given errors with a stack trace
// at the point Join is called. Any nil error values are discarded.
// Join returns nil if errs contains no non-nil values.
// The error formats as the concatenation of the strings obtained
// by calling the Error method of each element of errs, with a newline
// between each string.
// If there is only one error in chain, then it's stack trace will be
// preserved if present.
func Join(errs ...error) error {
n := 0
for _, err := range errs {
if err != nil {
n++
}
}
if n == 0 {
return nil
}
if n == 1 {
for _, err := range errs {
if err != nil {
if isWrapper(err) {
return err
}

return &stacked{
wrapped: &wrapped{wrapped: err},
stack: newStack(0),
}
}
}
}

e := &joinError{errs: make([]error, 0, n)}

for _, err := range errs {
if err != nil {
e.errs = append(e.errs, err)
}
}

return &stacked{
wrapped: &wrapped{wrapped: e},
stack: newStack(0),
}
}

type joinError struct {
errs []error
}

// todo: add marshal json?

func (e *joinError) Error() string {
var b []byte

for i, err := range e.errs {
if i > 0 {
b = append(b, '\n')
}
b = append(b, err.Error()...)
}

return string(b)
}

func (e *joinError) Unwrap() []error {
return e.errs
}
82 changes: 82 additions & 0 deletions joining_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package errors_test

import (
"fmt"
"testing"

"github.com/muonsoft/errors"
)

func TestJoin_ReturnsNil(t *testing.T) {
if err := errors.Join(); err != nil {
t.Errorf("errors.Join() = %v, want nil", err)
}
if err := errors.Join(nil); err != nil {
t.Errorf("errors.Join(nil) = %v, want nil", err)
}
if err := errors.Join(nil, nil); err != nil {
t.Errorf("errors.Join(nil, nil) = %v, want nil", err)
}
}

func TestJoin(t *testing.T) {
err1 := errors.New("err1")
err2 := errors.New("err2")
tests := []struct {
errs []error
want []error
}{
{
errs: []error{err1},
want: []error{err1},
},
{
errs: []error{err1, err2},
want: []error{err1, err2},
},
{
errs: []error{err1, nil, err2},
want: []error{err1, err2},
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%v", test.errs), func(t *testing.T) {
got := errors.Join(test.errs...)
for _, want := range test.want {
if !errors.Is(got, want) {
t.Errorf("want err %v in chain", want)
}
}
})
}
}

func TestJoin_ErrorMethod(t *testing.T) {
err1 := errors.New("err1")
err2 := errors.New("err2")
tests := []struct {
errs []error
want string
}{
{
errs: []error{err1},
want: "err1",
},
{
errs: []error{err1, err2},
want: "err1\nerr2",
},
{
errs: []error{err1, nil, err2},
want: "err1\nerr2",
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("%v", test.errs), func(t *testing.T) {
got := errors.Join(test.errs...).Error()
if got != test.want {
t.Errorf("Join().Error() = %q; want %q", got, test.want)
}
})
}
}
16 changes: 14 additions & 2 deletions logging.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,24 @@ func Log(err error, logger Logger) {
if s, ok := e.(stackTracer); ok {
logger.SetStackTrace(s.StackTrace())
}
}
logFields(err, logger)

logger.Log(err.Error())
}

func logFields(err error, logger Logger) {
for e := err; e != nil; e = errors.Unwrap(e) {
if w, ok := e.(LoggableError); ok {
w.LogFields(logger)
}
}

logger.Log(err.Error())
if joined, ok := e.(interface{ Unwrap() []error }); ok {
for _, u := range joined.Unwrap() {
logFields(u, logger)
}
}
}
}

type BoolField struct {
Expand Down
34 changes: 34 additions & 0 deletions logging_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package errors_test

import (
stderrors "errors"
"testing"

"github.com/muonsoft/errors"
Expand Down Expand Up @@ -45,3 +46,36 @@ func TestLog_errorWithStack(t *testing.T) {
logger.AssertField(t, "deepKey", "deepValue")
logger.AssertField(t, "deepestKey", "deepestValue")
}

func TestLog_joinedErrors(t *testing.T) {
logger := errorstest.NewLogger()

err := errors.Wrap(
errors.Join(
errors.Wrap(
errors.Errorf("error 1", errors.String("key1", "value1")),
errors.String("key2", "value2"),
),
errors.Errorf("error 2", errors.String("key3", "value3")),
stderrors.Join(
errors.Errorf("error 3", errors.String("key4", "value4")),
errors.Errorf("error 4", errors.String("key5", "value5")),
),
),
)
errors.Log(err, logger)

logger.AssertMessage(t, "error 1\nerror 2\nerror 3\nerror 4")
logger.AssertStackTrace(t, errorstest.StackTrace{
{
Function: "github.com/muonsoft/errors_test.TestLog_joinedErrors",
File: ".+errors/logging_test.go",
Line: 54,
},
})
logger.AssertField(t, "key1", "value1")
logger.AssertField(t, "key2", "value2")
logger.AssertField(t, "key3", "value3")
logger.AssertField(t, "key4", "value4")
logger.AssertField(t, "key5", "value5")
}
Loading

0 comments on commit d486efa

Please sign in to comment.