From 69fd34d4caa094d593179cd2ab53b5fb0b6349fd Mon Sep 17 00:00:00 2001 From: glyh Date: Sat, 5 Oct 2024 17:32:58 +0800 Subject: [PATCH] CPS IR, done, TODO: closure conversion --- src/bin/main.mbt | 198 +++++++++++++++++++++----------------- src/bin/moon.pkg.json | 2 + src/cps/cps_env.mbt | 12 +++ src/cps/cps_ir.mbt | 86 +++++++++++++++++ src/cps/cps_ir_string.mbt | 63 ++++++++++++ src/cps/moon.pkg.json | 15 +++ src/cps/optimizations.mbt | 28 ++++++ src/cps/precps2cps.mbt | 138 ++++++++++++++++++++++++++ src/cps/types.mbt | 5 + src/cps/utils.mbt | 11 +++ src/precps/ast2precps.mbt | 110 +++++++++++++++++++++ src/precps/moon.pkg.json | 14 +++ src/precps/precps_ir.mbt | 72 ++++++++++++++ src/precps/preprocess.mbt | 22 +++++ src/precps/tyenv.mbt | 56 +++++++++++ src/precps/types.mbt | 10 ++ src/precps/var.mbt | 19 ++++ src/top.mbt | 4 +- 18 files changed, 777 insertions(+), 88 deletions(-) create mode 100644 src/cps/cps_env.mbt create mode 100644 src/cps/cps_ir.mbt create mode 100644 src/cps/cps_ir_string.mbt create mode 100644 src/cps/moon.pkg.json create mode 100644 src/cps/optimizations.mbt create mode 100644 src/cps/precps2cps.mbt create mode 100644 src/cps/types.mbt create mode 100644 src/cps/utils.mbt create mode 100644 src/precps/ast2precps.mbt create mode 100644 src/precps/moon.pkg.json create mode 100644 src/precps/precps_ir.mbt create mode 100644 src/precps/preprocess.mbt create mode 100644 src/precps/tyenv.mbt create mode 100644 src/precps/types.mbt create mode 100644 src/precps/var.mbt diff --git a/src/bin/main.mbt b/src/bin/main.mbt index a745586..23983a0 100644 --- a/src/bin/main.mbt +++ b/src/bin/main.mbt @@ -2,9 +2,11 @@ enum Stages { Parse Typecheck - Knf - KnfOpt - Closure + PreCps + Cps + //Knf + //KnfOpt + //Closure Asm Finished } derive(Show, Eq, Compare) @@ -13,9 +15,11 @@ fn Stages::from_string(s : String) -> Stages? { match s { "parse" => Some(Stages::Parse) "typecheck" => Some(Stages::Typecheck) - "knf" => Some(Stages::Knf) - "knf-opt" => Some(Stages::KnfOpt) - "closure" => Some(Stages::Closure) + "precps" => Some(Stages::PreCps) + "cps" => Some(Stages::Cps) + //"knf" => Some(Stages::Knf) + //"knf-opt" => Some(Stages::KnfOpt) + //"closure" => Some(Stages::Closure) "riscv" => Some(Stages::Asm) "finished" => Some(Stages::Finished) _ => None @@ -25,11 +29,9 @@ fn Stages::from_string(s : String) -> Stages? { fn Stages::next(self : Stages) -> Stages { match self { Stages::Parse => Stages::Typecheck - Stages::Typecheck => Stages::Knf - Stages::Knf => Stages::KnfOpt - Stages::KnfOpt => Stages::Closure - Stages::Closure => // TODO - Stages::Asm + Stages::Typecheck => Stages::PreCps + Stages::PreCps => Stages::Cps + Stages::Cps => Stages::Asm Stages::Asm => Stages::Finished Stages::Finished => Stages::Finished } @@ -41,10 +43,13 @@ struct CompileStatus { mut source_code : String? mut ast : @types.Syntax? mut typechecked : @types.Syntax? - knf_env : @knf.KnfEnv - mut knf : @knf.Knf? - mut opt_knf : @knf.Knf? - mut closure_ir : @closure.Program? + mut precps : @precps.PreCps? + mut precpsenv : @precps.TyEnv? + mut cps : @cps.Cps? + //knf_env : @knf.KnfEnv + //mut knf : @knf.Knf? + //mut opt_knf : @knf.Knf? + //mut closure_ir : @closure.Program? mut asm : Array[@riscv.AssemblyFunction]? } @@ -59,24 +64,28 @@ fn CompileStatus::initialize( source_code: None, ast: None, typechecked: None, - knf_env: @knf.KnfEnv::new(@types.externals), - knf: None, - opt_knf: None, - closure_ir: None, + precps: None, + precpsenv: None, + cps: None, + //knf_env: @knf.KnfEnv::new(@types.externals), + //knf: None, + //opt_knf: None, + //closure_ir: None, asm: None, } match start_stage { Parse => v.source_code = Some(file) Typecheck => v.ast = Some(@types.Syntax::from_json!(@json.parse!(file))) - Knf => v.typechecked = Some(@types.Syntax::from_json!(@json.parse!(file))) - KnfOpt => { - v.knf = Some(@knf.Knf::from_json!(@json.parse!(file))) - v.knf_env.init_counter_from_existing(v.knf.unwrap()) - } - Closure => { - v.opt_knf = Some(@knf.Knf::from_json!(@json.parse!(file))) - v.knf_env.init_counter_from_existing(v.opt_knf.unwrap()) - } + PreCps => + v.typechecked = Some(@types.Syntax::from_json!(@json.parse!(file))) + //KnfOpt => { + // v.knf = Some(@knf.Knf::from_json!(@json.parse!(file))) + // v.knf_env.init_counter_from_existing(v.knf.unwrap()) + //} + //Closure => { + // v.opt_knf = Some(@knf.Knf::from_json!(@json.parse!(file))) + // v.knf_env.init_counter_from_existing(v.opt_knf.unwrap()) + //} _ => fail!("invalid start stage") } v @@ -97,27 +106,39 @@ fn CompileStatus::step(self : CompileStatus) -> Bool { let to_check = self.ast.unwrap().clone() self.typechecked = Some(@typing.infer_type(to_check)) } - Knf => { - let preprocessed = self.knf_env.syntax_preprocess( - self.typechecked.unwrap(), - ) - let knf = self.knf_env.to_knf(preprocessed) - self.knf = Some(knf) - } - KnfOpt => { - let knf = self.knf.unwrap() - // TODO: optimize - self.opt_knf = Some(knf) + PreCps => { + let tyenv = @precps.TyEnv::new(@types.externals) + let entry_inlined = @precps.inline_entry(self.typechecked.unwrap()) + self.precps = Some(tyenv.ast2precps(entry_inlined)) + self.precpsenv = Some(tyenv) } - Closure => { - let closure_ir = @closure.knf_program_to_closure( - self.opt_knf.unwrap(), - Map::from_iter(@types.externals.iter()), - ) - self.closure_ir = Some(closure_ir) + Cps => { + let cpsenv = @cps.CpsEnv::new(self.precpsenv.unwrap().counter.val) + let mut cps = cpsenv.precps2cps(self.precps.unwrap(), @cps.Cps::Just) + cps = @cps.alias_analysis(cps) + self.cps = Some(cps) } + //Knf => { + // let preprocessed = self.knf_env.syntax_preprocess( + // self.typechecked.unwrap(), + // ) + // let knf = self.knf_env.to_knf(preprocessed) + // self.knf = Some(knf) + //} + //KnfOpt => { + // let knf = self.knf.unwrap() + // // TODO: optimize + // self.opt_knf = Some(knf) + //} + //Closure => { + // let closure_ir = @closure.knf_program_to_closure( + // self.opt_knf.unwrap(), + // Map::from_iter(@types.externals.iter()), + // ) + // self.closure_ir = Some(closure_ir) + //} Asm => { - let real_asm = @riscv.emit(@util.die("TODO")) + let real_asm = @riscv.emit(@util.die("TODO5")) self.asm = Some(real_asm) } Finished => () @@ -131,20 +152,25 @@ fn CompileStatus::output(self : CompileStatus, json : Bool) -> String { match self.curr_stage { Parse => self.source_code.unwrap() Typecheck => @json.stringify(self.ast.unwrap().to_json()) - Knf => @json.stringify(self.typechecked.unwrap().to_json()) - KnfOpt => @json.stringify(self.knf.unwrap().to_json()) - Closure => @json.stringify(self.opt_knf.unwrap().to_json()) - Asm => @util.die("TODO") + PreCps => @util.die("TODO1") + Cps => @util.die("TODO2") + //Knf => @json.stringify(self.typechecked.unwrap().to_json()) + //KnfOpt => @json.stringify(self.knf.unwrap().to_json()) + //Closure => @json.stringify(self.opt_knf.unwrap().to_json()) + Asm => @util.die("TODO3") Finished => @riscv.print_functions(self.asm.unwrap()) } } else { match self.curr_stage { Parse => self.source_code.unwrap() Typecheck => self.ast.unwrap().to_string() - Knf => self.typechecked.unwrap().to_string() - KnfOpt => self.knf.unwrap().to_string() - Closure => self.opt_knf.unwrap().to_string() - Asm => self.closure_ir.unwrap().to_string() + PreCps => self.typechecked.unwrap().to_string() + Cps => self.precps.unwrap().to_string() + //Knf => self.typechecked.unwrap().to_string() + //KnfOpt => self.knf.unwrap().to_string() + //Closure => self.opt_knf.unwrap().to_string() + Asm => self.cps.unwrap().to_string() + //self.closure_ir.unwrap().to_string() Finished => @riscv.print_functions(self.asm.unwrap()) } } @@ -261,12 +287,12 @@ fn main { ) // Configure pipeline - if knf_interpreter.val { - end_stage.val = Stages::Knf - } - if closure_interpreter.val { - end_stage.val = Stages::Closure - } + //if knf_interpreter.val { + // end_stage.val = Stages::Knf + //} + //if closure_interpreter.val { + // end_stage.val = Stages::Closure + //} let stages_to_print = print.val.map( fn(s) { match Stages::from_string(s) { @@ -309,32 +335,32 @@ fn main { } // Output - if knf_interpreter.val { - let knfi = @knf_eval.KnfInterpreter::new() - add_interpreter_fns(knfi) - match knfi.eval_full?(status.knf.unwrap()) { - Ok(_) => () - Err(Failure(e)) => { - println(e) - @util.die("KNF interpreter error") - } - } - } else if closure_interpreter.val { - let clsi = @closure_eval.ClosureInterpreter::new() - add_closure_interpreter_fns(clsi) - match clsi.eval_full?(status.closure_ir.unwrap()) { - Ok(_) => () - Err(Failure(e)) => { - println(e) - @util.die("Closure interpreter error") - } - } + //if knf_interpreter.val { + // let knfi = @knf_eval.KnfInterpreter::new() + // add_interpreter_fns(knfi) + // match knfi.eval_full?(status.knf.unwrap()) { + // Ok(_) => () + // Err(Failure(e)) => { + // println(e) + // @util.die("KNF interpreter error") + // } + // } + //} else if closure_interpreter.val { + // let clsi = @closure_eval.ClosureInterpreter::new() + // add_closure_interpreter_fns(clsi) + // match clsi.eval_full?(status.closure_ir.unwrap()) { + // Ok(_) => () + // Err(Failure(e)) => { + // println(e) + // @util.die("Closure interpreter error") + // } + // } + //} else { + let out_string = status.output(json.val) + if out_file.val == "-" { + println(out_string) } else { - let out_string = status.output(json.val) - if out_file.val == "-" { - println(out_string) - } else { - @fs.write_to_string(out_file.val, out_string) - } + @fs.write_to_string(out_file.val, out_string) } + //} } diff --git a/src/bin/moon.pkg.json b/src/bin/moon.pkg.json index 48156ab..cdeb84d 100644 --- a/src/bin/moon.pkg.json +++ b/src/bin/moon.pkg.json @@ -10,6 +10,8 @@ "path": "moonbitlang/minimbt", "alias": "types" }, + "moonbitlang/minimbt/precps", + "moonbitlang/minimbt/cps", "moonbitlang/minimbt/knf", "moonbitlang/minimbt/typing", "moonbitlang/minimbt/knf_eval", diff --git a/src/cps/cps_env.mbt b/src/cps/cps_env.mbt new file mode 100644 index 0000000..85b9600 --- /dev/null +++ b/src/cps/cps_env.mbt @@ -0,0 +1,12 @@ +struct CpsEnv { + counter : Ref[Int] +} + +pub fn CpsEnv::new(counter : Int) -> CpsEnv { + { counter: { val: counter } } +} + +fn CpsEnv::new_tmp(self : CpsEnv, t : T) -> Var { + self.counter.val = self.counter.val + 1 + { name: { val: None }, id: self.counter.val, ty: t } +} diff --git a/src/cps/cps_ir.mbt b/src/cps/cps_ir.mbt new file mode 100644 index 0000000..1affa83 --- /dev/null +++ b/src/cps/cps_ir.mbt @@ -0,0 +1,86 @@ +struct Var { + name : Ref[String?] + id : Int + ty : T +} + +fn Var::op_equal(lhs : Var, rhs : Var) -> Bool { + lhs.id == rhs.id +} + +fn Var::from_precps(v : @precps.Var, t : T) -> Var { + { id: v.id, name: { val: v.name }, ty: t } +} + +enum Value { + Var(Var) + Label(Var) + Unit + Int(Int) + Double(Double) +} derive(Eq) + +fn Value::replace_var_bind(self : Value, from : Var, to : Value) -> Value { + match self { + Var(v) => if v == from { to } else { self } + _ => self + } +} + +enum AccessPath { + OffP(Int) + SelP(Int, AccessPath) +} derive(Show) + +pub enum Cps { + // T marks the binding's type + Record(Array[(Value, AccessPath)], Var, Cps) + Select(Int, Value, Var, Cps) + Offset(Int, Value, Var, Cps) + Fix(Var, Array[Var], Cps, Cps) + Switch(Value, Array[Cps]) + Prim(PrimOp, Array[Value], Var, Cps) + // T marks the return type + App(Value, Array[Value]) + Just(Value) +} + +fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps { + fn rec(s : Cps) { + s.replace_var_bind(from, to) + } + + fn recv(v : Value) { + v.replace_var_bind(from, to) + } + + match self { + Record(record, bind, rest) => { + let rest_new = if from != bind { rec(rest) } else { rest } + Record(record.map(fn { (v, path) => (recv(v), path) }), bind, rest_new) + } + Select(idx, v, bind, rest) => { + let rest_new = if from != bind { rec(rest) } else { rest } + Select(idx, recv(v), bind, rest_new) + } + Offset(idx, v, bind, rest) => { + let rest_new = if from != bind { rec(rest) } else { rest } + Offset(idx, recv(v), bind, rest_new) + } + Fix(name, args, body, rest) => { + let body_new = if from != name && not(args.contains(from)) { + rec(body) + } else { + body + } + let rest_new = if from != name { rec(rest) } else { body } + Fix(name, args, body_new, rest_new) + } + Switch(v, branches) => Switch(recv(v), branches.map(rec)) + Prim(op, args, bind, rest) => { + let rest_new = if from != bind { rec(rest) } else { rest } + Prim(op, args.map(recv), bind, rest_new) + } + App(f, args) => App(recv(f), args.map(recv)) + } +} diff --git a/src/cps/cps_ir_string.mbt b/src/cps/cps_ir_string.mbt new file mode 100644 index 0000000..29a197b --- /dev/null +++ b/src/cps/cps_ir_string.mbt @@ -0,0 +1,63 @@ +impl Show for Var with output(self, logger) { + logger.write_string(self.to_string()) +} + +pub fn Var::to_string(self : Var) -> String { + match self.name.val { + None => "?\{self.id}" + Some(n) => if self.id < 0 { n } else { "\{n}.\{self.id}" } + } +} + +impl Show for Value with output(self, logger) { + logger.write_string(self.to_string()) +} + +pub fn Value::to_string(self : Value) -> String { + match self { + Var(v) => v.to_string() + Label(v) => ":" + v.to_string() + Unit => "()" + Int(i) => i.to_string() + Double(f) => f.to_string() + } +} + +impl Show for Cps with output(self, logger) { + logger.write_string(self.to_string()) +} + +pub fn Cps::to_string(self : Cps) -> String { + to_str(self) +} + +fn to_str(cps : Cps, ~ident : String = "") -> String { + fn rec(c : Cps) { + to_str(c, ~ident) + } + + match cps { + Record(arr, bind, rest) => ident + "\{bind} = \{arr}\n" + rec(rest) + Select(idx, v, bind, rest) => ident + "\{bind} = \{v}[\{idx}]\n" + rec(rest) + Offset(idx, v, bind, rest) => ident + "\{bind} = \{v}+\{idx}\n" + rec(rest) + Fix(name, args, body, rest) => + ident + + "fn \{name}(\{args}) {\n" + + to_str(body, ident=ident + " ") + + "\n" + + ident + + "}\n" + + rec(rest) + Switch(v, branches) => + ident + + "switch(\{v}){\n" + + branches.map(fn { c => to_str(c, ident=ident + " ") }).join(";\n") + + "\n" + + ident + + "}" + Prim(op, args, bind, rest) => + ident + "prim \{bind} = \{op}(\{args})\n" + rec(rest) + App(f, args) => ident + "\{f}(\{args})" + Just(v) => ident + "return \{v}" + } +} diff --git a/src/cps/moon.pkg.json b/src/cps/moon.pkg.json new file mode 100644 index 0000000..32d754f --- /dev/null +++ b/src/cps/moon.pkg.json @@ -0,0 +1,15 @@ +{ + "import": [ + { + "path": "moonbitlang/minimbt", + "alias": "top" + }, + "moonbitlang/minimbt/util", + "moonbitlang/minimbt/precps" + ], + "test-import": [ + "moonbitlang/minimbt/parser", + "moonbitlang/minimbt/lex", + "moonbitlang/minimbt/typing" + ] +} diff --git a/src/cps/optimizations.mbt b/src/cps/optimizations.mbt new file mode 100644 index 0000000..cd13d47 --- /dev/null +++ b/src/cps/optimizations.mbt @@ -0,0 +1,28 @@ +pub fn alias_analysis(c : Cps) -> Cps { + let rec = alias_analysis + match c { + Record(arr, bind, inner) => Record(arr, bind, rec(inner)) + Select(idx, v, bind, inner) => Select(idx, v, bind, rec(inner)) + Offset(idx, v, bind, inner) => Offset(idx, v, bind, rec(inner)) + Fix(f1, args1, App(f2, args2), rest) => { + let args1_fix = args1.map(fn { v => Var(v) }) + if args1_fix == args2 { + // f1 is an alias of f2 + rec(rest.replace_var_bind(f1, f2)) + } else { + Fix(f1, args1, App(f2, args2), rec(rest)) + } + } + Fix(f, args, body, rest) => Fix(f, args, rec(body), rec(rest)) + Switch(v, branches) => Switch(v, branches.map(rec)) + Prim(op, vs, bind, rest) => Prim(op, vs, bind, rec(rest)) + c => c + } +} + +test "array equal" { + let a1 = [1, 2, 3, 4] + let a2 = [1, 2, 3] + a2.push(4) + assert_eq!(a1, a2) +} diff --git a/src/cps/precps2cps.mbt b/src/cps/precps2cps.mbt new file mode 100644 index 0000000..4297b71 --- /dev/null +++ b/src/cps/precps2cps.mbt @@ -0,0 +1,138 @@ +// REF: Compiling with Continuations +typealias Cont = (Value) -> Cps + +fn CpsEnv::precps2cps_list( + self : CpsEnv, + a : Array[P], + c : (@immut/list.T[Value]) -> Cps +) -> Cps { + fn g(a : ArrayView[P], w : @immut/list.T[Value]) { + match a { + [e, .. as r] => self.precps2cps(e, fn(v) { g(r, @immut/list.Cons(v, w)) }) + [] => c(w.rev()) + } + } + + g(a[:], @immut/list.of([])) +} + +pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { + match s { + Unit => c(Unit) + Int(i) => c(Int(i)) + Double(f) => c(Double(f)) + Let(ty, name, rhs, rest) => { + fn c1(v : Value) { + match v { + Var(v) => v.name.val = name.name // mark their name + _ => () + } + let rest = self.precps2cps(rest, c) + let orig_binding = Var::from_precps(name, ty) + rest.replace_var_bind(orig_binding, v) + } + + self.precps2cps(rhs, c1) + } + // NOTE: + // Any function f of type (a1, a2, a3, .., an) -> r has been transformed into + // (a1, a2, a3, .., an, (r) -> Unit) + LetRec(ty, fn_name, fn_args, body, rest) => { + guard let Fun(arg_tys, ret_ty) = ty else { + _ => @util.die("Calling a non function") + } + () // generate the type for cps converted function + let new_arg_tys = arg_tys.copy() + let k_type = T::Fun([ret_ty], Unit) + new_arg_tys.push(k_type) + let new_f_type = T::Fun(new_arg_tys, Unit) + // reference to the continuation + let new_vars = [] + for var in zip2(fn_args, arg_tys) { + let var = Var::from_precps(var.0, var.1) + new_vars.push(var) + } + // create the wrapper + let k_ref = self.new_tmp(k_type) + new_vars.push(k_ref) + let f_ref = Var::from_precps(fn_name, new_f_type) + fn fn_cont(returned : Value) { + App(Var(k_ref), [returned]) + } + + Fix( + f_ref, + new_vars, + self.precps2cps(body, fn_cont), + self.precps2cps(rest, c), + ) + } + // (a1, a2, a3, .., an, (r) -> Unit) + App(ret_ty, f, args) => { + let k_ref = self.new_tmp(Fun([ret_ty], Unit)) + let x_ref = self.new_tmp(ret_ty) + fn c1(f : Value) { + fn c2(es : @immut/list.T[Value]) { + App(f, es.iter().append(Var(k_ref)).collect()) + } + + self.precps2cps_list(args, c2) + } + + Fix(k_ref, [x_ref], c(Var(x_ref)), self.precps2cps(f, c1)) + } + Var(ty, v) => c(Var(Var::from_precps(v, ty))) + Tuple(tup_ty, elements) => { + fn c1(vs : @immut/list.T[Value]) { + let tmp = self.new_tmp(tup_ty) + let record_inner = vs.iter().map(fn { v => (v, OffP(0)) }).collect() + Record(record_inner, tmp, c(Var(tmp))) + } + + self.precps2cps_list(elements, c1) + } + Prim(ty, rator, rands) => { + fn c1(a : @immut/list.T[Value]) { + let tmp = self.new_tmp(ty) + Prim(rator, a.iter().collect(), tmp, c(Var(tmp))) + } + + self.precps2cps_list(rands, c1) + } + KthTuple(ret_ty, offset, tup) => { + fn c1(v : Value) { + let tmp = self.new_tmp(ret_ty) + Select(offset, v, tmp, c(Var(tmp))) + } + + self.precps2cps(tup, c1) + } + If(ret_ty, cond, _then, _else) => { + fn c1(cond : Value) -> Cps { + // To avoid exponential growth in CPS ir, we abstract the outer `c` out. + let k_ref = self.new_tmp(Fun([ret_ty], Unit)) + let x_ref = self.new_tmp(ret_ty) + fn c2(branch : Value) -> Cps { + App(Var(k_ref), [branch]) + } + + Fix( + k_ref, + [x_ref], + c(Var(x_ref)), + Switch( + cond, + [ + // 0: else branch + self.precps2cps(_else, c2), + // 1: then branch + self.precps2cps(_then, c2), + ], + ), + ) + } + + self.precps2cps(cond, c1) + } + } +} diff --git a/src/cps/types.mbt b/src/cps/types.mbt new file mode 100644 index 0000000..ede87ef --- /dev/null +++ b/src/cps/types.mbt @@ -0,0 +1,5 @@ +typealias T = @top.Type + +typealias PrimOp = @precps.PrimOp + +typealias P = @precps.PreCps diff --git a/src/cps/utils.mbt b/src/cps/utils.mbt new file mode 100644 index 0000000..e5b6d67 --- /dev/null +++ b/src/cps/utils.mbt @@ -0,0 +1,11 @@ +fn zip2[A, B](arr1 : Array[A], arr2 : Array[B]) -> Array[(A, B)] { + let out : Array[(A, B)] = [] + loop (arr1[:], arr2[:]) { + ([], []) => break out + ([a, .. as arr1], [b, .. as arr2]) => { + out.push((a, b)) + continue (arr1, arr2) + } + _ => @util.die("zipping arrays of different size") + } +} diff --git a/src/precps/ast2precps.mbt b/src/precps/ast2precps.mbt new file mode 100644 index 0000000..b740c12 --- /dev/null +++ b/src/precps/ast2precps.mbt @@ -0,0 +1,110 @@ +pub fn TyEnv::ast2precps(self : TyEnv, s : S) -> PreCps { + fn rec(s : S) { + self.ast2precps(s) + } + + match s { + Unit => Unit + Bool(true) => Int(1) + Bool(false) => Int(0) + Int(i) => Int(i) + Double(f) => Double(f) + Var(name) => + match self.find(name) { + Some(v) => v + None => @util.die("No binding for \{name}") + } + Tuple(tup) => { + let tup_outs = tup.map(rec) + let tys = tup_outs.map(fn { p => p.get_type() }) + let tup_ty = T::Tuple(tys) + Tuple(tup_ty, tup_outs) + } + Not(inner) => Prim(Bool, Not, [rec(inner)]) + Array(len, elem) => { + let elem = rec(elem) + let len = rec(len) + let ret_ty = T::Array(elem.get_type()) + Prim(ret_ty, MakeArray, [len, elem]) + } + Neg(inner, ~kind) => { + let n : Numeric = if kind == T::Int { Int } else { Double } + Prim(kind, Neg(n), [rec(inner)]) + } + App(f, args) => { + let f = rec(f) + guard let Fun(_, ret_ty) = f.get_type() else { + _ => @util.die("Calling a non function") + } + App(ret_ty, f, args.map(rec)) + } + Get(arr, idx) => { + let arr = rec(arr) + guard let Array(ele_ty) = arr.get_type() else { + _ => @util.die("indexing a non array") + } + Prim(ele_ty, Get, [arr, rec(idx)]) + } + If(cond, _then, _else) => { + let cond = rec(cond) + let _then = rec(_then) + let _else = rec(_else) + If(_then.get_type(), cond, _then, _else) + } + Prim(lhs, rhs, op, ~kind) => { + let n : Numeric = if kind == T::Int { Int } else { Double } + Prim(kind, Math(op, n), [rec(lhs), rec(rhs)]) + } + Eq(lhs, rhs) => Prim(Bool, Eq, [rec(lhs), rec(rhs)]) + LE(lhs, rhs) => Prim(Bool, Le, [rec(lhs), rec(rhs)]) + Let((name, ty), rhs, rest) => { + let (bind, env_new) = self.add(name, ty) + Let(ty, bind, rec(rhs), env_new.ast2precps(rest)) + } + LetRec(f, rest) => { + let (fvar, env_rest) = self.add(f.name.0, f.name.1) + let mut env_body = env_rest + let args = [] + f.args.each( + fn(arg) { + let (name, ty) = arg + let (argvar, env_body_new) = env_body.add(name, ty) + args.push(argvar) + env_body = env_body_new + }, + ) + LetRec( + f.name.1, + fvar, + args, + env_body.ast2precps(f.body), + env_rest.ast2precps(rest), + ) + } + LetTuple(tup, rhs, rest) => { + let tup_ty = T::Tuple(tup.map(fn { (_, ty) => ty })) + let tup_var = self.gen_tmp() + fn go( + tup : ArrayView[(String, T)], + idx : Int, + env_cur : TyEnv + ) -> PreCps { + match tup { + [] => env_cur.ast2precps(rest) + [ti, .. as tup_rest] => { + let (vari, env_next) = env_cur.add(ti.0, ti.1) + Let( + ti.1, + vari, + KthTuple(ti.1, idx, Var(tup_ty, tup_var)), + go(tup_rest, idx + 1, env_next), + ) + } + } + } + + Let(tup_ty, tup_var, rec(rhs), go(tup[:], 0, self)) + } + Put(arr, idx, rhs) => Prim(Unit, Put, [rec(arr), rec(idx), rec(rhs)]) + } +} diff --git a/src/precps/moon.pkg.json b/src/precps/moon.pkg.json new file mode 100644 index 0000000..da93d1c --- /dev/null +++ b/src/precps/moon.pkg.json @@ -0,0 +1,14 @@ +{ + "import": [ + { + "path": "moonbitlang/minimbt", + "alias": "types" + }, + "moonbitlang/minimbt/util" + ], + "test-import": [ + "moonbitlang/minimbt/parser", + "moonbitlang/minimbt/lex", + "moonbitlang/minimbt/typing" + ] +} diff --git a/src/precps/precps_ir.mbt b/src/precps/precps_ir.mbt new file mode 100644 index 0000000..4f5b317 --- /dev/null +++ b/src/precps/precps_ir.mbt @@ -0,0 +1,72 @@ +// NOTE: +// PreCPS IR different from typed AST in that: +// 1. all expression carries a type, or can be have it's type inferred without +// an associated type enviornment +// 2. Booleans are desugarred into ints. +// 3. All variables that refers to the same object has a same unique ID. +pub enum PreCps { + Unit + Int(Int) + Double(Double) + // T marks the binding type + LetRec(T, Var, Array[Var], PreCps, PreCps) + Let(T, Var, PreCps, PreCps) + // T marks return type + Var(T, Var) + Tuple(T, Array[PreCps]) + Prim(T, PrimOp, Array[PreCps]) + KthTuple(T, Int, PreCps) + App(T, PreCps, Array[PreCps]) + If(T, PreCps, PreCps, PreCps) +} + +impl Show for PreCps with output(self, logger) { + logger.write_string(self.to_string()) +} + +pub fn PreCps::to_string(self : PreCps) -> String { + match self { + Unit => "()" + Int(i) => i.to_string() + Double(f) => f.to_string() + LetRec(_, name, args, body, rest) => + "letrec \{name}(\{args}){\{body}} in \{rest}" + Let(_, name, rhs, rest) => "let \{name} = \{rhs} in \{rest}" + Var(_, v) => v.to_string() + Tuple(_, tup) => tup.to_string() + KthTuple(_, idx, tup) => "(\{tup}).\{idx}" + App(_, f, args) => "(\{f} \{args})" + Prim(_, op, args) => "#(\{op} \{args})" + If(_, cond, _then, _else) => + "(if { \{cond} } then { \{_then} } else { \{_else}) }" + } +} + +pub fn PreCps::get_type(self : PreCps) -> T { + match self { + Unit => Unit + Int(_) => Int + Double(_) => Double + LetRec(_, _, _, _, inner) => inner.get_type() + Let(_, _, _, inner) => inner.get_type() + Var(t, _) + | Tuple(t, _) + | Prim(t, _, _) | KthTuple(t, _, _) | App(t, _, _) | If(t, _, _, _) => t + } +} + +enum Numeric { + Double + Int +} derive(Show) + +enum PrimOp { + Not + MakeArray + Neg(Numeric) + Get + Put + Math(@types.Op, Numeric) + Eq + Le +} derive(Show) diff --git a/src/precps/preprocess.mbt b/src/precps/preprocess.mbt new file mode 100644 index 0000000..e6a7558 --- /dev/null +++ b/src/precps/preprocess.mbt @@ -0,0 +1,22 @@ +pub fn inline_entry( + s : S, + ~has_main : Bool = false, + ~has_init : Bool = false +) -> S { + match s { + Unit => + match (has_main, has_init) { + (false, false) => @util.die("no entrance found") + (true, false) => App(Var("main"), []) + (false, true) => App(Var("init"), []) + (true, true) => + Let(("_", Unit), App(Var("init"), []), App(Var("main"), [])) + } + LetRec(f, rest) => { + let has_main = has_main || f.name.0 == "main" + let has_init = has_init || f.name.0 == "init" + LetRec(f, inline_entry(rest, ~has_main, ~has_init)) + } + s => @util.die("unexpected toplevel \{s}") + } +} diff --git a/src/precps/tyenv.mbt b/src/precps/tyenv.mbt new file mode 100644 index 0000000..97d2ea5 --- /dev/null +++ b/src/precps/tyenv.mbt @@ -0,0 +1,56 @@ +pub struct TyEnv { + bindings : @immut/hashmap.T[Either[Int, String], (Var, T)] + counter : Ref[Int] +} + +fn to_bind_key(item : (String, T)) -> (Either[Int, String], (Var, T)) { + let ext_name = item.0 + let ty = item.1 + (Right(ext_name), (var_of_external(ext_name), ty)) +} + +fn find_bind_key(v : Var) -> Either[Int, String] { + match v.name { + None => Left(v.id) + Some(s) => Right(s) + } +} + +pub fn TyEnv::new(externals : @immut/hashmap.T[String, T]) -> TyEnv { + let externals = externals.iter().map(to_bind_key) + |> @immut/hashmap.from_iter() + { bindings: externals, counter: { val: 0 } } +} + +pub fn TyEnv::gen_tmp(self : TyEnv) -> Var { + self.counter.val = self.counter.val + 1 + { id: self.counter.val, name: None } +} + +pub fn TyEnv::add(self : TyEnv, name : String, ty : T) -> (Var, TyEnv) { + self.counter.val = self.counter.val + 1 + let to_bind = { id: self.counter.val, name: Some(name) } + let bindings = self.bindings.add(find_bind_key(to_bind), (to_bind, ty)) + (to_bind, { ..self, bindings, }) +} + +pub fn TyEnv::add_many(self : TyEnv, args : Iter[(String, T)]) -> TyEnv { + args.fold( + init=self, + fn(acc, ele) { + let (_, env_new) = acc.add(ele.0, ele.1) + env_new + }, + ) +} + +pub fn TyEnv::find(self : TyEnv, name : String) -> PreCps? { + match self.bindings[Right(name)] { + Some(item) => Some(Var(item.1, item.0)) + None => + match self.bindings[Right("minimbt_" + name)] { + None => None + Some(item) => Some(Var(item.1, item.0)) + } + } +} diff --git a/src/precps/types.mbt b/src/precps/types.mbt new file mode 100644 index 0000000..7938f42 --- /dev/null +++ b/src/precps/types.mbt @@ -0,0 +1,10 @@ +typealias S = @types.Syntax + +typealias T = @types.Type + +typealias N = @types.Name + +enum Either[L, R] { + Left(L) + Right(R) +} derive(Hash, Eq) diff --git a/src/precps/var.mbt b/src/precps/var.mbt new file mode 100644 index 0000000..dcd7502 --- /dev/null +++ b/src/precps/var.mbt @@ -0,0 +1,19 @@ +pub struct Var { + name : String? + id : Int +} derive(Eq, Hash, Compare) + +impl Show for Var with output(self, logger) { + logger.write_string(self.to_string()) +} + +pub fn Var::to_string(self : Var) -> String { + match self.name { + None => "?\{self.id}" + Some(n) => if self.id < 0 { n } else { "\{n}.\{self.id}" } + } +} + +pub fn Var::var_of_external(ext_name : String) -> Var { + { name: Some("minimbt_" + ext_name), id: -1 } +} diff --git a/src/top.mbt b/src/top.mbt index e55f4fb..358f619 100644 --- a/src/top.mbt +++ b/src/top.mbt @@ -17,7 +17,7 @@ pub enum Syntax { Let((String, Type), Syntax, Syntax) // let _: _ = _; _ LetRec(Fundef, Syntax) // fn f() {} ; _ LetTuple(Array[(String, Type)], Syntax, Syntax) // let (_ , _) : (_, _)= _; _ - Put(Syntax, Syntax, Syntax) // lhs = rhs; rest + Put(Syntax, Syntax, Syntax) // _[_] = _ } derive(Show) pub enum Op { @@ -28,7 +28,7 @@ pub enum Op { } derive(Show, Eq) pub struct Fundef { - name : (String, Type) + name : (String, Type) // the Type stores the function type rather than the returned type args : Array[(String, Type)] body : Syntax } derive(Show)