Skip to content

Commit

Permalink
emits ble & bgt now
Browse files Browse the repository at this point in the history
  • Loading branch information
glyh committed Oct 22, 2024
1 parent c42f431 commit f784b4d
Show file tree
Hide file tree
Showing 15 changed files with 302 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/closureps/cloenv.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ fn CloEnv::collect_named_fns(self : CloEnv, c : @cps.Cps) -> Unit {
self.collect_named_fns(body)
self.collect_named_fns(rest)
}
If(_, _then, _else) => {
If(_, _then, _else) | IfEq(_, _, _then, _else) | IfLe(_, _, _then, _else) => {
self.collect_named_fns(_then)
self.collect_named_fns(_else)
}
Expand Down
4 changes: 3 additions & 1 deletion src/closureps/cps2closureps.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn CloEnv::collect_label_closure(
self.collect_label_closure(rest, output)
output.remove(fn_name)
}
If(_, _then, _else) => {
If(_, _then, _else) | IfEq(_, _, _then, _else) | IfLe(_, _, _then, _else) => {
self.collect_label_closure(_then, output)
self.collect_label_closure(_else, output)
}
Expand Down Expand Up @@ -125,6 +125,8 @@ fn CloEnv::collect_closure(
self.add_rebind(bind, bind).collect_closure(rest, func_no_free_vars),
)
If(_cond, _then, _else) => If(_cond, rec(_then), rec(_else))
IfEq(lhs, rhs, _then, _else) => IfEq(lhs, rhs, rec(_then), rec(_else))
IfLe(lhs, rhs, _then, _else) => IfLe(lhs, rhs, rec(_then), rec(_else))
Prim(op, args, bind, rest) =>
Prim(
op,
Expand Down
11 changes: 11 additions & 0 deletions src/closureps_eval/interpreter.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,17 @@ pub fn CloPSInterpreter::eval(
Bool(b) => continue if b { _then } else { _else }
v => @util.die("unexpected condition \{v} for `if`")
}
IfLe(lhs, rhs, _then, _else) =>
match (self.eval_v!(lhs), self.eval_v!(rhs)) {
(Double(a), Double(b)) => continue if a <= b { _then } else { _else }
(Int(a), Int(b)) => continue if a <= b { _then } else { _else }
(lhs, rhs) => @util.die("unexpected input \{lhs}, \{rhs} for `le`")
}
IfEq(lhs, rhs, _then, _else) => {
let lhs = self.eval_v!(lhs)
let rhs = self.eval_v!(rhs)
continue if lhs == rhs { _then } else { _else }
}
Prim(Not, [v], bind, rest) => {
match self.eval_v!(v) {
Bool(b) => self.cur_env[bind] = Bool(not(b))
Expand Down
12 changes: 12 additions & 0 deletions src/cps/cps_ir.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ pub enum Cps {
KthTuple(Int, Value, Var, Cps)
Fix(Var, Array[Var], Cps, Cps)
If(Value, Cps, Cps)
IfLe(Value, Value, Cps, Cps)
IfEq(Value, Value, Cps, Cps)
Prim(PrimOp, Array[Value], Var, Cps)
// T marks the return type
App(Value, Array[Value])
Expand Down Expand Up @@ -98,6 +100,10 @@ fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps {
Fix(name, args, body_new, rest_new)
}
If(cond, _then, _else) => If(recv(cond), rec(_then), rec(_else))
IfEq(lhs, rhs, _then, _else) =>
IfEq(recv(lhs), recv(rhs), rec(_then), rec(_else))
IfLe(lhs, rhs, _then, _else) =>
IfLe(recv(lhs), recv(rhs), rec(_then), rec(_else))
Prim(op, args, bind, rest) => {
let rest_new = if from != bind { rec(rest) } else { rest }
Prim(op, args.map(recv), bind, rest_new)
Expand Down Expand Up @@ -140,6 +146,12 @@ pub fn Cps::free_variables(self : Cps) -> @hashset.T[Var] {
.free_variables()
.union(_then.free_variables())
.union(_else.free_variables())
IfEq(lhs, rhs, _then, _else) | IfLe(lhs, rhs, _then, _else) =>
lhs
.free_variables()
.union(rhs.free_variables())
.union(_then.free_variables())
.union(_else.free_variables())
Prim(_, args, bind, rest) => {
let fv_rest = rest.free_variables()
fv_rest.remove(bind)
Expand Down
22 changes: 22 additions & 0 deletions src/cps/cps_ir_string.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ fn to_str(cps : Cps, ~ident : String = "") -> String {
ident +
"}\n" +
rec(rest)
IfEq(lhs, rhs, _then, _else) =>
ident +
"if(\{lhs} == \{rhs}){\n" +
to_str(_then, ident=ident + " ") +
"\n" +
ident +
"} else {\n" +
to_str(_else, ident=ident + " ") +
"\n" +
ident +
"}"
IfLe(lhs, rhs, _then, _else) =>
ident +
"if(\{lhs} <= \{rhs}){\n" +
to_str(_then, ident=ident + " ") +
"\n" +
ident +
"} else {\n" +
to_str(_else, ident=ident + " ") +
"\n" +
ident +
"}"
If(cond, _then, _else) =>
ident +
"if(\{cond}){\n" +
Expand Down
56 changes: 56 additions & 0 deletions src/cps/precps2cps.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,62 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps {

self.precps2cps(tup, c1)
}
If(ret_ty, Prim(Bool, Eq, [lhs, rhs]), _then, _else) => {
fn c1(lhs : Value) -> Cps {
// To avoid exponential growth in CPS ir, we abstract the outer `c` out.
fn c2(rhs : Value) -> Cps {
let k_ref = self.new_tmp(Fun([ret_ty], Unit))
let x_ref = self.new_tmp(ret_ty)
fn c3(branch : Value) -> Cps {
App(Var(k_ref), [branch].map(fix_label_to_var))
}

Fix(
k_ref,
[x_ref],
c(Var(x_ref)),
IfEq(
lhs,
rhs,
self.precps2cps(_then, c3),
self.precps2cps(_else, c3),
),
)
}

self.precps2cps(rhs, c2)
}

self.precps2cps(lhs, c1)
}
If(ret_ty, Prim(Bool, Le, [lhs, rhs]), _then, _else) => {
fn c1(lhs : Value) -> Cps {
// To avoid exponential growth in CPS ir, we abstract the outer `c` out.
fn c2(rhs : Value) -> Cps {
let k_ref = self.new_tmp(Fun([ret_ty], Unit))
let x_ref = self.new_tmp(ret_ty)
fn c3(branch : Value) -> Cps {
App(Var(k_ref), [branch].map(fix_label_to_var))
}

Fix(
k_ref,
[x_ref],
c(Var(x_ref)),
IfLe(
lhs,
rhs,
self.precps2cps(_then, c3),
self.precps2cps(_else, c3),
),
)
}

self.precps2cps(rhs, c2)
}

self.precps2cps(lhs, c1)
}
If(ret_ty, cond, _then, _else) => {
fn c1(cond : Value) -> Cps {
// To avoid exponential growth in CPS ir, we abstract the outer `c` out.
Expand Down
18 changes: 18 additions & 0 deletions src/js/clops2js.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,24 @@ pub fn JsEmitter::emit_cps(self : JsEmitter, cps : @cps.Cps) -> String {
output += self.indent().emit_cps(_else)
output += line_start + "}"
}
IfEq(lhs, rhs, _then, _else) => {
let lhs_emit = self.emit_val(lhs)
let rhs_emit = self.emit_val(rhs)
output += line_start + "if (\{lhs_emit} === \{rhs_emit}) {"
output += self.indent().emit_cps(_then)
output += line_start + "} else { "
output += self.indent().emit_cps(_else)
output += line_start + "}"
}
IfLe(lhs, rhs, _then, _else) => {
let lhs_emit = self.emit_val(lhs)
let rhs_emit = self.emit_val(rhs)
output += line_start + "if (\{lhs_emit} <= \{rhs_emit}) {"
output += self.indent().emit_cps(_then)
output += line_start + "} else { "
output += self.indent().emit_cps(_else)
output += line_start + "}"
}
Prim(Not, [b], bind, rest) => {
let bool_emit = self.emit_val(b)
output += line_start + "const \{emit_var(bind)} = !\{bool_emit};"
Expand Down
40 changes: 37 additions & 3 deletions src/riscv/before_alloc.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,39 @@ fn reserve_fregs(cfg : @ssacfg.SsaCfg, block : @ssacfg.Block) -> Unit {
inst => insts.push(inst)
}
}
match block.last_inst.val {
BranchEq(lhs, rhs, _then, _else) => {
let lhs_is_imm = match lhs {
Var(_) => false
_ => true
}
let rhs_is_imm = match rhs {
Var(_) => false
_ => true
}
if lhs_is_imm && rhs_is_imm {
let tmp = cfg.new_named("tmp", ty=lhs.get_type())
insts.push(@ssacfg.Inst::Copy(tmp, lhs))
block.last_inst.val = BranchEq(Var(tmp), rhs, _then, _else)
}
}
BranchLe(lhs, rhs, _then, _else) => {
let lhs_is_imm = match lhs {
Var(_) => false
_ => true
}
let rhs_is_imm = match rhs {
Var(_) => false
_ => true
}
if lhs_is_imm && rhs_is_imm {
let tmp = cfg.new_named("tmp", ty=lhs.get_type())
insts.push(@ssacfg.Inst::Copy(tmp, lhs))
block.last_inst.val = BranchLe(Var(tmp), rhs, _then, _else)
}
}
_ => ()
}
block.insts = insts
}

Expand Down Expand Up @@ -63,10 +96,11 @@ fn freeze_closure(
}
block.insts = insts
block.last_inst.val = match block.last_inst.val {
Branch(cond, _then, _else) => Branch(fix_val(cond), _then, _else)
Branch(cond, _then, _else) => Branch(cond, _then, _else)
BranchEq(lhs, rhs, _then, _else) => BranchEq(lhs, rhs, _then, _else)
BranchLe(lhs, rhs, _then, _else) => BranchLe(lhs, rhs, _then, _else)
Call(f, args) => Call(fix_val(f), args.map(fix_val))
MakeArray(len, elem, kont) =>
MakeArray(fix_val(len), fix_val(elem), fix_val(kont))
MakeArray(len, elem, kont) => MakeArray(len, fix_val(elem), fix_val(kont))
Exit => Exit
}
}
Expand Down
82 changes: 82 additions & 0 deletions src/riscv/codegen.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,88 @@ fn CodegenBlock::codegen(self : CodegenBlock) -> Unit {
)
self.branch_to(blk_else).codegen()
}
BranchEq(lhs, rhs, blk_then, blk_else) =>
match get_reg_ty(lhs) {
// TODO: just like Eq & Le, we need to ensure the 2 regs doesn't crash
F64 => {
let reg_lhs = self.pull_val_f(lhs)
let reg_rhs = self.pull_val_f(rhs)
// NOTE: since we've reserve freg first, this won't happen
if reg_lhs == reg_rhs {
@util.die("pulling same reg for comparing floats")
}
self.insert_asms(
[
FeqD(reg_swap, reg_lhs, reg_rhs),
Beq(reg_swap, Zero, blk_else.to_string()),
],
)
self.branch_to(blk_then).codegen()
self.insert_asm(
Comment(
"we have tail call so no point generating a jump to skip the else branch",
),
)
self.branch_to(blk_else).codegen()
}
_ => {
let reg_lhs = self.pull_val_i(lhs)
let reg_rhs = self.pull_val_i(rhs)
// NOTE: since we've reserve freg first, this won't happen
if reg_lhs == reg_rhs {
@util.die("pulling same reg for comparing floats")
}
self.insert_asm(Bne(reg_lhs, reg_rhs, blk_else.to_string()))
self.branch_to(blk_then).codegen()
self.insert_asm(
Comment(
"we have tail call so no point generating a jump to skip the else branch",
),
)
self.branch_to(blk_else).codegen()
}
}
BranchLe(lhs, rhs, blk_then, blk_else) =>
match get_reg_ty(lhs) {
// TODO: just like Eq & Le, we need to ensure the 2 regs doesn't crash
F64 => {
let reg_lhs = self.pull_val_f(lhs)
let reg_rhs = self.pull_val_f(rhs)
// NOTE: since we've reserve freg first, this won't happen
if reg_lhs == reg_rhs {
@util.die("pulling same reg for comparing floats")
}
self.insert_asms(
[
FleD(reg_swap, reg_lhs, reg_rhs),
Beq(reg_swap, Zero, blk_else.to_string()),
],
)
self.branch_to(blk_then).codegen()
self.insert_asm(
Comment(
"we have tail call so no point generating a jump to skip the else branch",
),
)
self.branch_to(blk_else).codegen()
}
_ => {
let reg_lhs = self.pull_val_i(lhs)
let reg_rhs = self.pull_val_i(rhs)
// NOTE: since we've reserve freg first, this won't happen
if reg_lhs == reg_rhs {
@util.die("pulling same reg for comparing floats")
}
self.insert_asm(Bgt(reg_lhs, reg_rhs, blk_else.to_string()))
self.branch_to(blk_then).codegen()
self.insert_asm(
Comment(
"we have tail call so no point generating a jump to skip the else branch",
),
)
self.branch_to(blk_else).codegen()
}
}
Exit => self.insert_asms([Li(A0, "0"), Li(A7, "93"), Ecall])
MakeArray(len, elem, continuation) => {
// 1. call to generate an array
Expand Down
6 changes: 6 additions & 0 deletions src/riscv/collect_labels.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ fn collect_externals(cfg : @ssacfg.SsaCfg) -> ExternalLabels {
collect_label_var(_then)
collect_label_var(_else)
}
BranchEq(lhs, rhs, _then, _else) | BranchLe(lhs, rhs, _then, _else) => {
collect_label_val(lhs)
collect_label_val(rhs)
collect_label_var(_then)
collect_label_var(_else)
}
Call(f, args) => {
collect_label_val(f)
args.each(collect_label_val)
Expand Down
8 changes: 7 additions & 1 deletion src/riscv/interference_graph_build.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ fn LiveVarAnalysis::collect_pc_inst(
) -> Unit {
match inst {
Branch(cond, _then, _else) => self.collect_val(cond)
BranchEq(lhs, rhs, _then, _else) | BranchLe(lhs, rhs, _then, _else) => {
self.collect_val(lhs)
self.collect_val(rhs)
}
Call(f, args) => {
self.collect_val(f)
args.each(fn(arg) { self.collect_val(arg) })
Expand Down Expand Up @@ -127,7 +131,9 @@ fn collect_blocks_of_fn(cfg : @ssacfg.SsaCfg, fn_name : Var) -> Array[Var] {
let cur_blk = cfg.blocks[label].unwrap()
stack_process.push(label)
match cur_blk.last_inst.val {
Branch(_, blk_then, blk_else) => {
Branch(_, blk_then, blk_else)
| BranchEq(_, _, blk_then, blk_else)
| BranchLe(_, _, blk_then, blk_else) => {
q_collect.push(blk_then)
q_collect.push(blk_else)
}
Expand Down
2 changes: 2 additions & 0 deletions src/riscv/reg_spill.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ fn reg_spill_block(blk : @ssacfg.Block, spilled_var : Var) -> @ssacfg.Block {
}
let load_before_exit = match blk.last_inst.val {
Branch(cond, _, _) => val_spilled(cond)
BranchEq(lhs, rhs, _, _) | BranchLe(lhs, rhs, _, _) =>
val_spilled(lhs) || val_spilled(rhs)
Call(f, args) => val_spilled(f) || vals_spilled(args)
MakeArray(len, elem, kont) =>
val_spilled(len) || val_spilled(elem) || val_spilled(kont)
Expand Down
Loading

0 comments on commit f784b4d

Please sign in to comment.