From fb4d383a216ff95c28d45985bac85ed551eec3f0 Mon Sep 17 00:00:00 2001 From: Tim Voronov Date: Wed, 6 Nov 2024 11:29:48 -0500 Subject: [PATCH] Fixes --- pkg/compiler/compiler_exec_test.go | 100 ++++++++++++++--------------- pkg/compiler/visitor.go | 54 ++++++++-------- pkg/runtime/opcode.go | 2 + pkg/runtime/values/helpers.go | 16 +++++ pkg/runtime/vm.go | 18 ++++++ pkg/stdlib/collections/length.go | 36 ----------- pkg/stdlib/collections/lib.go | 1 - pkg/stdlib/testing/empty.go | 3 +- pkg/stdlib/testing/len.go | 3 +- 9 files changed, 114 insertions(+), 119 deletions(-) delete mode 100644 pkg/stdlib/collections/length.go diff --git a/pkg/compiler/compiler_exec_test.go b/pkg/compiler/compiler_exec_test.go index 6d655818..d884c115 100644 --- a/pkg/compiler/compiler_exec_test.go +++ b/pkg/compiler/compiler_exec_test.go @@ -151,51 +151,51 @@ func TestVariables(t *testing.T) { }, }) - //Convey("Should compile LET i = (FOR i WHILE COUNTER() < 5 RETURN i) RETURN i", t, func() { - // c := compiler.New() - // counter := -1 - // c.RegisterFunction("COUNTER", func(ctx context.visitor, args ...core.Second) (core.Second, error) { - // counter++ - // - // return values.NewInt(counter), nil - // }) - // - // p, err := c.Compile(` - // LET i = (FOR i WHILE COUNTER() < 5 RETURN i) - // RETURN i - // `) - // - // So(err, ShouldBeNil) - // So(p, ShouldHaveSameTypeAs, &runtime.Program{}) - // - // out, err := p.Run(context.Background()) - // - // So(err, ShouldBeNil) - // So(string(out), ShouldEqual, "[0,1,2,3,4]") - //}) - // - //Convey("Should compile LET i = (FOR i WHILE COUNTER() < 5 T::FAIL() RETURN i)? RETURN i == NONE", t, func() { - // c := compiler.New() - // counter := -1 - // c.RegisterFunction("COUNTER", func(ctx context.visitor, args ...core.Second) (core.Second, error) { - // counter++ - // - // return values.NewInt(counter), nil - // }) - // - // p, err := c.Compile(` - // LET i = (FOR i WHILE COUNTER() < 5 T::FAIL() RETURN i)? - // RETURN i == NONE - // `) - // - // So(err, ShouldBeNil) - // So(p, ShouldHaveSameTypeAs, &runtime.Program{}) - // - // out, err := p.Run(context.Background()) - // - // So(err, ShouldBeNil) - // So(string(out), ShouldEqual, "true") - //}) + Convey("Should compile LET i = (FOR i WHILE COUNTER() < 5 RETURN i) RETURN i", t, func() { + c := compiler.New() + + p, err := c.Compile(` + LET i = (FOR i WHILE COUNTER() < 5 RETURN i) + RETURN i + `) + + So(err, ShouldBeNil) + So(p, ShouldHaveSameTypeAs, &runtime.Program{}) + + counter := -1 + out, err := Run(p, runtime.WithFunction("COUNTER", func(ctx context.Context, args ...core.Value) (core.Value, error) { + counter++ + + return values.NewInt(counter), nil + })) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, "[0,1,2,3,4]") + }) + + Convey("Should compile LET i = (FOR i WHILE COUNTER() < 5 T::FAIL() RETURN i)? RETURN i == NONE", t, func() { + c := compiler.New() + + p, err := c.Compile(` + LET i = (FOR i WHILE COUNTER() < 5 T::FAIL() RETURN i)? + RETURN length(i) == 0 + `) + + So(err, ShouldBeNil) + So(p, ShouldHaveSameTypeAs, &runtime.Program{}) + + counter := -1 + out, err := Run(p, runtime.WithFunction("COUNTER", func(ctx context.Context, args ...core.Value) (core.Value, error) { + counter++ + + return values.NewInt(counter), nil + }), runtime.WithFunction("T::FAIL", func(ctx context.Context, args ...core.Value) (core.Value, error) { + return values.None, fmt.Errorf("test") + })) + + So(err, ShouldBeNil) + So(string(out), ShouldEqual, "true") + }) Convey("Should not compile FOR foo IN foo", t, func() { c := compiler.New() @@ -1088,11 +1088,11 @@ func TestFor(t *testing.T) { func TestForWhile(t *testing.T) { var counter int64 RunUseCases(t, []UseCase{ - //{ - // "FOR i WHILE false RETURN i", - // []any{}, - // ShouldEqualJSON, - //}, + { + "FOR i WHILE false RETURN i", + []any{}, + ShouldEqualJSON, + }, { "FOR i WHILE UNTIL(5) RETURN i", []any{0, 1, 2, 3, 4}, diff --git a/pkg/compiler/visitor.go b/pkg/compiler/visitor.go index e6959b33..a25b1c0a 100644 --- a/pkg/compiler/visitor.go +++ b/pkg/compiler/visitor.go @@ -105,9 +105,9 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} var passThrough bool var distinct bool var returnRuleCtx antlr.RuleContext + var jumpOffset int // identify whether it's WHILE or FOR loop isForLoop := ctx.While() == nil - var isWhileFn bool returnCtx := ctx.ForExpressionReturn() if c := returnCtx.ReturnExpression(); c != nil { @@ -168,21 +168,14 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} // Create initial value for the loop counter v.emitter.EmitA(runtime.OpWhileLoopInit, counterReg) + beforeExp := v.emitter.Size() // Loop data source to iterate over cond := srcExpr.Accept(v).(runtime.Operand) + jumpOffset = v.emitter.Size() - beforeExp // jumpPlaceholder is a placeholder for the exit jump position loop.Next = v.emitter.EmitJumpAB(runtime.OpWhileLoopNext, counterReg, cond, jumpPlaceholder) - // Fix jump for function calls - if predicate := srcExpr.Predicate(); predicate != nil { - if atom := predicate.ExpressionAtom(); atom != nil { - if fcExpr := atom.FunctionCallExpression(); fcExpr != nil { - isWhileFn = true - } - } - } - counterVar := ctx.GetCounterVariable().GetText() // declare counter variable @@ -207,17 +200,7 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} returnRuleCtx.Accept(v) } - if isForLoop { - v.emitter.EmitJump(runtime.OpJump, loop.Next) - } else { - - if !isWhileFn { - v.emitter.EmitJump(runtime.OpJump, loop.Next-1) - } else { - v.emitter.EmitJump(runtime.OpJump, loop.Next-2) - } - - } + v.emitter.EmitJump(runtime.OpJump, loop.Next-jumpOffset) // TODO: Do not allocate for pass-through loops dst := v.registers.Allocate(Temp) @@ -976,15 +959,30 @@ func (v *visitor) visitFunctionCall(ctx *fql.FunctionCallContext, safeCall bool) } } - nameAndDest := v.loadConstant(v.functionName(ctx)) + name := v.functionName(ctx) - if !safeCall { - v.emitter.EmitAs(runtime.OpCall, nameAndDest, seq) - } else { - v.emitter.EmitAs(runtime.OpCallSafe, nameAndDest, seq) - } + switch name { + case "LENGTH": + dst := v.registers.Allocate(Temp) + + if seq == nil || len(seq.Registers) > 1 { + panic(core.Error(core.ErrInvalidArgument, "LENGTH: expected 1 argument")) + } + + v.emitter.EmitAB(runtime.OpLength, dst, seq.Registers[0]) + + return dst + default: + nameAndDest := v.loadConstant(v.functionName(ctx)) + + if !safeCall { + v.emitter.EmitAs(runtime.OpCall, nameAndDest, seq) + } else { + v.emitter.EmitAs(runtime.OpCallSafe, nameAndDest, seq) + } - return nameAndDest + return nameAndDest + } } func (v *visitor) functionName(ctx *fql.FunctionCallContext) values.String { diff --git a/pkg/runtime/opcode.go b/pkg/runtime/opcode.go index 09b94c0d..55ee251b 100644 --- a/pkg/runtime/opcode.go +++ b/pkg/runtime/opcode.go @@ -48,6 +48,8 @@ const ( OpLoadProperty OpLoadPropertyOptional + OpLength + OpCall OpCallSafe diff --git a/pkg/runtime/values/helpers.go b/pkg/runtime/values/helpers.go index 1a84c19b..82eb3064 100644 --- a/pkg/runtime/values/helpers.go +++ b/pkg/runtime/values/helpers.go @@ -585,3 +585,19 @@ func ToNumberOnly(input core.Value) core.Value { func CompareStrings(a, b String) Int { return Int(strings.Compare(a.String(), b.String())) } + +func Length(value core.Value) (Int, error) { + c, ok := value.(core.Measurable) + + if !ok { + return 0, core.TypeError(value, + types.String, + types.Array, + types.Object, + types.Binary, + types.Measurable, + ) + } + + return Int(c.Length()), nil +} diff --git a/pkg/runtime/vm.go b/pkg/runtime/vm.go index 0529c4ec..f1369435 100644 --- a/pkg/runtime/vm.go +++ b/pkg/runtime/vm.go @@ -26,6 +26,7 @@ func NewVM(program *Program) *VM { } func (vm *VM) Run(ctx context.Context, opts []EnvironmentOption) (core.Value, error) { + // TODO: Return jump position if an error occurred within a wrapped loop tryCatch := func(pos int) bool { for _, pair := range vm.program.CatchTable { if pos >= pair[0] && pos <= pair[1] { @@ -43,6 +44,7 @@ func (vm *VM) Run(ctx context.Context, opts []EnvironmentOption) (core.Value, er vm.pc = 0 program := vm.program + // TODO: Add panic handling and snapshot the last instruction and frame that caused it loop: for vm.pc < len(program.Bytecode) { inst := program.Bytecode[vm.pc] @@ -264,6 +266,22 @@ loop: } else { return nil, err } + case OpLength: + val, ok := reg[src1].(core.Measurable) + + if ok { + reg[dst] = values.NewInt(val.Length()) + } else if tryCatch(vm.pc) { + reg[dst] = values.ZeroInt + } else { + return values.None, core.TypeError(reg[src1], + types.String, + types.Array, + types.Object, + types.Binary, + types.Measurable, + ) + } case OpRange: res, err := operators.Range(reg[src1], reg[src2]) diff --git a/pkg/stdlib/collections/length.go b/pkg/stdlib/collections/length.go deleted file mode 100644 index a598562c..00000000 --- a/pkg/stdlib/collections/length.go +++ /dev/null @@ -1,36 +0,0 @@ -package collections - -import ( - "context" - - "github.com/MontFerret/ferret/pkg/runtime/core" - "github.com/MontFerret/ferret/pkg/runtime/values" - "github.com/MontFerret/ferret/pkg/runtime/values/types" -) - -// LENGTH returns the length of a measurable value. -// @param {Measurable} value - The value to measure. -// @return {Int} - The length of the value. -func Length(_ context.Context, inputs ...core.Value) (core.Value, error) { - err := core.ValidateArgs(inputs, 1, 1) - - if err != nil { - return values.None, err - } - - value := inputs[0] - - c, ok := value.(core.Measurable) - - if !ok { - return values.None, core.TypeError(value, - types.String, - types.Array, - types.Object, - types.Binary, - types.Measurable, - ) - } - - return values.Int(c.Length()), nil -} diff --git a/pkg/stdlib/collections/lib.go b/pkg/stdlib/collections/lib.go index d72beb08..f3eedc6e 100644 --- a/pkg/stdlib/collections/lib.go +++ b/pkg/stdlib/collections/lib.go @@ -6,7 +6,6 @@ func RegisterLib(ns core.Namespace) error { return ns.RegisterFunctions( core.NewFunctionsFromMap(map[string]core.Function{ "INCLUDES": Includes, - "LENGTH": Length, "REVERSE": Reverse, })) } diff --git a/pkg/stdlib/testing/empty.go b/pkg/stdlib/testing/empty.go index e9294205..03149516 100644 --- a/pkg/stdlib/testing/empty.go +++ b/pkg/stdlib/testing/empty.go @@ -5,7 +5,6 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/core" "github.com/MontFerret/ferret/pkg/runtime/values" - "github.com/MontFerret/ferret/pkg/stdlib/collections" "github.com/MontFerret/ferret/pkg/stdlib/testing/base" ) @@ -19,7 +18,7 @@ var Empty = base.Assertion{ MinArgs: 1, MaxArgs: 2, Fn: func(ctx context.Context, args []core.Value) (bool, error) { - size, err := collections.Length(ctx, args[0]) + size, err := values.Length(args[0]) if err != nil { return false, err diff --git a/pkg/stdlib/testing/len.go b/pkg/stdlib/testing/len.go index b307dd50..9d261359 100644 --- a/pkg/stdlib/testing/len.go +++ b/pkg/stdlib/testing/len.go @@ -7,7 +7,6 @@ import ( "github.com/MontFerret/ferret/pkg/runtime/values" "github.com/MontFerret/ferret/pkg/runtime/core" - "github.com/MontFerret/ferret/pkg/stdlib/collections" "github.com/MontFerret/ferret/pkg/stdlib/testing/base" ) @@ -25,7 +24,7 @@ var Len = base.Assertion{ col := args[0] size := args[1] - out, err := collections.Length(ctx, col) + out, err := values.Length(col) if err != nil { return false, err