diff --git a/pkg/compiler/loops.go b/pkg/compiler/loops.go index 4aac23dd..7e27c1f8 100644 --- a/pkg/compiler/loops.go +++ b/pkg/compiler/loops.go @@ -3,13 +3,21 @@ package compiler import "github.com/MontFerret/ferret/pkg/runtime" type ( + LoopType int + + LoopKind int + Loop struct { - PassThrough bool - Distinct bool - Result runtime.Operand - Iterator runtime.Operand + Type LoopType + Kind LoopKind + Distinct bool Allocate bool Next int + Src runtime.Operand + Iterator runtime.Operand + Value runtime.Operand + Key runtime.Operand + Result runtime.Operand } LoopTable struct { @@ -18,6 +26,18 @@ type ( } ) +const ( + NormalLoop LoopType = iota + PassThroughLoop + TemporalLoop +) + +const ( + ForLoop LoopKind = iota + WhileLoop + DoWhileLoop +) + func NewLoopTable(registers *RegisterAllocator) *LoopTable { return &LoopTable{ loops: make([]*Loop, 0), @@ -25,19 +45,19 @@ func NewLoopTable(registers *RegisterAllocator) *LoopTable { } } -func (lt *LoopTable) EnterLoop(passThrough, distinct bool) *Loop { +func (lt *LoopTable) EnterLoop(loopType LoopType, kind LoopKind, distinct bool) *Loop { var allocate bool var state runtime.Operand // top loop if len(lt.loops) == 0 { allocate = true - } else if !passThrough { + } else if loopType != PassThroughLoop { // nested with explicit RETURN expression prev := lt.loops[len(lt.loops)-1] // if the loop above does not do pass through // we allocate a new state for this loop - allocate = !prev.PassThrough + allocate = prev.Type != PassThroughLoop state = prev.Result } else { // nested with implicit RETURN expression @@ -50,10 +70,11 @@ func (lt *LoopTable) EnterLoop(passThrough, distinct bool) *Loop { } lt.loops = append(lt.loops, &Loop{ - PassThrough: passThrough, - Distinct: distinct, - Result: state, - Allocate: allocate, + Type: loopType, + Kind: kind, + Distinct: distinct, + Result: state, + Allocate: allocate, }) return lt.loops[len(lt.loops)-1] diff --git a/pkg/compiler/visitor.go b/pkg/compiler/visitor.go index 89a7b404..b45a2c60 100644 --- a/pkg/compiler/visitor.go +++ b/pkg/compiler/visitor.go @@ -103,12 +103,9 @@ func (v *visitor) VisitHead(_ *fql.HeadContext) interface{} { func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} { v.symbols.EnterScope() - var passThrough bool var distinct bool var returnRuleCtx antlr.RuleContext var jumpOffset int - // identify whether it's WHILE or FOR loop - isForLoop := ctx.While() == nil returnCtx := ctx.ForExpressionReturn() if c := returnCtx.ReturnExpression(); c != nil { @@ -116,25 +113,23 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} distinct = c.Distinct() != nil } else if c := returnCtx.ForExpression(); c != nil { returnRuleCtx = c - passThrough = true } - loop := v.loops.EnterLoop(passThrough, distinct) + loop := v.loops.EnterLoop(v.loopType(ctx), v.loopKind(ctx), distinct) dsReg := loop.Result if loop.Allocate { v.emitter.EmitAb(runtime.OpLoopBegin, dsReg, distinct) } - if isForLoop { + if loop.Kind == ForLoop { // Loop data source to iterate over - src1 := ctx.ForExpressionSource().Accept(v).(runtime.Operand) + loop.Src = ctx.ForExpressionSource().Accept(v).(runtime.Operand) + loop.Iterator = v.registers.Allocate(State) - iterReg := v.registers.Allocate(State) - - v.emitter.EmitAB(runtime.OpForLoopInit, iterReg, src1) + v.emitter.EmitAB(runtime.OpForLoopInit, loop.Iterator, loop.Src) // jumpPlaceholder is a placeholder for the exit jump position - loop.Next = v.emitter.EmitJumpc(runtime.OpForLoopNext, jumpPlaceholder, iterReg) + loop.Next = v.emitter.EmitJumpc(runtime.OpForLoopNext, jumpPlaceholder, loop.Iterator) valVar := ctx.GetValueVariable().GetText() counterVarCtx := ctx.GetCounterVariable() @@ -148,20 +143,16 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} hasCounterVar = true } - var valReg runtime.Operand - // declare value variable if hasValVar { - valReg = v.symbols.DefineVariable(valVar) - v.emitter.EmitAB(runtime.OpForLoopValue, valReg, iterReg) + loop.Value = v.symbols.DefineVariable(valVar) + v.emitter.EmitAB(runtime.OpForLoopValue, loop.Value, loop.Iterator) } - var keyReg runtime.Operand - if hasCounterVar { // declare counter variable - keyReg = v.symbols.DefineVariable(counterVar) - v.emitter.EmitAB(runtime.OpForLoopKey, keyReg, iterReg) + loop.Key = v.symbols.DefineVariable(counterVar) + v.emitter.EmitAB(runtime.OpForLoopKey, loop.Key, loop.Iterator) } } else { counterReg := v.registers.Allocate(State) @@ -195,7 +186,7 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} } // RETURN - if !passThrough { + if loop.Type != PassThroughLoop { c := returnRuleCtx.(*fql.ReturnExpressionContext) expReg := c.Expression().Accept(v).(runtime.Operand) @@ -213,13 +204,13 @@ func (v *visitor) VisitForExpression(ctx *fql.ForExpressionContext) interface{} // TODO: Reuse the dsReg register v.emitter.EmitAB(runtime.OpLoopEnd, dst, dsReg) - if isForLoop { + if loop.Kind == ForLoop { v.emitter.PatchJump(jumpIndex) } else { v.emitter.PatchJumpAB(jumpIndex) } } else { - if isForLoop { + if loop.Kind == ForLoop { v.emitter.PatchJumpNext(jumpIndex) } else { v.emitter.PatchJumpNextAB(jumpIndex) @@ -1081,6 +1072,26 @@ func (v *visitor) functionName(ctx *fql.FunctionCallContext) values.String { return values.NewString(strings.ToUpper(name)) } +func (v *visitor) loopType(ctx *fql.ForExpressionContext) LoopType { + if c := ctx.ForExpressionReturn().ForExpression(); c == nil { + return NormalLoop + } + + return PassThroughLoop +} + +func (v *visitor) loopKind(ctx *fql.ForExpressionContext) LoopKind { + if ctx.While() == nil { + return ForLoop + } + + if ctx.Do() == nil { + return WhileLoop + } + + return DoWhileLoop +} + func (v *visitor) loadConstant(constant core.Value) runtime.Operand { reg := v.registers.Allocate(Temp) v.emitter.EmitAB(runtime.OpLoadConst, reg, v.symbols.AddConstant(constant))