diff --git a/src/closureps/cloenv.mbt b/src/closureps/cloenv.mbt index f6fbc04..8261415 100644 --- a/src/closureps/cloenv.mbt +++ b/src/closureps/cloenv.mbt @@ -21,7 +21,7 @@ fn CloEnv::collect_named_fns(self : CloEnv, c : @cps.Cps) -> Unit { self.collect_named_fns(body) self.collect_named_fns(rest) } - If(_, _then, _else) => { + If(_, _then, _else) | IfEq(_, _, _then, _else) | IfLe(_, _, _then, _else) => { self.collect_named_fns(_then) self.collect_named_fns(_else) } diff --git a/src/closureps/cps2closureps.mbt b/src/closureps/cps2closureps.mbt index e6deeee..31aa0ea 100644 --- a/src/closureps/cps2closureps.mbt +++ b/src/closureps/cps2closureps.mbt @@ -52,7 +52,7 @@ fn CloEnv::collect_label_closure( self.collect_label_closure(rest, output) output.remove(fn_name) } - If(_, _then, _else) => { + If(_, _then, _else) | IfEq(_, _, _then, _else) | IfLe(_, _, _then, _else) => { self.collect_label_closure(_then, output) self.collect_label_closure(_else, output) } @@ -125,6 +125,8 @@ fn CloEnv::collect_closure( self.add_rebind(bind, bind).collect_closure(rest, func_no_free_vars), ) If(_cond, _then, _else) => If(_cond, rec(_then), rec(_else)) + IfEq(lhs, rhs, _then, _else) => IfEq(lhs, rhs, rec(_then), rec(_else)) + IfLe(lhs, rhs, _then, _else) => IfLe(lhs, rhs, rec(_then), rec(_else)) Prim(op, args, bind, rest) => Prim( op, diff --git a/src/closureps_eval/interpreter.mbt b/src/closureps_eval/interpreter.mbt index c0a5f61..4bb3101 100644 --- a/src/closureps_eval/interpreter.mbt +++ b/src/closureps_eval/interpreter.mbt @@ -172,6 +172,17 @@ pub fn CloPSInterpreter::eval( Bool(b) => continue if b { _then } else { _else } v => @util.die("unexpected condition \{v} for `if`") } + IfLe(lhs, rhs, _then, _else) => + match (self.eval_v!(lhs), self.eval_v!(rhs)) { + (Double(a), Double(b)) => continue if a <= b { _then } else { _else } + (Int(a), Int(b)) => continue if a <= b { _then } else { _else } + (lhs, rhs) => @util.die("unexpected input \{lhs}, \{rhs} for `le`") + } + IfEq(lhs, rhs, _then, _else) => { + let lhs = self.eval_v!(lhs) + let rhs = self.eval_v!(rhs) + continue if lhs == rhs { _then } else { _else } + } Prim(Not, [v], bind, rest) => { match self.eval_v!(v) { Bool(b) => self.cur_env[bind] = Bool(not(b)) diff --git a/src/cps/cps_ir.mbt b/src/cps/cps_ir.mbt index d2102d7..8c424be 100644 --- a/src/cps/cps_ir.mbt +++ b/src/cps/cps_ir.mbt @@ -62,6 +62,8 @@ pub enum Cps { KthTuple(Int, Value, Var, Cps) Fix(Var, Array[Var], Cps, Cps) If(Value, Cps, Cps) + IfLe(Value, Value, Cps, Cps) + IfEq(Value, Value, Cps, Cps) Prim(PrimOp, Array[Value], Var, Cps) // T marks the return type App(Value, Array[Value]) @@ -98,6 +100,10 @@ fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps { Fix(name, args, body_new, rest_new) } If(cond, _then, _else) => If(recv(cond), rec(_then), rec(_else)) + IfEq(lhs, rhs, _then, _else) => + IfEq(recv(lhs), recv(rhs), rec(_then), rec(_else)) + IfLe(lhs, rhs, _then, _else) => + IfLe(recv(lhs), recv(rhs), rec(_then), rec(_else)) Prim(op, args, bind, rest) => { let rest_new = if from != bind { rec(rest) } else { rest } Prim(op, args.map(recv), bind, rest_new) @@ -140,6 +146,12 @@ pub fn Cps::free_variables(self : Cps) -> @hashset.T[Var] { .free_variables() .union(_then.free_variables()) .union(_else.free_variables()) + IfEq(lhs, rhs, _then, _else) | IfLe(lhs, rhs, _then, _else) => + lhs + .free_variables() + .union(rhs.free_variables()) + .union(_then.free_variables()) + .union(_else.free_variables()) Prim(_, args, bind, rest) => { let fv_rest = rest.free_variables() fv_rest.remove(bind) diff --git a/src/cps/cps_ir_string.mbt b/src/cps/cps_ir_string.mbt index 6bd8073..a7486f1 100644 --- a/src/cps/cps_ir_string.mbt +++ b/src/cps/cps_ir_string.mbt @@ -49,6 +49,28 @@ fn to_str(cps : Cps, ~ident : String = "") -> String { ident + "}\n" + rec(rest) + IfEq(lhs, rhs, _then, _else) => + ident + + "if(\{lhs} == \{rhs}){\n" + + to_str(_then, ident=ident + " ") + + "\n" + + ident + + "} else {\n" + + to_str(_else, ident=ident + " ") + + "\n" + + ident + + "}" + IfLe(lhs, rhs, _then, _else) => + ident + + "if(\{lhs} <= \{rhs}){\n" + + to_str(_then, ident=ident + " ") + + "\n" + + ident + + "} else {\n" + + to_str(_else, ident=ident + " ") + + "\n" + + ident + + "}" If(cond, _then, _else) => ident + "if(\{cond}){\n" + diff --git a/src/cps/precps2cps.mbt b/src/cps/precps2cps.mbt index b2687ba..35675e2 100644 --- a/src/cps/precps2cps.mbt +++ b/src/cps/precps2cps.mbt @@ -144,6 +144,62 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { self.precps2cps(tup, c1) } + If(ret_ty, Prim(Bool, Eq, [lhs, rhs]), _then, _else) => { + fn c1(lhs : Value) -> Cps { + // To avoid exponential growth in CPS ir, we abstract the outer `c` out. + fn c2(rhs : Value) -> Cps { + let k_ref = self.new_tmp(Fun([ret_ty], Unit)) + let x_ref = self.new_tmp(ret_ty) + fn c3(branch : Value) -> Cps { + App(Var(k_ref), [branch].map(fix_label_to_var)) + } + + Fix( + k_ref, + [x_ref], + c(Var(x_ref)), + IfEq( + lhs, + rhs, + self.precps2cps(_then, c3), + self.precps2cps(_else, c3), + ), + ) + } + + self.precps2cps(rhs, c2) + } + + self.precps2cps(lhs, c1) + } + If(ret_ty, Prim(Bool, Le, [lhs, rhs]), _then, _else) => { + fn c1(lhs : Value) -> Cps { + // To avoid exponential growth in CPS ir, we abstract the outer `c` out. + fn c2(rhs : Value) -> Cps { + let k_ref = self.new_tmp(Fun([ret_ty], Unit)) + let x_ref = self.new_tmp(ret_ty) + fn c3(branch : Value) -> Cps { + App(Var(k_ref), [branch].map(fix_label_to_var)) + } + + Fix( + k_ref, + [x_ref], + c(Var(x_ref)), + IfLe( + lhs, + rhs, + self.precps2cps(_then, c3), + self.precps2cps(_else, c3), + ), + ) + } + + self.precps2cps(rhs, c2) + } + + self.precps2cps(lhs, c1) + } If(ret_ty, cond, _then, _else) => { fn c1(cond : Value) -> Cps { // To avoid exponential growth in CPS ir, we abstract the outer `c` out. diff --git a/src/js/clops2js.mbt b/src/js/clops2js.mbt index 8549c80..28e11d9 100644 --- a/src/js/clops2js.mbt +++ b/src/js/clops2js.mbt @@ -73,6 +73,24 @@ pub fn JsEmitter::emit_cps(self : JsEmitter, cps : @cps.Cps) -> String { output += self.indent().emit_cps(_else) output += line_start + "}" } + IfEq(lhs, rhs, _then, _else) => { + let lhs_emit = self.emit_val(lhs) + let rhs_emit = self.emit_val(rhs) + output += line_start + "if (\{lhs_emit} === \{rhs_emit}) {" + output += self.indent().emit_cps(_then) + output += line_start + "} else { " + output += self.indent().emit_cps(_else) + output += line_start + "}" + } + IfLe(lhs, rhs, _then, _else) => { + let lhs_emit = self.emit_val(lhs) + let rhs_emit = self.emit_val(rhs) + output += line_start + "if (\{lhs_emit} <= \{rhs_emit}) {" + output += self.indent().emit_cps(_then) + output += line_start + "} else { " + output += self.indent().emit_cps(_else) + output += line_start + "}" + } Prim(Not, [b], bind, rest) => { let bool_emit = self.emit_val(b) output += line_start + "const \{emit_var(bind)} = !\{bool_emit};" diff --git a/src/riscv/before_alloc.mbt b/src/riscv/before_alloc.mbt index b32f416..c796ab5 100644 --- a/src/riscv/before_alloc.mbt +++ b/src/riscv/before_alloc.mbt @@ -28,6 +28,39 @@ fn reserve_fregs(cfg : @ssacfg.SsaCfg, block : @ssacfg.Block) -> Unit { inst => insts.push(inst) } } + match block.last_inst.val { + BranchEq(lhs, rhs, _then, _else) => { + let lhs_is_imm = match lhs { + Var(_) => false + _ => true + } + let rhs_is_imm = match rhs { + Var(_) => false + _ => true + } + if lhs_is_imm && rhs_is_imm { + let tmp = cfg.new_named("tmp", ty=lhs.get_type()) + insts.push(@ssacfg.Inst::Copy(tmp, lhs)) + block.last_inst.val = BranchEq(Var(tmp), rhs, _then, _else) + } + } + BranchLe(lhs, rhs, _then, _else) => { + let lhs_is_imm = match lhs { + Var(_) => false + _ => true + } + let rhs_is_imm = match rhs { + Var(_) => false + _ => true + } + if lhs_is_imm && rhs_is_imm { + let tmp = cfg.new_named("tmp", ty=lhs.get_type()) + insts.push(@ssacfg.Inst::Copy(tmp, lhs)) + block.last_inst.val = BranchLe(Var(tmp), rhs, _then, _else) + } + } + _ => () + } block.insts = insts } @@ -63,10 +96,11 @@ fn freeze_closure( } block.insts = insts block.last_inst.val = match block.last_inst.val { - Branch(cond, _then, _else) => Branch(fix_val(cond), _then, _else) + Branch(cond, _then, _else) => Branch(cond, _then, _else) + BranchEq(lhs, rhs, _then, _else) => BranchEq(lhs, rhs, _then, _else) + BranchLe(lhs, rhs, _then, _else) => BranchLe(lhs, rhs, _then, _else) Call(f, args) => Call(fix_val(f), args.map(fix_val)) - MakeArray(len, elem, kont) => - MakeArray(fix_val(len), fix_val(elem), fix_val(kont)) + MakeArray(len, elem, kont) => MakeArray(len, fix_val(elem), fix_val(kont)) Exit => Exit } } diff --git a/src/riscv/codegen.mbt b/src/riscv/codegen.mbt index b6ef1f3..33c9d42 100644 --- a/src/riscv/codegen.mbt +++ b/src/riscv/codegen.mbt @@ -854,6 +854,88 @@ fn CodegenBlock::codegen(self : CodegenBlock) -> Unit { ) self.branch_to(blk_else).codegen() } + BranchEq(lhs, rhs, blk_then, blk_else) => + match get_reg_ty(lhs) { + // TODO: just like Eq & Le, we need to ensure the 2 regs doesn't crash + F64 => { + let reg_lhs = self.pull_val_f(lhs) + let reg_rhs = self.pull_val_f(rhs) + // NOTE: since we've reserve freg first, this won't happen + if reg_lhs == reg_rhs { + @util.die("pulling same reg for comparing floats") + } + self.insert_asms( + [ + FeqD(reg_swap, reg_lhs, reg_rhs), + Beq(reg_swap, Zero, blk_else.to_string()), + ], + ) + self.branch_to(blk_then).codegen() + self.insert_asm( + Comment( + "we have tail call so no point generating a jump to skip the else branch", + ), + ) + self.branch_to(blk_else).codegen() + } + _ => { + let reg_lhs = self.pull_val_i(lhs) + let reg_rhs = self.pull_val_i(rhs) + // NOTE: since we've reserve freg first, this won't happen + if reg_lhs == reg_rhs { + @util.die("pulling same reg for comparing floats") + } + self.insert_asm(Bne(reg_lhs, reg_rhs, blk_else.to_string())) + self.branch_to(blk_then).codegen() + self.insert_asm( + Comment( + "we have tail call so no point generating a jump to skip the else branch", + ), + ) + self.branch_to(blk_else).codegen() + } + } + BranchLe(lhs, rhs, blk_then, blk_else) => + match get_reg_ty(lhs) { + // TODO: just like Eq & Le, we need to ensure the 2 regs doesn't crash + F64 => { + let reg_lhs = self.pull_val_f(lhs) + let reg_rhs = self.pull_val_f(rhs) + // NOTE: since we've reserve freg first, this won't happen + if reg_lhs == reg_rhs { + @util.die("pulling same reg for comparing floats") + } + self.insert_asms( + [ + FleD(reg_swap, reg_lhs, reg_rhs), + Beq(reg_swap, Zero, blk_else.to_string()), + ], + ) + self.branch_to(blk_then).codegen() + self.insert_asm( + Comment( + "we have tail call so no point generating a jump to skip the else branch", + ), + ) + self.branch_to(blk_else).codegen() + } + _ => { + let reg_lhs = self.pull_val_i(lhs) + let reg_rhs = self.pull_val_i(rhs) + // NOTE: since we've reserve freg first, this won't happen + if reg_lhs == reg_rhs { + @util.die("pulling same reg for comparing floats") + } + self.insert_asm(Bgt(reg_lhs, reg_rhs, blk_else.to_string())) + self.branch_to(blk_then).codegen() + self.insert_asm( + Comment( + "we have tail call so no point generating a jump to skip the else branch", + ), + ) + self.branch_to(blk_else).codegen() + } + } Exit => self.insert_asms([Li(A0, "0"), Li(A7, "93"), Ecall]) MakeArray(len, elem, continuation) => { // 1. call to generate an array diff --git a/src/riscv/collect_labels.mbt b/src/riscv/collect_labels.mbt index 057b055..508f4f8 100644 --- a/src/riscv/collect_labels.mbt +++ b/src/riscv/collect_labels.mbt @@ -46,6 +46,12 @@ fn collect_externals(cfg : @ssacfg.SsaCfg) -> ExternalLabels { collect_label_var(_then) collect_label_var(_else) } + BranchEq(lhs, rhs, _then, _else) | BranchLe(lhs, rhs, _then, _else) => { + collect_label_val(lhs) + collect_label_val(rhs) + collect_label_var(_then) + collect_label_var(_else) + } Call(f, args) => { collect_label_val(f) args.each(collect_label_val) diff --git a/src/riscv/interference_graph_build.mbt b/src/riscv/interference_graph_build.mbt index 8138172..bfd45a4 100644 --- a/src/riscv/interference_graph_build.mbt +++ b/src/riscv/interference_graph_build.mbt @@ -79,6 +79,10 @@ fn LiveVarAnalysis::collect_pc_inst( ) -> Unit { match inst { Branch(cond, _then, _else) => self.collect_val(cond) + BranchEq(lhs, rhs, _then, _else) | BranchLe(lhs, rhs, _then, _else) => { + self.collect_val(lhs) + self.collect_val(rhs) + } Call(f, args) => { self.collect_val(f) args.each(fn(arg) { self.collect_val(arg) }) @@ -127,7 +131,9 @@ fn collect_blocks_of_fn(cfg : @ssacfg.SsaCfg, fn_name : Var) -> Array[Var] { let cur_blk = cfg.blocks[label].unwrap() stack_process.push(label) match cur_blk.last_inst.val { - Branch(_, blk_then, blk_else) => { + Branch(_, blk_then, blk_else) + | BranchEq(_, _, blk_then, blk_else) + | BranchLe(_, _, blk_then, blk_else) => { q_collect.push(blk_then) q_collect.push(blk_else) } diff --git a/src/riscv/reg_spill.mbt b/src/riscv/reg_spill.mbt index c59572e..3c477c9 100644 --- a/src/riscv/reg_spill.mbt +++ b/src/riscv/reg_spill.mbt @@ -77,6 +77,8 @@ fn reg_spill_block(blk : @ssacfg.Block, spilled_var : Var) -> @ssacfg.Block { } let load_before_exit = match blk.last_inst.val { Branch(cond, _, _) => val_spilled(cond) + BranchEq(lhs, rhs, _, _) | BranchLe(lhs, rhs, _, _) => + val_spilled(lhs) || val_spilled(rhs) Call(f, args) => val_spilled(f) || vals_spilled(args) MakeArray(len, elem, kont) => val_spilled(len) || val_spilled(elem) || val_spilled(kont) diff --git a/src/ssacfg/clops2ssacfg.mbt b/src/ssacfg/clops2ssacfg.mbt index b41ce7f..1bb0c79 100644 --- a/src/ssacfg/clops2ssacfg.mbt +++ b/src/ssacfg/clops2ssacfg.mbt @@ -52,6 +52,44 @@ fn SsaCfg::cps2block( } Prim(Put, args, _, _) => @util.die("unexpect args \{args} for put") Fix(_) => @util.die("unexpected nested function") + IfLe(lhs, rhs, _then, _else) => { + // NOTE: for control flow connection points there's no reason for us to + // generate a type for it. + // NOTE: complete cur block; phis need to be fixed, but we'll deal with it in a later pass + let label_then = self.new_named("then") + let label_else = self.new_named("else") + cur_block.last_inst.val = BranchLe(lhs, rhs, label_then, label_else) + // NOTE: convert then block to ssa block + let block_then = Block::new_from(cur_block.fn_name, cur_label) + self.blocks[label_then] = block_then + self.cps2block(_then, block_then, label_then) + // NOTE: convert else block to ssa block + let block_else = Block::new_from(cur_block.fn_name, cur_label) + self.blocks[label_else] = block_else + cur_block = block_else + cur_label = label_else + continue _else + // no need to add merge point after if statements as we always have a tail call + } + IfEq(lhs, rhs, _then, _else) => { + // NOTE: for control flow connection points there's no reason for us to + // generate a type for it. + // NOTE: complete cur block; phis need to be fixed, but we'll deal with it in a later pass + let label_then = self.new_named("then") + let label_else = self.new_named("else") + cur_block.last_inst.val = BranchEq(lhs, rhs, label_then, label_else) + // NOTE: convert then block to ssa block + let block_then = Block::new_from(cur_block.fn_name, cur_label) + self.blocks[label_then] = block_then + self.cps2block(_then, block_then, label_then) + // NOTE: convert else block to ssa block + let block_else = Block::new_from(cur_block.fn_name, cur_label) + self.blocks[label_else] = block_else + cur_block = block_else + cur_label = label_else + continue _else + // no need to add merge point after if statements as we always have a tail call + } If(cond, _then, _else) => { // NOTE: for control flow connection points there's no reason for us to // generate a type for it. diff --git a/src/ssacfg/ssa_ir.mbt b/src/ssacfg/ssa_ir.mbt index 862c8c2..e3fab27 100644 --- a/src/ssacfg/ssa_ir.mbt +++ b/src/ssacfg/ssa_ir.mbt @@ -25,6 +25,8 @@ pub enum Inst { pub enum PCInst { // join points (val, block ref) Branch(Value, Var, Var) + BranchEq(Value, Value, Var, Var) + BranchLe(Value, Value, Var, Var) Call(Value, Array[Value]) // len, elem, continuation MakeArray(Value, Value, Value) diff --git a/test/build_all b/test/build_all index c6f4c24..ff74a4b 100755 --- a/test/build_all +++ b/test/build_all @@ -1,11 +1,16 @@ #!/usr/bin/env bash +echo "building CPS IR" moon run ../src/bin/main.mbt -- --end-stage cps ./test_src/$1.mbt -o $1.cps +echo "building CLOPS IR" moon run ../src/bin/main.mbt -- --end-stage clops ./test_src/$1.mbt -o $1.clops +echo "building ASM" moon run ../src/bin/main.mbt -- ./test_src/$1.mbt -o $1.s +echo "building JS" moon run ../src/bin/main.mbt -- ./test_src/$1.mbt -o $1.js --js +echo "building Binary" zig build-exe -target riscv64-linux -femit-bin=$1 \ $1.s ../riscv_rt/zig-out/lib/libmincaml.a \ -O Debug -fno-strip -mcpu=baseline_rv64