diff --git a/notes.md b/notes.md new file mode 100644 index 0000000..4e43d7e --- /dev/null +++ b/notes.md @@ -0,0 +1,3 @@ +## Call Convention +1. We use `ra` to store the continuation address, as it's otherwise unused in our language. We do need to push it onto stack to preserve it's value when doing a native call, though. +2. We store `closure` pointer after any arguments, so we should be able to work with native functions just fine. This differs from what is being done in the book "Compiling with Continuations". diff --git a/src/bin/main.mbt b/src/bin/main.mbt index 23983a0..5d83a7e 100644 --- a/src/bin/main.mbt +++ b/src/bin/main.mbt @@ -4,9 +4,8 @@ enum Stages { Typecheck PreCps Cps - //Knf - //KnfOpt - //Closure + CloPS + // NOTE: add stages here. Asm Finished } derive(Show, Eq, Compare) @@ -17,9 +16,8 @@ fn Stages::from_string(s : String) -> Stages? { "typecheck" => Some(Stages::Typecheck) "precps" => Some(Stages::PreCps) "cps" => Some(Stages::Cps) - //"knf" => Some(Stages::Knf) - //"knf-opt" => Some(Stages::KnfOpt) - //"closure" => Some(Stages::Closure) + "clops" => Some(Stages::CloPS) + // NOTE: add stages here. "riscv" => Some(Stages::Asm) "finished" => Some(Stages::Finished) _ => None @@ -31,7 +29,9 @@ fn Stages::next(self : Stages) -> Stages { Stages::Parse => Stages::Typecheck Stages::Typecheck => Stages::PreCps Stages::PreCps => Stages::Cps - Stages::Cps => Stages::Asm + Stages::Cps => Stages::CloPS + Stages::CloPS => Stages::Asm + // NOTE: add stages here. Stages::Asm => Stages::Finished Stages::Finished => Stages::Finished } @@ -43,13 +43,11 @@ struct CompileStatus { mut source_code : String? mut ast : @types.Syntax? mut typechecked : @types.Syntax? + mut counter : Int // for unique var generation mut precps : @precps.PreCps? - mut precpsenv : @precps.TyEnv? mut cps : @cps.Cps? - //knf_env : @knf.KnfEnv - //mut knf : @knf.Knf? - //mut opt_knf : @knf.Knf? - //mut closure_ir : @closure.Program? + mut clops : @closureps.ClosurePS? + // NOTE: add stages here. mut asm : Array[@riscv.AssemblyFunction]? } @@ -64,13 +62,11 @@ fn CompileStatus::initialize( source_code: None, ast: None, typechecked: None, + counter: 0, precps: None, - precpsenv: None, cps: None, - //knf_env: @knf.KnfEnv::new(@types.externals), - //knf: None, - //opt_knf: None, - //closure_ir: None, + clops: None, + // NOTE: add stages here. asm: None, } match start_stage { @@ -78,14 +74,6 @@ fn CompileStatus::initialize( Typecheck => v.ast = Some(@types.Syntax::from_json!(@json.parse!(file))) PreCps => v.typechecked = Some(@types.Syntax::from_json!(@json.parse!(file))) - //KnfOpt => { - // v.knf = Some(@knf.Knf::from_json!(@json.parse!(file))) - // v.knf_env.init_counter_from_existing(v.knf.unwrap()) - //} - //Closure => { - // v.opt_knf = Some(@knf.Knf::from_json!(@json.parse!(file))) - // v.knf_env.init_counter_from_existing(v.opt_knf.unwrap()) - //} _ => fail!("invalid start stage") } v @@ -110,33 +98,21 @@ fn CompileStatus::step(self : CompileStatus) -> Bool { let tyenv = @precps.TyEnv::new(@types.externals) let entry_inlined = @precps.inline_entry(self.typechecked.unwrap()) self.precps = Some(tyenv.ast2precps(entry_inlined)) - self.precpsenv = Some(tyenv) + self.counter = tyenv.counter.val } Cps => { - let cpsenv = @cps.CpsEnv::new(self.precpsenv.unwrap().counter.val) + let cpsenv = @cps.CpsEnv::new(self.counter) let mut cps = cpsenv.precps2cps(self.precps.unwrap(), @cps.Cps::Just) cps = @cps.alias_analysis(cps) self.cps = Some(cps) + self.counter = cpsenv.counter.val } - //Knf => { - // let preprocessed = self.knf_env.syntax_preprocess( - // self.typechecked.unwrap(), - // ) - // let knf = self.knf_env.to_knf(preprocessed) - // self.knf = Some(knf) - //} - //KnfOpt => { - // let knf = self.knf.unwrap() - // // TODO: optimize - // self.opt_knf = Some(knf) - //} - //Closure => { - // let closure_ir = @closure.knf_program_to_closure( - // self.opt_knf.unwrap(), - // Map::from_iter(@types.externals.iter()), - // ) - // self.closure_ir = Some(closure_ir) - //} + CloPS => { + let clops = @closureps.cps2clops(self.counter, self.cps.unwrap()) + self.clops = Some(clops) + self.counter = clops.counter.val + } + // NOTE: add stages here. Asm => { let real_asm = @riscv.emit(@util.die("TODO5")) self.asm = Some(real_asm) @@ -154,9 +130,8 @@ fn CompileStatus::output(self : CompileStatus, json : Bool) -> String { Typecheck => @json.stringify(self.ast.unwrap().to_json()) PreCps => @util.die("TODO1") Cps => @util.die("TODO2") - //Knf => @json.stringify(self.typechecked.unwrap().to_json()) - //KnfOpt => @json.stringify(self.knf.unwrap().to_json()) - //Closure => @json.stringify(self.opt_knf.unwrap().to_json()) + CloPS => @util.die("TODO999") + // NOTE: add stages here. Asm => @util.die("TODO3") Finished => @riscv.print_functions(self.asm.unwrap()) } @@ -166,11 +141,9 @@ fn CompileStatus::output(self : CompileStatus, json : Bool) -> String { Typecheck => self.ast.unwrap().to_string() PreCps => self.typechecked.unwrap().to_string() Cps => self.precps.unwrap().to_string() - //Knf => self.typechecked.unwrap().to_string() - //KnfOpt => self.knf.unwrap().to_string() - //Closure => self.opt_knf.unwrap().to_string() - Asm => self.cps.unwrap().to_string() - //self.closure_ir.unwrap().to_string() + CloPS => self.cps.unwrap().to_string() + // NOTE: add stages here. + Asm => self.clops.unwrap().to_string() Finished => @riscv.print_functions(self.asm.unwrap()) } } diff --git a/src/bin/moon.pkg.json b/src/bin/moon.pkg.json index cdeb84d..936d6d8 100644 --- a/src/bin/moon.pkg.json +++ b/src/bin/moon.pkg.json @@ -11,6 +11,7 @@ "alias": "types" }, "moonbitlang/minimbt/precps", + "moonbitlang/minimbt/closureps", "moonbitlang/minimbt/cps", "moonbitlang/minimbt/knf", "moonbitlang/minimbt/typing", diff --git a/src/closureps/cloenv.mbt b/src/closureps/cloenv.mbt new file mode 100644 index 0000000..84ef139 --- /dev/null +++ b/src/closureps/cloenv.mbt @@ -0,0 +1,28 @@ +struct CloEnv { + // NOTE: fundef's type's arg's length should be one more than FuncDef.args, as + // we use the last slot for closure. + // and the types is recursive, BTW + fnblocks : @hashmap.T[Var, FuncDef] + counter : Ref[Int] + bindings : @immut/hashmap.T[Var, Var] +} + +fn CloEnv::add_rebind(self : CloEnv, name : Var, cb : Var) -> CloEnv { + { ..self, bindings: self.bindings.add(name, cb) } +} + +fn CloEnv::new(counter : Int) -> CloEnv { + let counter = { val: counter } + { fnblocks: @hashmap.new(), counter, bindings: @immut/hashmap.new() } +} + +fn CloEnv::new_tmp(self : CloEnv, t : T) -> Var { + self.counter.val = self.counter.val + 1 + { name: { val: None }, id: self.counter.val, ty: t } +} + +// NOTE: no worry of repeated names generated as all vars are marked by an uid +fn CloEnv::new_named(self : CloEnv, name : String, t : T) -> Var { + self.counter.val = self.counter.val + 1 + { name: { val: Some(name) }, id: self.counter.val, ty: t } +} diff --git a/src/closureps/cps2closureps.mbt b/src/closureps/cps2closureps.mbt new file mode 100644 index 0000000..b617b06 --- /dev/null +++ b/src/closureps/cps2closureps.mbt @@ -0,0 +1,124 @@ +// replace any function vars +fn CloEnv::rebind_var(self : CloEnv, v : Var) -> Var { + match self.bindings[v] { + None => v + Some(wrapped) => wrapped + } +} + +fn CloEnv::rebind_value(self : CloEnv, v : Value) -> Value { + match v { + Var(v) => Var(self.rebind_var(v)) + v => v + } +} + +// collect all closures to top level and fix call convention +fn CloEnv::collect_closure(self : CloEnv, s : S) -> S { + fn rec(c : S) { + self.collect_closure(c) + } + + fn recrbva(v : Value) { + self.rebind_value(v) + } + + match s { + Tuple(record, bind, rest) => + Tuple( + record.map(recrbva), + bind, + // NOTE: the reason for add all bindings is to shadow any closure rebind + // so it doesn't accidentally rebind too muach than it should + self.add_rebind(bind, bind).collect_closure(rest), + ) + KthTuple(idx, v, bind, rest) => + KthTuple(idx, v, bind, self.add_rebind(bind, bind).collect_closure(rest)) + Switch(v, branches) => Switch(v, branches.map(rec)) + Prim(op, args, bind, rest) => + Prim(op, args, bind, self.add_rebind(bind, bind).collect_closure(rest)) + Fix(f, args, body, rest) => { + // Step 1. Calculate free variables of body + let fvs = body.free_variables() + fvs.remove(f) + args.each(fn { a => fvs.remove(a) }) + let free_vars = fvs.iter().collect() + + // Step 2. Calculate the free variable tuple we need to pass inside the + // closure + let fv_data_ty = match free_vars { + [] => T::Unit + _ => T::Tuple(free_vars.map(fn { v => v.ty })) + } + + // this is the closure passed into the function + let closure_ref = self.new_named( + "closure_ref_\{f.to_string()}", + T::Tuple([f.ty, fv_data_ty]), + ) + + // fix the type of f to accept an additional closure arg at the end + guard let T::Fun(args_ty, _) = f.ty else { + _ => @util.die("calling non function") + // NOTE: the following alters f's type + } + args_ty.push(closure_ref.ty) // WARN: after this operation our ds is now self-recursive + let body_fixed = match free_vars { + [] => rec(body) + _ => { + let fn_ptr = self.new_named("fn_ptr", f.ty) + let freevars = self.new_named("freevars", fv_data_ty) + let body_to_wrap = self + .add_rebind(f, closure_ref) + .collect_closure(body) + let body_with_freevars_bound = free_vars.foldi( + init=body_to_wrap, + fn(idx, acc, ele) { KthTuple(idx, Var(freevars), ele, acc) }, + ) + KthTuple( + 0, + Var(closure_ref), + fn_ptr, + KthTuple(1, Var(closure_ref), freevars, body_with_freevars_bound), + ) + } + } + self.fnblocks[f] = { args, free_vars, body: body_fixed } + let freevars_captured = self.new_named("freevars_captured", fv_data_ty) + let closure_gen = self.new_named( + "closure_\{f.to_string()}", + T::Tuple([f.ty, fv_data_ty]), + ) + let rest_fixed = self.add_rebind(f, closure_gen).collect_closure(rest) + Tuple( + free_vars.map(Value::Var), + freevars_captured, + Tuple([Label(f), Var(freevars_captured)], closure_gen, rest_fixed), + ) + } + App(f, args) => + match f { + Var(f_var) => + // NOTE: always generate a call as if we're calling a closure. + // Since there's no way for us to decide whether we're calling a + // closure or not. + match self.bindings[f_var] { + Some(maybe_closure) => { + args.push(Var(maybe_closure)) + // we know the called function statically, so we're allowed to mark + // it as a label + App(Label(f_var), args) + } + None => { + args.push(f) + App(f, args) + } + } + // NOTE: must be a native call + // there's no guarantee we always use this case for all native calls + Label(_) => App(f, args.map(recrbva)) + _ => @util.die("Can't invoke call on \{f}") + } + Just(v) => Just(recrbva(v)) + } +} diff --git a/src/closureps/funcdef.mbt b/src/closureps/funcdef.mbt new file mode 100644 index 0000000..67e06f2 --- /dev/null +++ b/src/closureps/funcdef.mbt @@ -0,0 +1,8 @@ +// we don't store the return value type as there's no return in CPS +pub struct FuncDef { + args : Array[Var] + free_vars : Array[Var] + // closure is a tuple of function pointer + // and a tuple of free variables + body : S +} diff --git a/src/closureps/interface.mbt b/src/closureps/interface.mbt new file mode 100644 index 0000000..c29e510 --- /dev/null +++ b/src/closureps/interface.mbt @@ -0,0 +1,13 @@ +pub struct ClosurePS { + fnblocks : @hashmap.T[Var, FuncDef] + root : S + counter : Ref[Int] +} + +pub fn cps2clops(cnt : Int, s : S) -> ClosurePS { + let env = CloEnv::new(cnt) + let root = env.collect_closure(s) + let counter = env.counter + let fnblocks = env.fnblocks + { fnblocks, root, counter } +} diff --git a/src/closureps/moon.pkg.json b/src/closureps/moon.pkg.json new file mode 100644 index 0000000..2874c1e --- /dev/null +++ b/src/closureps/moon.pkg.json @@ -0,0 +1,16 @@ +{ + "import": [ + { + "path": "moonbitlang/minimbt", + "alias": "top" + }, + "moonbitlang/minimbt/util", + "moonbitlang/minimbt/precps", + "moonbitlang/minimbt/cps" + ], + "test-import": [ + "moonbitlang/minimbt/parser", + "moonbitlang/minimbt/lex", + "moonbitlang/minimbt/typing" + ] +} diff --git a/src/closureps/show.mbt b/src/closureps/show.mbt new file mode 100644 index 0000000..5ab7ee8 --- /dev/null +++ b/src/closureps/show.mbt @@ -0,0 +1,14 @@ +impl Show for ClosurePS with output(self, logger) { + logger.write_string(self.to_string()) +} + +pub fn ClosurePS::to_string(self : ClosurePS) -> String { + let mut output = "" + for item in self.fnblocks.iter() { + let (name, def) = item + output += "[\{name}], args: \{def.args}, freevars: \{def.free_vars}\n" + output += "\{def.body}\n\n" + } + output += "[root]\n\{self.root}\n" + output +} diff --git a/src/closureps/types.mbt b/src/closureps/types.mbt new file mode 100644 index 0000000..4974874 --- /dev/null +++ b/src/closureps/types.mbt @@ -0,0 +1,14 @@ +typealias T = @top.Type + +typealias PrimOp = @precps.PrimOp + +typealias S = @cps.Cps + +typealias Var = @cps.Var + +typealias Value = @cps.Value + +enum Either[L, R] { + Left(L) + Right(R) +} diff --git a/src/cps/cps_env.mbt b/src/cps/cps_env.mbt index 85b9600..1757317 100644 --- a/src/cps/cps_env.mbt +++ b/src/cps/cps_env.mbt @@ -1,4 +1,4 @@ -struct CpsEnv { +pub struct CpsEnv { counter : Ref[Int] } @@ -10,3 +10,8 @@ fn CpsEnv::new_tmp(self : CpsEnv, t : T) -> Var { self.counter.val = self.counter.val + 1 { name: { val: None }, id: self.counter.val, ty: t } } + +fn CpsEnv::new_named(self : CpsEnv, s : String, t : T) -> Var { + self.counter.val = self.counter.val + 1 + { name: { val: Some(s) }, id: self.counter.val, ty: t } +} diff --git a/src/cps/cps_ir.mbt b/src/cps/cps_ir.mbt index 1affa83..e184b5a 100644 --- a/src/cps/cps_ir.mbt +++ b/src/cps/cps_ir.mbt @@ -1,10 +1,14 @@ -struct Var { +pub struct Var { name : Ref[String?] id : Int ty : T } -fn Var::op_equal(lhs : Var, rhs : Var) -> Bool { +pub fn Var::hash_combine(self : Var, hasher : Hasher) -> Unit { + hasher.combine(self.id) +} + +pub fn Var::op_equal(lhs : Var, rhs : Var) -> Bool { lhs.id == rhs.id } @@ -12,7 +16,7 @@ fn Var::from_precps(v : @precps.Var, t : T) -> Var { { id: v.id, name: { val: v.name }, ty: t } } -enum Value { +pub enum Value { Var(Var) Label(Var) Unit @@ -27,16 +31,10 @@ fn Value::replace_var_bind(self : Value, from : Var, to : Value) -> Value { } } -enum AccessPath { - OffP(Int) - SelP(Int, AccessPath) -} derive(Show) - pub enum Cps { // T marks the binding's type - Record(Array[(Value, AccessPath)], Var, Cps) - Select(Int, Value, Var, Cps) - Offset(Int, Value, Var, Cps) + Tuple(Array[Value], Var, Cps) + KthTuple(Int, Value, Var, Cps) Fix(Var, Array[Var], Cps, Cps) Switch(Value, Array[Cps]) Prim(PrimOp, Array[Value], Var, Cps) @@ -55,17 +53,13 @@ fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps { } match self { - Record(record, bind, rest) => { - let rest_new = if from != bind { rec(rest) } else { rest } - Record(record.map(fn { (v, path) => (recv(v), path) }), bind, rest_new) - } - Select(idx, v, bind, rest) => { + Tuple(record, bind, rest) => { let rest_new = if from != bind { rec(rest) } else { rest } - Select(idx, recv(v), bind, rest_new) + Tuple(record.map(recv), bind, rest_new) } - Offset(idx, v, bind, rest) => { + KthTuple(idx, v, bind, rest) => { let rest_new = if from != bind { rec(rest) } else { rest } - Offset(idx, recv(v), bind, rest_new) + KthTuple(idx, recv(v), bind, rest_new) } Fix(name, args, body, rest) => { let body_new = if from != name && not(args.contains(from)) { @@ -82,5 +76,52 @@ fn Cps::replace_var_bind(self : Cps, from : Var, to : Value) -> Cps { Prim(op, args.map(recv), bind, rest_new) } App(f, args) => App(recv(f), args.map(recv)) + Just(v) => Just(recv(v)) + } +} + +fn Value::free_variables(self : Value) -> @hashset.T[Var] { + match self { + Var(v) => @hashset.of([v]) + _ => @hashset.new() + } +} + +pub fn Cps::free_variables(self : Cps) -> @hashset.T[Var] { + match self { + Tuple(record, bind, rest) => { + let fvs = rest.free_variables() + fvs.remove(bind) + record.fold(init=fvs, fn(acc, ele) { acc.union(ele.free_variables()) }) + } + KthTuple(_, v, bind, rest) => { + let fvs = rest.free_variables() + fvs.remove(bind) + fvs.union(v.free_variables()) + } + Fix(name, args, body, rest) => { + let fv_rest = rest.free_variables() + fv_rest.remove(name) + let fv_body = body.free_variables() + fv_body.remove(name) + args.each(fn { a => fv_body.remove(a) }) + fv_body.union(fv_rest) + } + Switch(v, branches) => + branches.fold( + init=v.free_variables(), + fn(acc, ele) { acc.union(ele.free_variables()) }, + ) + Prim(_, args, bind, rest) => { + let fv_rest = rest.free_variables() + fv_rest.remove(bind) + args.fold(init=fv_rest, fn(acc, ele) { acc.union(ele.free_variables()) }) + } + App(f, args) => + args.fold( + init=f.free_variables(), + fn(acc, ele) { acc.union(ele.free_variables()) }, + ) + Just(v) => v.free_variables() } } diff --git a/src/cps/cps_ir_string.mbt b/src/cps/cps_ir_string.mbt index 29a197b..7123c17 100644 --- a/src/cps/cps_ir_string.mbt +++ b/src/cps/cps_ir_string.mbt @@ -1,4 +1,4 @@ -impl Show for Var with output(self, logger) { +pub impl Show for Var with output(self, logger) { logger.write_string(self.to_string()) } @@ -9,7 +9,7 @@ pub fn Var::to_string(self : Var) -> String { } } -impl Show for Value with output(self, logger) { +pub impl Show for Value with output(self, logger) { logger.write_string(self.to_string()) } @@ -37,9 +37,9 @@ fn to_str(cps : Cps, ~ident : String = "") -> String { } match cps { - Record(arr, bind, rest) => ident + "\{bind} = \{arr}\n" + rec(rest) - Select(idx, v, bind, rest) => ident + "\{bind} = \{v}[\{idx}]\n" + rec(rest) - Offset(idx, v, bind, rest) => ident + "\{bind} = \{v}+\{idx}\n" + rec(rest) + Tuple(arr, bind, rest) => ident + "\{bind} = \{arr}\n" + rec(rest) + KthTuple(idx, v, bind, rest) => + ident + "\{bind} = \{v}.\{idx}\n" + rec(rest) Fix(name, args, body, rest) => ident + "fn \{name}(\{args}) {\n" + diff --git a/src/cps/optimizations.mbt b/src/cps/optimizations.mbt index cd13d47..ae939ed 100644 --- a/src/cps/optimizations.mbt +++ b/src/cps/optimizations.mbt @@ -1,9 +1,8 @@ pub fn alias_analysis(c : Cps) -> Cps { let rec = alias_analysis match c { - Record(arr, bind, inner) => Record(arr, bind, rec(inner)) - Select(idx, v, bind, inner) => Select(idx, v, bind, rec(inner)) - Offset(idx, v, bind, inner) => Offset(idx, v, bind, rec(inner)) + Tuple(arr, bind, inner) => Tuple(arr, bind, rec(inner)) + KthTuple(idx, v, bind, inner) => KthTuple(idx, v, bind, rec(inner)) Fix(f1, args1, App(f2, args2), rest) => { let args1_fix = args1.map(fn { v => Var(v) }) if args1_fix == args2 { diff --git a/src/cps/precps2cps.mbt b/src/cps/precps2cps.mbt index 4297b71..93da96e 100644 --- a/src/cps/precps2cps.mbt +++ b/src/cps/precps2cps.mbt @@ -53,7 +53,7 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { new_vars.push(var) } // create the wrapper - let k_ref = self.new_tmp(k_type) + let k_ref = self.new_named("kont_\{fn_name}", k_type) new_vars.push(k_ref) let f_ref = Var::from_precps(fn_name, new_f_type) fn fn_cont(returned : Value) { @@ -82,11 +82,12 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { Fix(k_ref, [x_ref], c(Var(x_ref)), self.precps2cps(f, c1)) } Var(ty, v) => c(Var(Var::from_precps(v, ty))) + Label(ty, v) => c(Label(Var::from_precps(v, ty))) Tuple(tup_ty, elements) => { fn c1(vs : @immut/list.T[Value]) { let tmp = self.new_tmp(tup_ty) - let record_inner = vs.iter().map(fn { v => (v, OffP(0)) }).collect() - Record(record_inner, tmp, c(Var(tmp))) + let record_inner = vs.iter().collect() + Tuple(record_inner, tmp, c(Var(tmp))) } self.precps2cps_list(elements, c1) @@ -102,7 +103,7 @@ pub fn CpsEnv::precps2cps(self : CpsEnv, s : P, c : Cont) -> Cps { KthTuple(ret_ty, offset, tup) => { fn c1(v : Value) { let tmp = self.new_tmp(ret_ty) - Select(offset, v, tmp, c(Var(tmp))) + KthTuple(offset, v, tmp, c(Var(tmp))) } self.precps2cps(tup, c1) diff --git a/src/precps/precps_ir.mbt b/src/precps/precps_ir.mbt index 4f5b317..cb61d06 100644 --- a/src/precps/precps_ir.mbt +++ b/src/precps/precps_ir.mbt @@ -13,6 +13,7 @@ pub enum PreCps { Let(T, Var, PreCps, PreCps) // T marks return type Var(T, Var) + Label(T, Var) Tuple(T, Array[PreCps]) Prim(T, PrimOp, Array[PreCps]) KthTuple(T, Int, PreCps) @@ -33,6 +34,7 @@ pub fn PreCps::to_string(self : PreCps) -> String { "letrec \{name}(\{args}){\{body}} in \{rest}" Let(_, name, rhs, rest) => "let \{name} = \{rhs} in \{rest}" Var(_, v) => v.to_string() + Label(_, v) => ":" + v.to_string() Tuple(_, tup) => tup.to_string() KthTuple(_, idx, tup) => "(\{tup}).\{idx}" App(_, f, args) => "(\{f} \{args})" @@ -50,6 +52,7 @@ pub fn PreCps::get_type(self : PreCps) -> T { LetRec(_, _, _, _, inner) => inner.get_type() Let(_, _, _, inner) => inner.get_type() Var(t, _) + | Label(t, _) | Tuple(t, _) | Prim(t, _, _) | KthTuple(t, _, _) | App(t, _, _) | If(t, _, _, _) => t } diff --git a/src/precps/tyenv.mbt b/src/precps/tyenv.mbt index 82b415c..e37b888 100644 --- a/src/precps/tyenv.mbt +++ b/src/precps/tyenv.mbt @@ -1,12 +1,12 @@ pub struct TyEnv { - bindings : @immut/hashmap.T[Either[Int, String], (Var, T)] + bindings : @immut/hashmap.T[Either[Int, String], PreCps] counter : Ref[Int] } -fn to_bind_key(item : (String, T)) -> (Either[Int, String], (Var, T)) { +fn to_bind_key_label(item : (String, T)) -> (Either[Int, String], PreCps) { let ext_name = item.0 let ty = item.1 - (Right(ext_name), (var_of_external(ext_name), ty)) + (Right(ext_name), Label(ty, Var::var_of_external(ext_name))) } fn find_bind_key(v : Var) -> Either[Int, String] { @@ -17,7 +17,7 @@ fn find_bind_key(v : Var) -> Either[Int, String] { } pub fn TyEnv::new(externals : @immut/hashmap.T[String, T]) -> TyEnv { - let externals = externals.iter().map(to_bind_key) + let externals = externals.iter().map(to_bind_key_label) |> @immut/hashmap.from_iter() { bindings: externals, counter: { val: 0 } } } @@ -30,7 +30,7 @@ pub fn TyEnv::gen_tmp(self : TyEnv) -> Var { pub fn TyEnv::add(self : TyEnv, name : String, ty : T) -> (Var, TyEnv) { self.counter.val = self.counter.val + 1 let to_bind = { id: self.counter.val, name: Some(name) } - let bindings = self.bindings.add(find_bind_key(to_bind), (to_bind, ty)) + let bindings = self.bindings.add(find_bind_key(to_bind), Var(ty, to_bind)) (to_bind, { ..self, bindings, }) } @@ -45,12 +45,5 @@ pub fn TyEnv::add_many(self : TyEnv, args : Iter[(String, T)]) -> TyEnv { } pub fn TyEnv::find(self : TyEnv, name : String) -> PreCps? { - match self.bindings[Right(name)] { - Some(item) => Some(Var(item.1, item.0)) - None => - match self.bindings[Right(name)] { - None => None - Some(item) => Some(Var(item.1, item.0)) - } - } + self.bindings[Right(name)] }