From 8f2957e6ca17af13f7ad19269fff5e6f2e49b7cb Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Wed, 8 Sep 2021 21:01:22 -0400 Subject: [PATCH] Feature/optimized member expression (#653) * Added new member path resolution logic * Updated Getter and Setter interfaces * Added ssupport of pre-compiled static member path * Improved error handling --- pkg/compiler/compiler_member_test.go | 16 +- pkg/compiler/visitor.go | 16 ++ pkg/drivers/cdp/dom/document.go | 8 +- pkg/drivers/cdp/dom/element.go | 8 +- pkg/drivers/cdp/page.go | 8 +- pkg/drivers/common/getter.go | 212 +++++++++++------- pkg/drivers/common/setter.go | 73 +++--- pkg/drivers/cookie.go | 11 +- pkg/drivers/cookies.go | 9 +- pkg/drivers/headers.go | 15 +- pkg/drivers/http/document.go | 8 +- pkg/drivers/http/element.go | 8 +- pkg/drivers/http/page.go | 8 +- pkg/drivers/response.go | 19 +- pkg/runtime/core/errors.go | 63 +++++- pkg/runtime/core/expression.go | 18 +- pkg/runtime/core/getter.go | 21 ++ pkg/runtime/core/setter.go | 12 + pkg/runtime/core/value.go | 14 -- pkg/runtime/expressions/member.go | 63 ++++-- pkg/runtime/expressions/member_test.go | 295 +++++++++++++++++++++++++ pkg/runtime/values/array.go | 24 +- pkg/runtime/values/helpers.go | 104 +++++++-- pkg/runtime/values/helpers_test.go | 4 +- pkg/runtime/values/object.go | 23 +- pkg/runtime/values/object_test.go | 7 +- pkg/runtime/values/string.go | 4 + pkg/runtime/values/string_test.go | 7 + 28 files changed, 837 insertions(+), 241 deletions(-) create mode 100644 pkg/runtime/core/getter.go create mode 100644 pkg/runtime/core/setter.go create mode 100644 pkg/runtime/expressions/member_test.go diff --git a/pkg/compiler/compiler_member_test.go b/pkg/compiler/compiler_member_test.go index 9fb25b61..79d488ea 100644 --- a/pkg/compiler/compiler_member_test.go +++ b/pkg/compiler/compiler_member_test.go @@ -557,9 +557,21 @@ func BenchmarkMemberObject(b *testing.B) { func BenchmarkMemberObjectComputed(b *testing.B) { p := compiler.New().MustCompile(` - LET obj = { "foo": "bar"} + LET obj = { + first: { + second: { + third: { + fourth: { + fifth: { + bottom: true + } + } + } + } + } + } - RETURN obj["foo"] + RETURN obj["first"]["second"]["third"]["fourth"]["fifth"]["bottom"] `) for n := 0; n < b.N; n++ { diff --git a/pkg/compiler/visitor.go b/pkg/compiler/visitor.go index 9adfdcfb..f307fcea 100644 --- a/pkg/compiler/visitor.go +++ b/pkg/compiler/visitor.go @@ -10,6 +10,7 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/expressions/clauses" "github.com/MontFerret/ferret/pkg/runtime/expressions/literals" "github.com/MontFerret/ferret/pkg/runtime/expressions/operators" + "github.com/MontFerret/ferret/pkg/runtime/values" "github.com/antlr/antlr4/runtime/Go/antlr" "github.com/pkg/errors" "regexp" @@ -877,6 +878,8 @@ func (v *visitor) doVisitMemberExpression(ctx *fql.MemberExpressionContext, scop children := ctx.AllMemberExpressionPath() path := make([]*expressions.MemberPathSegment, 0, len(children)) + preCompiledPath := make([]core.Value, 0, len(children)) + skipOptimization := false for _, memberPath := range children { var exp core.Expression @@ -903,6 +906,18 @@ func (v *visitor) doVisitMemberExpression(ctx *fql.MemberExpressionContext, scop return nil, err } + if !skipOptimization { + switch t := exp.(type) { + case literals.StringLiteral: + preCompiledPath = append(preCompiledPath, values.NewString(string(t))) + case literals.IntLiteral: + preCompiledPath = append(preCompiledPath, values.NewInt(int(t))) + default: + skipOptimization = true + preCompiledPath = nil + } + } + path = append(path, segment) } @@ -910,6 +925,7 @@ func (v *visitor) doVisitMemberExpression(ctx *fql.MemberExpressionContext, scop v.getSourceMap(ctx), source, path, + preCompiledPath, ) } diff --git a/pkg/drivers/cdp/dom/document.go b/pkg/drivers/cdp/dom/document.go index fcf8b6e5..7dd2cc79 100644 --- a/pkg/drivers/cdp/dom/document.go +++ b/pkg/drivers/cdp/dom/document.go @@ -184,12 +184,12 @@ func (doc *HTMLDocument) Iterate(ctx context.Context) (core.Iterator, error) { return doc.element.Iterate(ctx) } -func (doc *HTMLDocument) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { - return common.GetInDocument(ctx, doc, path) +func (doc *HTMLDocument) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + return common.GetInDocument(ctx, path, doc) } -func (doc *HTMLDocument) SetIn(ctx context.Context, path []core.Value, value core.Value) error { - return common.SetInDocument(ctx, doc, path, value) +func (doc *HTMLDocument) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { + return common.SetInDocument(ctx, path, doc, value) } func (doc *HTMLDocument) Close() error { diff --git a/pkg/drivers/cdp/dom/element.go b/pkg/drivers/cdp/dom/element.go index 2b4eebfb..b9106cd5 100644 --- a/pkg/drivers/cdp/dom/element.go +++ b/pkg/drivers/cdp/dom/element.go @@ -185,12 +185,12 @@ func (el *HTMLElement) Iterate(_ context.Context) (core.Iterator, error) { return common.NewIterator(el) } -func (el *HTMLElement) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { - return common.GetInElement(ctx, el, path) +func (el *HTMLElement) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + return common.GetInElement(ctx, path, el) } -func (el *HTMLElement) SetIn(ctx context.Context, path []core.Value, value core.Value) error { - return common.SetInElement(ctx, el, path, value) +func (el *HTMLElement) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { + return common.SetInElement(ctx, path, el, value) } func (el *HTMLElement) GetValue(ctx context.Context) (core.Value, error) { diff --git a/pkg/drivers/cdp/page.go b/pkg/drivers/cdp/page.go index a77f866a..dc289df2 100644 --- a/pkg/drivers/cdp/page.go +++ b/pkg/drivers/cdp/page.go @@ -257,12 +257,12 @@ func (p *HTMLPage) Copy() core.Value { return values.None } -func (p *HTMLPage) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { - return common.GetInPage(ctx, p, path) +func (p *HTMLPage) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + return common.GetInPage(ctx, path, p) } -func (p *HTMLPage) SetIn(ctx context.Context, path []core.Value, value core.Value) error { - return common.SetInPage(ctx, p, path, value) +func (p *HTMLPage) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { + return common.SetInPage(ctx, path, p, value) } func (p *HTMLPage) Iterate(ctx context.Context) (core.Iterator, error) { diff --git a/pkg/drivers/common/getter.go b/pkg/drivers/common/getter.go index cded8a2b..6799722f 100644 --- a/pkg/drivers/common/getter.go +++ b/pkg/drivers/common/getter.go @@ -11,12 +11,13 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/values/types" ) -func GetInPage(ctx context.Context, page drivers.HTMLPage, path []core.Value) (core.Value, error) { +func GetInPage(ctx context.Context, path []core.Value, page drivers.HTMLPage) (core.Value, core.PathError) { if len(path) == 0 { return page, nil } - segment := path[0] + segmentIdx := 0 + segment := path[segmentIdx] if segment.Type() == types.String { segment := segment.(values.String) @@ -26,27 +27,55 @@ func GetInPage(ctx context.Context, page drivers.HTMLPage, path []core.Value) (c resp, err := page.GetResponse(ctx) if err != nil { - return nil, errors.Wrap(err, "get response") + return nil, core.NewPathError( + errors.Wrap(err, "get response"), + 0, + ) } - return resp.GetIn(ctx, path[1:]) + out, pathErr := resp.GetIn(ctx, path[segmentIdx+1:]) + + if pathErr != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) + } + + return out, nil case "mainFrame", "document": - return GetInDocument(ctx, page.GetMainFrame(), path[1:]) + out, pathErr := GetInDocument(ctx, path[segmentIdx+1:], page.GetMainFrame()) + + if pathErr != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) + } + + return out, nil case "frames": if len(path) == 1 { - return page.GetFrames(ctx) + out, err := page.GetFrames(ctx) + + if err != nil { + return nil, core.NewPathError( + errors.Wrap(err, "get response"), + segmentIdx, + ) + } + + return out, nil } - idx := path[1] + segmentIdx = +1 + idx := path[segmentIdx] if !values.IsNumber(idx) { - return values.None, core.TypeError(idx.Type(), types.Int, types.Float) + return values.None, core.NewPathError( + core.TypeError(idx.Type(), types.Int, types.Float), + segmentIdx, + ) } value, err := page.GetFrame(ctx, values.ToInt(idx)) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) } if len(path) == 2 { @@ -56,42 +85,57 @@ func GetInPage(ctx context.Context, page drivers.HTMLPage, path []core.Value) (c frame, err := drivers.ToDocument(value) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) + } + + out, pathErr := GetInDocument(ctx, path[segmentIdx+1:], frame) + + if err != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) } - return GetInDocument(ctx, frame, path[2:]) + return out, nil case "url", "URL": return page.GetURL(), nil case "cookies": cookies, err := page.GetCookies(ctx) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) } if len(path) == 1 { return cookies, nil } - return cookies.GetIn(ctx, path[1:]) + out, pathErr := cookies.GetIn(ctx, path[segmentIdx+1:]) + + if err != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) + } + + return out, nil case "title": return page.GetMainFrame().GetTitle(), nil case "isClosed": return page.IsClosed(), nil default: - return GetInDocument(ctx, page.GetMainFrame(), path) + return GetInDocument(ctx, path, page.GetMainFrame()) } } - return GetInDocument(ctx, page.GetMainFrame(), path) + return GetInDocument(ctx, path, page.GetMainFrame()) } -func GetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Value) (core.Value, error) { +func GetInDocument(ctx context.Context, path []core.Value, doc drivers.HTMLDocument) (core.Value, core.PathError) { if len(path) == 0 { return doc, nil } - segment := path[0] + var out core.Value + var err error + segmentIdx := 0 + segment := path[segmentIdx] if segment.Type() == types.String { segment := segment.(values.String) @@ -107,7 +151,7 @@ func GetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Va parent, err := doc.GetParentDocument(ctx) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) } if parent == nil { @@ -118,12 +162,18 @@ func GetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Va return parent, nil } - return GetInDocument(ctx, parent, path[1:]) + out, pathErr := GetInDocument(ctx, path[segmentIdx+1:], parent) + + if pathErr != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) + } + + return out, nil case "body", "head": out, err := doc.QuerySelector(ctx, segment) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) } if out == values.None { @@ -137,121 +187,127 @@ func GetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Va el, err := drivers.ToElement(out) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) + } + + out, pathErr := GetInElement(ctx, path[segmentIdx+1:], el) + + if pathErr != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) } - return GetInElement(ctx, el, path[1:]) + return out, nil case "innerHTML": - return doc.GetElement().GetInnerHTML(ctx) + out, err = doc.GetElement().GetInnerHTML(ctx) case "innerText": - return doc.GetElement().GetInnerText(ctx) + out, err = doc.GetElement().GetInnerText(ctx) default: - return GetInNode(ctx, doc.GetElement(), path) + return GetInNode(ctx, path, doc.GetElement()) } + + return values.ReturnOrNext(ctx, path, segmentIdx, out, err) } - return GetInNode(ctx, doc.GetElement(), path) + return GetInNode(ctx, path, doc.GetElement()) } -func GetInElement(ctx context.Context, el drivers.HTMLElement, path []core.Value) (core.Value, error) { +func GetInElement(ctx context.Context, path []core.Value, el drivers.HTMLElement) (core.Value, core.PathError) { if len(path) == 0 { return el, nil } - segment := path[0] + segmentIdx := 0 + segment := path[segmentIdx] if segment.Type() == types.String { + var out core.Value + var err error + segment := segment.(values.String) switch segment { case "innerText": - return el.GetInnerText(ctx) + out, err = el.GetInnerText(ctx) case "innerHTML": - return el.GetInnerHTML(ctx) + out, err = el.GetInnerHTML(ctx) case "value": - return el.GetValue(ctx) + out, err = el.GetValue(ctx) case "attributes": - attrs, err := el.GetAttributes(ctx) - - if err != nil { - return values.None, err - } - if len(path) == 1 { - return attrs, nil - } + out, err = el.GetAttributes(ctx) + } else { + // e.g. attributes.href + segmentIdx++ + attrName := path[segmentIdx] - return values.GetIn(ctx, attrs, path[1:]) - case "style": - styles, err := el.GetStyles(ctx) - - if err != nil { - return values.None, err + out, err = el.GetAttribute(ctx, values.ToString(attrName)) } - + case "style": if len(path) == 1 { - return styles, nil - } + out, err = el.GetStyles(ctx) + } else { + // e.g. style.color + segmentIdx++ + styleName := path[segmentIdx] - return values.GetIn(ctx, styles, path[1:]) + out, err = el.GetStyle(ctx, values.ToString(styleName)) + } case "previousElementSibling": - return el.GetPreviousElementSibling(ctx) + out, err = el.GetPreviousElementSibling(ctx) case "nextElementSibling": - return el.GetNextElementSibling(ctx) + out, err = el.GetNextElementSibling(ctx) case "parentElement": - return el.GetParentElement(ctx) + out, err = el.GetParentElement(ctx) default: - return GetInNode(ctx, el, path) + return GetInNode(ctx, path, el) } + + return values.ReturnOrNext(ctx, path, segmentIdx, out, err) } - return GetInNode(ctx, el, path) + return GetInNode(ctx, path, el) } -func GetInNode(ctx context.Context, node drivers.HTMLNode, path []core.Value) (core.Value, error) { +func GetInNode(ctx context.Context, path []core.Value, node drivers.HTMLNode) (core.Value, core.PathError) { if len(path) == 0 { return node, nil } - nt := node.Type() - segment := path[0] - st := segment.Type() - - switch st { - case types.Int: - if nt == drivers.HTMLElementType || nt == drivers.HTMLDocumentType { - re := node.(drivers.HTMLNode) + segmentIdx := 0 + segment := path[segmentIdx] - return re.GetChildNode(ctx, values.ToInt(segment)) - } + var out core.Value + var err error - return values.GetIn(ctx, node, path[1:]) + switch segment.Type() { + case types.Int: + out, err = node.GetChildNode(ctx, values.ToInt(segment)) case types.String: segment := segment.(values.String) switch segment { case "nodeType": - return node.GetNodeType(ctx) + out, err = node.GetNodeType(ctx) case "nodeName": - return node.GetNodeName(ctx) + out, err = node.GetNodeName(ctx) case "children": - children, err := node.GetChildNodes(ctx) - - if err != nil { - return values.None, err - } - if len(path) == 1 { - return children, nil + out, err = node.GetChildNodes(ctx) + } else { + segmentIdx++ + out, err = node.GetChildNode(ctx, values.ToInt(path[segmentIdx])) } - - return values.GetIn(ctx, children, path[1:]) case "length": return node.Length(), nil default: return values.None, nil } default: - return values.None, core.TypeError(st, types.Int, types.String) + return values.None, core.NewPathError( + core.TypeError(segment.Type(), types.Int, types.String), + segmentIdx, + ) } + + return values.ReturnOrNext(ctx, path, segmentIdx, out, err) } diff --git a/pkg/drivers/common/setter.go b/pkg/drivers/common/setter.go index cf775d08..40e6749d 100644 --- a/pkg/drivers/common/setter.go +++ b/pkg/drivers/common/setter.go @@ -9,28 +9,29 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/values/types" ) -func SetInPage(ctx context.Context, page drivers.HTMLPage, path []core.Value, value core.Value) error { +func SetInPage(ctx context.Context, path []core.Value, page drivers.HTMLPage, value core.Value) core.PathError { if len(path) == 0 { return nil } - return SetInDocument(ctx, page.GetMainFrame(), path, value) + return SetInDocument(ctx, path, page.GetMainFrame(), value) } -func SetInDocument(ctx context.Context, doc drivers.HTMLDocument, path []core.Value, value core.Value) error { +func SetInDocument(ctx context.Context, path []core.Value, doc drivers.HTMLDocument, value core.Value) core.PathError { if len(path) == 0 { return nil } - return SetInNode(ctx, doc, path, value) + return SetInNode(ctx, path, doc, value) } -func SetInElement(ctx context.Context, el drivers.HTMLElement, path []core.Value, value core.Value) error { +func SetInElement(ctx context.Context, path []core.Value, el drivers.HTMLElement, value core.Value) core.PathError { if len(path) == 0 { return nil } - segment := path[0] + segmentIdx := 0 + segment := path[segmentIdx] if segment.Type() == types.String { segment := segment.(values.String) @@ -39,27 +40,32 @@ func SetInElement(ctx context.Context, el drivers.HTMLElement, path []core.Value case "attributes": if len(path) > 1 { attrName := path[1] + err := el.SetAttribute(ctx, values.NewString(attrName.String()), values.NewString(value.String())) - return el.SetAttribute(ctx, values.NewString(attrName.String()), values.NewString(value.String())) + if err != nil { + return core.NewPathError(err, segmentIdx) + } + + return nil } err := core.ValidateType(value, types.Object) if err != nil { - return err + return core.NewPathError(err, segmentIdx) } curr, err := el.GetAttributes(ctx) if err != nil { - return err + return core.NewPathError(err, segmentIdx) } // remove all previous attributes err = el.RemoveAttribute(ctx, curr.Keys()...) if err != nil { - return err + return core.NewPathError(err, segmentIdx) } obj := value.(*values.Object) @@ -69,24 +75,34 @@ func SetInElement(ctx context.Context, el drivers.HTMLElement, path []core.Value return err == nil }) - return err + if err != nil { + return core.NewPathError(err, segmentIdx) + } + + return nil case "style": if len(path) > 1 { attrName := path[1] - return el.SetStyle(ctx, values.NewString(attrName.String()), values.NewString(value.String())) + err := el.SetStyle(ctx, values.NewString(attrName.String()), values.NewString(value.String())) + + if err != nil { + return core.NewPathError(err, segmentIdx) + } + + return nil } err := core.ValidateType(value, types.Object) if err != nil { - return err + return core.NewPathError(err, segmentIdx) } styles, err := el.GetStyles(ctx) if err != nil { - return err + return core.NewPathError(err, segmentIdx) } err = el.RemoveStyle(ctx, styles.Keys()...) @@ -98,30 +114,33 @@ func SetInElement(ctx context.Context, el drivers.HTMLElement, path []core.Value return err == nil }) - return err + if err != nil { + return core.NewPathError(err, segmentIdx) + } + + return nil case "value": if len(path) > 1 { - return core.Error(ErrInvalidPath, PathToString(path[1:])) + return core.NewPathError(ErrInvalidPath, segmentIdx+1) + } + + err := el.SetValue(ctx, value) + + if err != nil { + return core.NewPathError(err, segmentIdx) } - return el.SetValue(ctx, value) + return nil } } - return SetInNode(ctx, el, path, value) + return SetInNode(ctx, path, el, value) } -func SetInNode(_ context.Context, _ drivers.HTMLNode, path []core.Value, _ core.Value) error { +func SetInNode(_ context.Context, path []core.Value, _ drivers.HTMLNode, _ core.Value) core.PathError { if len(path) == 0 { return nil } - segment := path[0] - st := segment.Type() - - if st == types.Int { - return core.Error(core.ErrInvalidOperation, "children are read-only") - } - - return core.Error(ErrReadOnly, PathToString(path)) + return core.NewPathError(ErrReadOnly, 0) } diff --git a/pkg/drivers/cookie.go b/pkg/drivers/cookie.go index 3a3a3226..9560aacb 100644 --- a/pkg/drivers/cookie.go +++ b/pkg/drivers/cookie.go @@ -12,7 +12,6 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/values" - "github.com/MontFerret/ferret/pkg/runtime/values/types" ) type ( @@ -162,20 +161,14 @@ func (c HTTPCookie) MarshalJSON() ([]byte, error) { return out, err } -func (c HTTPCookie) GetIn(_ context.Context, path []core.Value) (core.Value, error) { +func (c HTTPCookie) GetIn(_ context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return values.None, nil } segment := path[0] - err := core.ValidateType(segment, types.String) - - if err != nil { - return values.None, err - } - - switch segment.(values.String) { + switch values.ToString(segment) { case "name": return values.NewString(c.Name), nil case "value": diff --git a/pkg/drivers/cookies.go b/pkg/drivers/cookies.go index 4798e1bd..a8954cba 100644 --- a/pkg/drivers/cookies.go +++ b/pkg/drivers/cookies.go @@ -175,17 +175,18 @@ func (c *HTTPCookies) Set(cookie HTTPCookie) { c.values[cookie.Name] = cookie } -func (c *HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { +func (c *HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return values.None, nil } - segment := path[0] + segmentIdx := 0 + segment := path[segmentIdx] err := core.ValidateType(segment, types.String) if err != nil { - return values.None, err + return values.None, core.NewPathError(err, segmentIdx) } cookie, found := c.values[segment.String()] @@ -195,7 +196,7 @@ func (c *HTTPCookies) GetIn(ctx context.Context, path []core.Value) (core.Value, return cookie, nil } - return values.GetIn(ctx, cookie, path[1:]) + return values.GetIn(ctx, cookie, path[segmentIdx+1:]) } return values.None, nil diff --git a/pkg/drivers/headers.go b/pkg/drivers/headers.go index c89a5118..2757dbca 100644 --- a/pkg/drivers/headers.go +++ b/pkg/drivers/headers.go @@ -11,8 +11,6 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/values" - "github.com/MontFerret/ferret/pkg/runtime/values/types" - "github.com/wI2L/jettison" ) @@ -159,20 +157,15 @@ func (h *HTTPHeaders) Get(key string) string { return textproto.MIMEHeader(h.values).Get(key) } -func (h *HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, error) { +func (h *HTTPHeaders) GetIn(_ context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return values.None, nil } - segment := path[0] - - err := core.ValidateType(segment, types.String) - - if err != nil { - return values.None, err - } + segmentIx := 0 + segment := path[segmentIx] - return values.NewString(h.Get(segment.String())), nil + return values.NewString(h.Get(string(values.ToString(segment)))), nil } func (h *HTTPHeaders) ForEach(predicate func(value []string, key string) bool) { diff --git a/pkg/drivers/http/document.go b/pkg/drivers/http/document.go index 30710949..a4e08025 100644 --- a/pkg/drivers/http/document.go +++ b/pkg/drivers/http/document.go @@ -134,12 +134,12 @@ func (doc *HTMLDocument) Iterate(_ context.Context) (core.Iterator, error) { return common.NewIterator(doc.element) } -func (doc *HTMLDocument) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { - return common.GetInDocument(ctx, doc, path) +func (doc *HTMLDocument) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + return common.GetInDocument(ctx, path, doc) } -func (doc *HTMLDocument) SetIn(ctx context.Context, path []core.Value, value core.Value) error { - return common.SetInDocument(ctx, doc, path, value) +func (doc *HTMLDocument) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { + return common.SetInDocument(ctx, path, doc, value) } func (doc *HTMLDocument) GetNodeType(_ context.Context) (values.Int, error) { diff --git a/pkg/drivers/http/element.go b/pkg/drivers/http/element.go index c5a9bb52..91774e02 100644 --- a/pkg/drivers/http/element.go +++ b/pkg/drivers/http/element.go @@ -481,12 +481,12 @@ func (el *HTMLElement) ExistsBySelector(_ context.Context, selector values.Strin return values.True, nil } -func (el *HTMLElement) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { - return common.GetInElement(ctx, el, path) +func (el *HTMLElement) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + return common.GetInElement(ctx, path, el) } -func (el *HTMLElement) SetIn(ctx context.Context, path []core.Value, value core.Value) error { - return common.SetInElement(ctx, el, path, value) +func (el *HTMLElement) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { + return common.SetInElement(ctx, path, el, value) } func (el *HTMLElement) Iterate(_ context.Context) (core.Iterator, error) { diff --git a/pkg/drivers/http/page.go b/pkg/drivers/http/page.go index c69e782c..8b88fd79 100644 --- a/pkg/drivers/http/page.go +++ b/pkg/drivers/http/page.go @@ -108,12 +108,12 @@ func (p *HTMLPage) Iterate(ctx context.Context) (core.Iterator, error) { return p.document.Iterate(ctx) } -func (p *HTMLPage) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { - return common.GetInPage(ctx, p, path) +func (p *HTMLPage) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + return common.GetInPage(ctx, path, p) } -func (p *HTMLPage) SetIn(ctx context.Context, path []core.Value, value core.Value) error { - return common.SetInPage(ctx, p, path, value) +func (p *HTMLPage) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { + return common.SetInPage(ctx, path, p, value) } func (p *HTMLPage) Length() values.Int { diff --git a/pkg/drivers/response.go b/pkg/drivers/response.go index 8c4aaf63..c976f4f2 100644 --- a/pkg/drivers/response.go +++ b/pkg/drivers/response.go @@ -77,16 +77,19 @@ func (resp *HTTPResponse) MarshalJSON() ([]byte, error) { return jettison.MarshalOpts(responseMarshal(*resp), jettison.NoHTMLEscaping()) } -func (resp *HTTPResponse) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { +func (resp *HTTPResponse) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return resp, nil } - if typ := path[0].Type(); typ != types.String { - return values.None, core.TypeError(typ, types.String) + segmentIdx := 0 + segment := path[segmentIdx] + + if typ := segment.Type(); typ != types.String { + return values.None, core.NewPathError(core.TypeError(typ, types.String), segmentIdx) } - field := path[0].(values.String).String() + field := segment.String() switch field { case "url", "URL": @@ -100,7 +103,13 @@ func (resp *HTTPResponse) GetIn(ctx context.Context, path []core.Value) (core.Va return resp.Headers, nil } - return resp.Headers.GetIn(ctx, path[1:]) + out, pathErr := resp.Headers.GetIn(ctx, path[1:]) + + if pathErr != nil { + return values.None, core.NewPathErrorFrom(pathErr, segmentIdx) + } + + return out, nil case "responseTime": return values.NewFloat(resp.ResponseTime), nil diff --git a/pkg/runtime/core/errors.go b/pkg/runtime/core/errors.go index 93b34593..6ccbe91f 100644 --- a/pkg/runtime/core/errors.go +++ b/pkg/runtime/core/errors.go @@ -7,9 +7,65 @@ import ( "github.com/pkg/errors" ) -type SourceErrorDetail struct { - BaseError error - ComputeError error +type ( + SourceErrorDetail struct { + error + BaseError error + ComputeError error + } + + // PathError represents an interface of + // error type which returned when an error occurs during an execution of Getter.GetIn or Setter.SetIn functions + // and contains segment of a given path that caused the error. + PathError interface { + error + Cause() error + Segment() int + Format(path []Value) string + } + + // NativePathError represents a default implementation of GetterError interface. + NativePathError struct { + cause error + segment int + } +) + +// NewPathError is a constructor function of NativePathError struct. +func NewPathError(err error, segment int) PathError { + return &NativePathError{ + cause: err, + segment: segment, + } +} + +// NewPathErrorFrom is a constructor function of NativePathError struct +// that accepts nested PathError and original segment index. +// It sums indexes to get the correct one that points to original path. +func NewPathErrorFrom(err PathError, segment int) PathError { + return NewPathError(err.Cause(), err.Segment()+segment) +} + +func (e *NativePathError) Cause() error { + return e.cause +} + +func (e *NativePathError) Error() string { + return e.cause.Error() +} + +func (e *NativePathError) Segment() int { + return e.segment +} + +func (e *NativePathError) Format(path []Value) string { + err := e.cause + + if err == ErrInvalidPath && len(path) > e.segment { + return err.Error() + " '" + path[e.segment].String() + "'" + } + + return err.Error() } func (e *SourceErrorDetail) Error() string { @@ -30,6 +86,7 @@ var ( ErrNotImplemented = errors.New("not implemented") ErrNotSupported = errors.New("not supported") ErrNoMoreData = errors.New("no more data") + ErrInvalidPath = errors.New("cannot read property") ) const typeErrorTemplate = "expected %s, but got %s" diff --git a/pkg/runtime/core/expression.go b/pkg/runtime/core/expression.go index e170e0a8..3e6f62a5 100644 --- a/pkg/runtime/core/expression.go +++ b/pkg/runtime/core/expression.go @@ -2,6 +2,20 @@ package core import "context" -type Expression interface { - Exec(ctx context.Context, scope *Scope) (Value, error) +type ( + Expression interface { + Exec(ctx context.Context, scope *Scope) (Value, error) + } + + ExpressionFn struct { + fn func(ctx context.Context, scope *Scope) (Value, error) + } +) + +func NewExpressionFn(fn func(ctx context.Context, scope *Scope) (Value, error)) Expression { + return &ExpressionFn{fn} +} + +func (f *ExpressionFn) Exec(ctx context.Context, scope *Scope) (Value, error) { + return f.fn(ctx, scope) } diff --git a/pkg/runtime/core/getter.go b/pkg/runtime/core/getter.go new file mode 100644 index 00000000..ef73c80c --- /dev/null +++ b/pkg/runtime/core/getter.go @@ -0,0 +1,21 @@ +package core + +import "context" + +type ( + GetterPathIterator interface { + Path() []Value + Current() Value + CurrentIndex() int + } + + // Getter represents an interface of + // complex types that needs to be used to read values by path. + // The interface is created to let user-defined types be used in dot notation data access. + Getter interface { + GetIn(ctx context.Context, path []Value) (Value, PathError) + } + + // GetterFn represents a type of helper functions that implement complex path resolutions. + GetterFn func(ctx context.Context, path []Value, src Getter) (Value, PathError) +) diff --git a/pkg/runtime/core/setter.go b/pkg/runtime/core/setter.go new file mode 100644 index 00000000..d262c8dd --- /dev/null +++ b/pkg/runtime/core/setter.go @@ -0,0 +1,12 @@ +package core + +import "context" + +type ( + // Setter represents an interface of + // complex types that needs to be used to write values by path. + // The interface is created to let user-defined types be used in dot notation assignment. + Setter interface { + SetIn(ctx context.Context, path []Value, value Value) PathError + } +) diff --git a/pkg/runtime/core/value.go b/pkg/runtime/core/value.go index 8933f233..7fda37a4 100644 --- a/pkg/runtime/core/value.go +++ b/pkg/runtime/core/value.go @@ -28,18 +28,4 @@ type ( Iterator interface { Next(ctx context.Context) (value Value, key Value, err error) } - - // Getter represents an interface of - // complex types that needs to be used to read values by path. - // The interface is created to let user-defined types be used in dot notation data access. - Getter interface { - GetIn(ctx context.Context, path []Value) (Value, error) - } - - // Setter represents an interface of - // complex types that needs to be used to write values by path. - // The interface is created to let user-defined types be used in dot notation assignment. - Setter interface { - SetIn(ctx context.Context, path []Value, value Value) error - } ) diff --git a/pkg/runtime/expressions/member.go b/pkg/runtime/expressions/member.go index 44caf4d2..5c7e6244 100644 --- a/pkg/runtime/expressions/member.go +++ b/pkg/runtime/expressions/member.go @@ -2,18 +2,20 @@ package expressions import ( "context" + "github.com/pkg/errors" "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/values" ) type MemberExpression struct { - src core.SourceMap - source core.Expression - path []*MemberPathSegment + src core.SourceMap + source core.Expression + path []*MemberPathSegment + preCompiledPath []core.Value } -func NewMemberExpression(src core.SourceMap, source core.Expression, path []*MemberPathSegment) (*MemberExpression, error) { +func NewMemberExpression(src core.SourceMap, source core.Expression, path []*MemberPathSegment, preCompiledPath []core.Value) (*MemberExpression, error) { if source == nil { return nil, core.Error(core.ErrMissedArgument, "source") } @@ -22,11 +24,11 @@ func NewMemberExpression(src core.SourceMap, source core.Expression, path []*Mem return nil, core.Error(core.ErrMissedArgument, "path expressions") } - return &MemberExpression{src, source, path}, nil + return &MemberExpression{src, source, path, preCompiledPath}, nil } func (e *MemberExpression) Exec(ctx context.Context, scope *core.Scope) (core.Value, error) { - val, err := e.source.Exec(ctx, scope) + member, err := e.source.Exec(ctx, scope) if err != nil { if e.path[0].optional { @@ -39,28 +41,49 @@ func (e *MemberExpression) Exec(ctx context.Context, scope *core.Scope) (core.Va ) } - out := val - path := make([]core.Value, 1) + var segments = e.preCompiledPath - for _, seg := range e.path { - segment, err := seg.exp.Exec(ctx, scope) + if e.preCompiledPath == nil { + segments = make([]core.Value, len(e.path)) - if err != nil { - return values.None, err + // unfold the path + for i, seg := range e.path { + segment, err := seg.exp.Exec(ctx, scope) + + if err != nil { + return values.None, err + } + + segments[i] = segment } + } - path[0] = segment - c, err := values.GetIn(ctx, out, path) + var pathErr core.PathError + var out core.Value = values.None - if err != nil { - if !seg.optional { - return values.None, core.SourceError(e.src, err) - } + getter, ok := member.(core.Getter) - return values.None, nil + if ok { + out, pathErr = getter.GetIn(ctx, segments) + } else { + out, pathErr = values.GetIn(ctx, member, segments) + } + + if pathErr != nil { + segmentIdx := pathErr.Segment() + // if invalid index is returned, we ignore the optionality check + // and return the pathErr + if segmentIdx >= len(e.path) { + return values.None, errors.New(pathErr.Format(segments)) + } + + segment := e.path[segmentIdx] + + if !segment.optional { + return values.None, errors.New(pathErr.Format(segments)) } - out = c + return values.None, nil } return out, nil diff --git a/pkg/runtime/expressions/member_test.go b/pkg/runtime/expressions/member_test.go new file mode 100644 index 00000000..1dbd5e08 --- /dev/null +++ b/pkg/runtime/expressions/member_test.go @@ -0,0 +1,295 @@ +package expressions_test + +import ( + "context" + "github.com/MontFerret/ferret/pkg/runtime/core" + "github.com/MontFerret/ferret/pkg/runtime/expressions" + "github.com/MontFerret/ferret/pkg/runtime/values" + "github.com/MontFerret/ferret/pkg/runtime/values/types" + "github.com/stretchr/testify/mock" + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +type ( + TestObject struct { + mock.Mock + *values.Object + failAt string + } +) + +func NewTestObject() *TestObject { + o := new(TestObject) + o.Object = values.NewObject() + + return o +} + +func (to *TestObject) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { + to.Mock.Called(path) + + var current core.Value = to.Object + + for i, segment := range path { + if segment.String() == to.failAt { + return values.None, core.NewPathError(core.ErrInvalidPath, i) + } + + next, err := values.GetIn(ctx, current, []core.Value{segment}) + + if err != nil { + return values.None, core.NewPathError(err, i) + } + + current = next + } + + return current, nil +} + +func TestMemberExpression(t *testing.T) { + Convey(".Exec", t, func() { + Convey("Should use .Getter interface if a source implements it", func() { + o := NewTestObject() + o.Set("foo", values.NewObjectWith( + values.NewObjectProperty("bar", values.NewObjectWith( + values.NewObjectProperty("baz", values.NewObject()), + )), + )) + + args := []core.Value{ + values.NewString("foo"), + values.NewString("bar"), + values.NewString("baz"), + } + + o.On("GetIn", args) + + s1, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[0], nil + }), + false, + ) + + s2, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[1], nil + }), + false, + ) + + s3, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[2], nil + }), + false, + ) + + segments := []*expressions.MemberPathSegment{s1, s2, s3} + + exp, err := expressions.NewMemberExpression( + core.SourceMap{}, + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return o, nil + }), + segments, + nil, + ) + + So(err, ShouldBeNil) + + root, cancel := core.NewRootScope() + + defer func() { + if err := cancel(); err != nil { + panic(err) + } + }() + + out, err := exp.Exec(context.Background(), root.Fork()) + So(err, ShouldBeNil) + So(out.Type().String(), ShouldNotEqual, types.None.String()) + + o.AssertExpectations(t) + }) + + Convey("Should use generic traverse logic if a source does not implement Getter interface", func() { + o := values.NewString("abcdefg") + + args := []core.Value{ + values.NewInt(0), + } + + s1, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[0], nil + }), + false, + ) + + segments := []*expressions.MemberPathSegment{s1} + + exp, err := expressions.NewMemberExpression( + core.SourceMap{}, + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return o, nil + }), + segments, + nil, + ) + + So(err, ShouldBeNil) + + root, cancel := core.NewRootScope() + + defer func() { + if err := cancel(); err != nil { + panic(err) + } + }() + + out, err := exp.Exec(context.Background(), root.Fork()) + So(err, ShouldBeNil) + So(out.String(), ShouldEqual, "a") + }) + + Convey("When path is not optional", func() { + Convey("Should return an error if it occurs during path resolution", func() { + o := NewTestObject() + o.failAt = "bar" + o.Set("foo", values.NewObjectWith( + values.NewObjectProperty("bar", values.NewObjectWith( + values.NewObjectProperty("baz", values.NewObject()), + )), + )) + + args := []core.Value{ + values.NewString("foo"), + values.NewString("bar"), + values.NewString("baz"), + } + + o.On("GetIn", mock.Anything) + + s1, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[0], nil + }), + false, + ) + + s2, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[1], nil + }), + false, + ) + + s3, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[2], nil + }), + false, + ) + + segments := []*expressions.MemberPathSegment{s1, s2, s3} + + exp, err := expressions.NewMemberExpression( + core.SourceMap{}, + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return o, nil + }), + segments, + nil, + ) + + So(err, ShouldBeNil) + + root, cancel := core.NewRootScope() + + defer func() { + if err := cancel(); err != nil { + panic(err) + } + }() + + _, err = exp.Exec(context.Background(), root.Fork()) + So(err, ShouldNotBeNil) + So(err.Error(), ShouldEqual, core.NewPathError(core.ErrInvalidPath, 1).Format(args)) + + o.AssertExpectations(t) + }) + }) + + Convey("When path is optional", func() { + Convey("Should return None if it occurs during path resolution", func() { + o := NewTestObject() + o.failAt = "bar" + o.Set("foo", values.NewObjectWith( + values.NewObjectProperty("bar", values.NewObjectWith( + values.NewObjectProperty("baz", values.NewObject()), + )), + )) + + args := []core.Value{ + values.NewString("foo"), + values.NewString("bar"), + values.NewString("baz"), + } + + o.On("GetIn", mock.Anything) + + s1, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[0], nil + }), + true, + ) + + s2, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[1], nil + }), + true, + ) + + s3, _ := expressions.NewMemberPathSegment( + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return args[2], nil + }), + true, + ) + + segments := []*expressions.MemberPathSegment{s1, s2, s3} + + exp, err := expressions.NewMemberExpression( + core.SourceMap{}, + core.NewExpressionFn(func(ctx context.Context, scope *core.Scope) (core.Value, error) { + return o, nil + }), + segments, + nil, + ) + + So(err, ShouldBeNil) + + root, cancel := core.NewRootScope() + + defer func() { + if err := cancel(); err != nil { + panic(err) + } + }() + + out, err := exp.Exec(context.Background(), root.Fork()) + So(err, ShouldBeNil) + So(out.Type().String(), ShouldEqual, values.None.Type().String()) + + o.AssertExpectations(t) + }) + }) + }) +} diff --git a/pkg/runtime/values/array.go b/pkg/runtime/values/array.go index 7b523eb3..579bf412 100644 --- a/pkg/runtime/values/array.go +++ b/pkg/runtime/values/array.go @@ -277,28 +277,34 @@ func (t *Array) SortWith(sorter ArraySorter) *Array { return res } -func (t *Array) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { +func (t *Array) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return None, nil } - if typ := path[0].Type(); typ != types.Int { - return None, core.TypeError(typ, types.Int) + segmentIdx := 0 + + if typ := path[segmentIdx].Type(); typ != types.Int { + return None, core.NewPathError(core.TypeError(typ, types.Int), segmentIdx) } - first := t.Get(path[0].(Int)) + first := t.Get(path[segmentIdx].(Int)) if len(path) == 1 { return first, nil } + segmentIdx++ + + if first == None || first == nil { + return None, core.NewPathError(core.ErrInvalidPath, segmentIdx) + } + getter, ok := first.(core.Getter) + if !ok { - return None, core.TypeError( - first.Type(), - core.NewType("Getter"), - ) + return GetIn(ctx, first, path[segmentIdx:]) } - return getter.GetIn(ctx, path[1:]) + return getter.GetIn(ctx, path[segmentIdx:]) } diff --git a/pkg/runtime/values/helpers.go b/pkg/runtime/values/helpers.go index 979a1ccf..4f6e4939 100644 --- a/pkg/runtime/values/helpers.go +++ b/pkg/runtime/values/helpers.go @@ -15,21 +15,53 @@ import ( ) // GetIn checks that from implements core.Getter interface. If it implements, -// GetIn call from.GetIn method, otherwise return error. -func GetIn(ctx context.Context, from core.Value, byPath []core.Value) (core.Value, error) { - getter, ok := from.(core.Getter) - - if !ok { - return None, core.TypeError( - from.Type(), - core.NewType("Getter"), - ) +// GetIn call from.GetIn method, otherwise iterates over values and tries to resolve a given path. +func GetIn(ctx context.Context, from core.Value, byPath []core.Value) (core.Value, core.PathError) { + if len(byPath) == 0 { + return None, nil + } + + var result = from + + for i, segment := range byPath { + if result == None || result == nil { + break + } + + segType := segment.Type() + + switch curVal := result.(type) { + case *Object: + result, _ = curVal.Get(ToString(segment)) + case *Array: + if segType != types.Int { + return nil, core.NewPathError( + core.TypeError(segType, types.Int), + i, + ) + } + + result = curVal.Get(segment.(Int)) + case String: + if segType != types.Int { + return nil, core.NewPathError( + core.TypeError(segType, types.Int), + i, + ) + } + + result = curVal.At(ToInt(segment)) + case core.Getter: + return curVal.GetIn(ctx, byPath[i:]) + default: + return None, core.NewPathError(core.ErrInvalidPath, i) + } } - return getter.GetIn(ctx, byPath) + return result, nil } -func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.Value) error { +func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.Value) core.PathError { if len(byPath) == 0 { return nil } @@ -46,7 +78,10 @@ func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.V switch parVal := parent.(type) { case *Object: if segmentType != types.String { - return core.TypeError(segmentType, types.String) + return core.NewPathError( + core.TypeError(segmentType, types.String), + idx, + ) } if !isTarget { @@ -56,14 +91,17 @@ func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.V } case *Array: if segmentType != types.Int { - return core.TypeError(segmentType, types.Int) + return core.NewPathError( + core.TypeError(segmentType, types.Int), + idx, + ) } if !isTarget { current = parVal.Get(segment.(Int)) } else { if err := parVal.Set(segment.(Int), value); err != nil { - return err + return core.NewPathError(err, idx) } } case core.Setter: @@ -78,7 +116,7 @@ func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.V parent = obj if segmentType != types.String { - return core.TypeError(segmentType, types.String) + return core.NewPathError(core.TypeError(segmentType, types.String), idx) } if isTarget { @@ -90,13 +128,21 @@ func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.V if isTarget { if err := arr.Set(segment.(Int), value); err != nil { - return err + return core.NewPathError(err, idx) } } } // set new parent - if err := SetIn(ctx, to, byPath[0:idx-1], parent); err != nil { + nextPath := byPath + + if idx > 0 { + nextPath = byPath[0 : idx-1] + } else { + nextPath = byPath[0:] + } + + if err := SetIn(ctx, to, nextPath, parent); err != nil { return err } @@ -109,6 +155,30 @@ func SetIn(ctx context.Context, to core.Value, byPath []core.Value, value core.V return nil } +func ReturnOrNext(ctx context.Context, path []core.Value, idx int, out core.Value, err error) (core.Value, core.PathError) { + if err != nil { + pathErr, ok := err.(core.PathError) + + if ok { + return None, core.NewPathErrorFrom(pathErr, idx) + } + + return None, core.NewPathError(err, idx) + } + + if len(path) > (idx + 1) { + out, pathErr := GetIn(ctx, out, path[idx+1:]) + + if pathErr != nil { + return None, core.NewPathErrorFrom(pathErr, idx) + } + + return out, nil + } + + return out, nil +} + func Parse(input interface{}) core.Value { switch value := input.(type) { case bool: diff --git a/pkg/runtime/values/helpers_test.go b/pkg/runtime/values/helpers_test.go index 61a7277e..5340efb8 100644 --- a/pkg/runtime/values/helpers_test.go +++ b/pkg/runtime/values/helpers_test.go @@ -46,7 +46,7 @@ func (t *CustomValue) Copy() core.Value { return values.None } -func (t *CustomValue) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { +func (t *CustomValue) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return values.None, nil } @@ -65,7 +65,7 @@ func (t *CustomValue) GetIn(ctx context.Context, path []core.Value) (core.Value, return values.GetIn(ctx, propValue, path[1:]) } -func (t *CustomValue) SetIn(ctx context.Context, path []core.Value, value core.Value) error { +func (t *CustomValue) SetIn(ctx context.Context, path []core.Value, value core.Value) core.PathError { if len(path) == 0 { return nil } diff --git a/pkg/runtime/values/object.go b/pkg/runtime/values/object.go index fd5d49ed..59adbc6e 100644 --- a/pkg/runtime/values/object.go +++ b/pkg/runtime/values/object.go @@ -296,28 +296,29 @@ func (t *Object) Clone() core.Cloneable { return cloned } -func (t *Object) GetIn(ctx context.Context, path []core.Value) (core.Value, error) { +func (t *Object) GetIn(ctx context.Context, path []core.Value) (core.Value, core.PathError) { if len(path) == 0 { return None, nil } - if typ := path[0].Type(); typ != types.String { - return None, core.TypeError(typ, types.String) - } - - first, _ := t.Get(path[0].(String)) + segmentIdx := 0 + first, _ := t.Get(ToString(path[segmentIdx])) if len(path) == 1 { return first, nil } + segmentIdx++ + + if first == None || first == nil { + return None, core.NewPathError(core.ErrInvalidPath, segmentIdx) + } + getter, ok := first.(core.Getter) + if !ok { - return None, core.TypeError( - first.Type(), - core.NewType("Getter"), - ) + return GetIn(ctx, first, path[segmentIdx:]) } - return getter.GetIn(ctx, path[1:]) + return getter.GetIn(ctx, path[segmentIdx:]) } diff --git a/pkg/runtime/values/object_test.go b/pkg/runtime/values/object_test.go index 7f4fcf9c..a9fac337 100644 --- a/pkg/runtime/values/object_test.go +++ b/pkg/runtime/values/object_test.go @@ -444,14 +444,15 @@ func TestObject(t *testing.T) { Convey("Should error when input is not correct", func() { - Convey("Should error when path[0] is not a string", func() { + Convey("Should return None when path[0] is not a string", func() { obj := values.NewObject() path := []core.Value{values.NewInt(0)} el, err := obj.GetIn(ctx, path) - So(err, ShouldBeError) - So(el.Compare(values.None), ShouldEqual, 0) + So(err, ShouldBeNil) + So(el, ShouldNotBeNil) + So(el.Type().String(), ShouldEqual, types.None.String()) }) Convey("Should error when first received item is not a Getter and len(path) > 1", func() { diff --git a/pkg/runtime/values/string.go b/pkg/runtime/values/string.go index db8bb836..82bd6518 100644 --- a/pkg/runtime/values/string.go +++ b/pkg/runtime/values/string.go @@ -121,3 +121,7 @@ func (t String) IndexOf(other String) Int { func (t String) Concat(other core.Value) String { return String(string(t) + other.String()) } + +func (t String) At(index Int) String { + return String([]rune(t)[index]) +} diff --git a/pkg/runtime/values/string_test.go b/pkg/runtime/values/string_test.go index 46dd6e32..c33e0329 100644 --- a/pkg/runtime/values/string_test.go +++ b/pkg/runtime/values/string_test.go @@ -63,5 +63,12 @@ func TestString(t *testing.T) { So(string(json2), ShouldEqual, fmt.Sprintf(`"%s"`, value)) }) }) + Convey(".At", t, func() { + Convey("It should return a character", func() { + v := values.NewString("abcdefg") + c := v.At(2) + So(string(c), ShouldEqual, "c") + }) + }) }