Skip to content

Commit

Permalink
implement CPS & Closure Conversion on CPS
Browse files Browse the repository at this point in the history
  • Loading branch information
glyh committed Oct 6, 2024
1 parent e082182 commit 5807a5c
Show file tree
Hide file tree
Showing 17 changed files with 335 additions and 99 deletions.
3 changes: 3 additions & 0 deletions notes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Call Convention
1. We use `ra` to store the continuation address, as it's otherwise unused in our language. We do need to push it onto stack to preserve it's value when doing a native call, though.
2. We store `closure` pointer after any arguments, so we should be able to work with native functions just fine. This differs from what is being done in the book "Compiling with Continuations".
81 changes: 27 additions & 54 deletions src/bin/main.mbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@ enum Stages {
Typecheck
PreCps
Cps
//Knf
//KnfOpt
//Closure
CloPS
// NOTE: add stages here.
Asm
Finished
} derive(Show, Eq, Compare)
Expand All @@ -17,9 +16,8 @@ fn Stages::from_string(s : String) -> Stages? {
"typecheck" => Some(Stages::Typecheck)
"precps" => Some(Stages::PreCps)
"cps" => Some(Stages::Cps)
//"knf" => Some(Stages::Knf)
//"knf-opt" => Some(Stages::KnfOpt)
//"closure" => Some(Stages::Closure)
"clops" => Some(Stages::CloPS)
// NOTE: add stages here.
"riscv" => Some(Stages::Asm)
"finished" => Some(Stages::Finished)
_ => None
Expand All @@ -31,7 +29,9 @@ fn Stages::next(self : Stages) -> Stages {
Stages::Parse => Stages::Typecheck
Stages::Typecheck => Stages::PreCps
Stages::PreCps => Stages::Cps
Stages::Cps => Stages::Asm
Stages::Cps => Stages::CloPS
Stages::CloPS => Stages::Asm
// NOTE: add stages here.
Stages::Asm => Stages::Finished
Stages::Finished => Stages::Finished
}
Expand All @@ -43,13 +43,11 @@ struct CompileStatus {
mut source_code : String?
mut ast : @types.Syntax?
mut typechecked : @types.Syntax?
mut counter : Int // for unique var generation
mut precps : @precps.PreCps?
mut precpsenv : @precps.TyEnv?
mut cps : @cps.Cps?
//knf_env : @knf.KnfEnv
//mut knf : @knf.Knf?
//mut opt_knf : @knf.Knf?
//mut closure_ir : @closure.Program?
mut clops : @closureps.ClosurePS?
// NOTE: add stages here.
mut asm : Array[@riscv.AssemblyFunction]?
}

Expand All @@ -64,28 +62,18 @@ fn CompileStatus::initialize(
source_code: None,
ast: None,
typechecked: None,
counter: 0,
precps: None,
precpsenv: None,
cps: None,
//knf_env: @knf.KnfEnv::new(@types.externals),
//knf: None,
//opt_knf: None,
//closure_ir: None,
clops: None,
// NOTE: add stages here.
asm: None,
}
match start_stage {
Parse => v.source_code = Some(file)
Typecheck => v.ast = Some(@types.Syntax::from_json!(@json.parse!(file)))
PreCps =>
v.typechecked = Some(@types.Syntax::from_json!(@json.parse!(file)))
//KnfOpt => {
// v.knf = Some(@knf.Knf::from_json!(@json.parse!(file)))
// v.knf_env.init_counter_from_existing(v.knf.unwrap())
//}
//Closure => {
// v.opt_knf = Some(@knf.Knf::from_json!(@json.parse!(file)))
// v.knf_env.init_counter_from_existing(v.opt_knf.unwrap())
//}
_ => fail!("invalid start stage")
}
v
Expand All @@ -110,33 +98,21 @@ fn CompileStatus::step(self : CompileStatus) -> Bool {
let tyenv = @precps.TyEnv::new(@types.externals)
let entry_inlined = @precps.inline_entry(self.typechecked.unwrap())
self.precps = Some(tyenv.ast2precps(entry_inlined))
self.precpsenv = Some(tyenv)
self.counter = tyenv.counter.val
}
Cps => {
let cpsenv = @cps.CpsEnv::new(self.precpsenv.unwrap().counter.val)
let cpsenv = @cps.CpsEnv::new(self.counter)
let mut cps = cpsenv.precps2cps(self.precps.unwrap(), @cps.Cps::Just)
cps = @cps.alias_analysis(cps)
self.cps = Some(cps)
self.counter = cpsenv.counter.val
}
//Knf => {
// let preprocessed = self.knf_env.syntax_preprocess(
// self.typechecked.unwrap(),
// )
// let knf = self.knf_env.to_knf(preprocessed)
// self.knf = Some(knf)
//}
//KnfOpt => {
// let knf = self.knf.unwrap()
// // TODO: optimize
// self.opt_knf = Some(knf)
//}
//Closure => {
// let closure_ir = @closure.knf_program_to_closure(
// self.opt_knf.unwrap(),
// Map::from_iter(@types.externals.iter()),
// )
// self.closure_ir = Some(closure_ir)
//}
CloPS => {
let clops = @closureps.cps2clops(self.counter, self.cps.unwrap())
self.clops = Some(clops)
self.counter = clops.counter.val
}
// NOTE: add stages here.
Asm => {
let real_asm = @riscv.emit(@util.die("TODO5"))
self.asm = Some(real_asm)
Expand All @@ -154,9 +130,8 @@ fn CompileStatus::output(self : CompileStatus, json : Bool) -> String {
Typecheck => @json.stringify(self.ast.unwrap().to_json())
PreCps => @util.die("TODO1")
Cps => @util.die("TODO2")
//Knf => @json.stringify(self.typechecked.unwrap().to_json())
//KnfOpt => @json.stringify(self.knf.unwrap().to_json())
//Closure => @json.stringify(self.opt_knf.unwrap().to_json())
CloPS => @util.die("TODO999")
// NOTE: add stages here.
Asm => @util.die("TODO3")
Finished => @riscv.print_functions(self.asm.unwrap())
}
Expand All @@ -166,11 +141,9 @@ fn CompileStatus::output(self : CompileStatus, json : Bool) -> String {
Typecheck => self.ast.unwrap().to_string()
PreCps => self.typechecked.unwrap().to_string()
Cps => self.precps.unwrap().to_string()
//Knf => self.typechecked.unwrap().to_string()
//KnfOpt => self.knf.unwrap().to_string()
//Closure => self.opt_knf.unwrap().to_string()
Asm => self.cps.unwrap().to_string()
//self.closure_ir.unwrap().to_string()
CloPS => self.cps.unwrap().to_string()
// NOTE: add stages here.
Asm => self.clops.unwrap().to_string()
Finished => @riscv.print_functions(self.asm.unwrap())
}
}
Expand Down
1 change: 1 addition & 0 deletions src/bin/moon.pkg.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"alias": "types"
},
"moonbitlang/minimbt/precps",
"moonbitlang/minimbt/closureps",
"moonbitlang/minimbt/cps",
"moonbitlang/minimbt/knf",
"moonbitlang/minimbt/typing",
Expand Down
28 changes: 28 additions & 0 deletions src/closureps/cloenv.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
struct CloEnv {
// NOTE: fundef's type's arg's length should be one more than FuncDef.args, as
// we use the last slot for closure.
// and the types is recursive, BTW
fnblocks : @hashmap.T[Var, FuncDef]
counter : Ref[Int]
bindings : @immut/hashmap.T[Var, Var]
}

fn CloEnv::add_rebind(self : CloEnv, name : Var, cb : Var) -> CloEnv {
{ ..self, bindings: self.bindings.add(name, cb) }
}

fn CloEnv::new(counter : Int) -> CloEnv {
let counter = { val: counter }
{ fnblocks: @hashmap.new(), counter, bindings: @immut/hashmap.new() }
}

fn CloEnv::new_tmp(self : CloEnv, t : T) -> Var {
self.counter.val = self.counter.val + 1
{ name: { val: None }, id: self.counter.val, ty: t }
}

// NOTE: no worry of repeated names generated as all vars are marked by an uid
fn CloEnv::new_named(self : CloEnv, name : String, t : T) -> Var {
self.counter.val = self.counter.val + 1
{ name: { val: Some(name) }, id: self.counter.val, ty: t }
}
124 changes: 124 additions & 0 deletions src/closureps/cps2closureps.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// replace any function vars
fn CloEnv::rebind_var(self : CloEnv, v : Var) -> Var {
match self.bindings[v] {
None => v
Some(wrapped) => wrapped
}
}

fn CloEnv::rebind_value(self : CloEnv, v : Value) -> Value {
match v {
Var(v) => Var(self.rebind_var(v))
v => v
}
}

// collect all closures to top level and fix call convention
fn CloEnv::collect_closure(self : CloEnv, s : S) -> S {
fn rec(c : S) {
self.collect_closure(c)
}

fn recrbva(v : Value) {
self.rebind_value(v)
}

match s {
Tuple(record, bind, rest) =>
Tuple(
record.map(recrbva),
bind,
// NOTE: the reason for add all bindings is to shadow any closure rebind
// so it doesn't accidentally rebind too muach than it should
self.add_rebind(bind, bind).collect_closure(rest),
)
KthTuple(idx, v, bind, rest) =>
KthTuple(idx, v, bind, self.add_rebind(bind, bind).collect_closure(rest))
Switch(v, branches) => Switch(v, branches.map(rec))
Prim(op, args, bind, rest) =>
Prim(op, args, bind, self.add_rebind(bind, bind).collect_closure(rest))
Fix(f, args, body, rest) => {
// Step 1. Calculate free variables of body
let fvs = body.free_variables()
fvs.remove(f)
args.each(fn { a => fvs.remove(a) })
let free_vars = fvs.iter().collect()

// Step 2. Calculate the free variable tuple we need to pass inside the
// closure
let fv_data_ty = match free_vars {
[] => T::Unit
_ => T::Tuple(free_vars.map(fn { v => v.ty }))
}

// this is the closure passed into the function
let closure_ref = self.new_named(
"closure_ref_\{f.to_string()}",
T::Tuple([f.ty, fv_data_ty]),
)

// fix the type of f to accept an additional closure arg at the end
guard let T::Fun(args_ty, _) = f.ty else {
_ => @util.die("calling non function")
// NOTE: the following alters f's type
}
args_ty.push(closure_ref.ty) // WARN: after this operation our ds is now self-recursive
let body_fixed = match free_vars {
[] => rec(body)
_ => {
let fn_ptr = self.new_named("fn_ptr", f.ty)
let freevars = self.new_named("freevars", fv_data_ty)
let body_to_wrap = self
.add_rebind(f, closure_ref)
.collect_closure(body)
let body_with_freevars_bound = free_vars.foldi(
init=body_to_wrap,
fn(idx, acc, ele) { KthTuple(idx, Var(freevars), ele, acc) },
)
KthTuple(
0,
Var(closure_ref),
fn_ptr,
KthTuple(1, Var(closure_ref), freevars, body_with_freevars_bound),
)
}
}
self.fnblocks[f] = { args, free_vars, body: body_fixed }
let freevars_captured = self.new_named("freevars_captured", fv_data_ty)
let closure_gen = self.new_named(
"closure_\{f.to_string()}",
T::Tuple([f.ty, fv_data_ty]),
)
let rest_fixed = self.add_rebind(f, closure_gen).collect_closure(rest)
Tuple(
free_vars.map(Value::Var),
freevars_captured,
Tuple([Label(f), Var(freevars_captured)], closure_gen, rest_fixed),
)
}
App(f, args) =>
match f {
Var(f_var) =>
// NOTE: always generate a call as if we're calling a closure.
// Since there's no way for us to decide whether we're calling a
// closure or not.
match self.bindings[f_var] {
Some(maybe_closure) => {
args.push(Var(maybe_closure))
// we know the called function statically, so we're allowed to mark
// it as a label
App(Label(f_var), args)
}
None => {
args.push(f)
App(f, args)
}
}
// NOTE: must be a native call
// there's no guarantee we always use this case for all native calls
Label(_) => App(f, args.map(recrbva))
_ => @util.die("Can't invoke call on \{f}")
}
Just(v) => Just(recrbva(v))
}
}
8 changes: 8 additions & 0 deletions src/closureps/funcdef.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// we don't store the return value type as there's no return in CPS
pub struct FuncDef {
args : Array[Var]
free_vars : Array[Var]
// closure is a tuple of function pointer
// and a tuple of free variables
body : S
}
13 changes: 13 additions & 0 deletions src/closureps/interface.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pub struct ClosurePS {
fnblocks : @hashmap.T[Var, FuncDef]
root : S
counter : Ref[Int]
}

pub fn cps2clops(cnt : Int, s : S) -> ClosurePS {
let env = CloEnv::new(cnt)
let root = env.collect_closure(s)
let counter = env.counter
let fnblocks = env.fnblocks
{ fnblocks, root, counter }
}
16 changes: 16 additions & 0 deletions src/closureps/moon.pkg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"import": [
{
"path": "moonbitlang/minimbt",
"alias": "top"
},
"moonbitlang/minimbt/util",
"moonbitlang/minimbt/precps",
"moonbitlang/minimbt/cps"
],
"test-import": [
"moonbitlang/minimbt/parser",
"moonbitlang/minimbt/lex",
"moonbitlang/minimbt/typing"
]
}
14 changes: 14 additions & 0 deletions src/closureps/show.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
impl Show for ClosurePS with output(self, logger) {
logger.write_string(self.to_string())
}

pub fn ClosurePS::to_string(self : ClosurePS) -> String {
let mut output = ""
for item in self.fnblocks.iter() {
let (name, def) = item
output += "[\{name}], args: \{def.args}, freevars: \{def.free_vars}\n"
output += "\{def.body}\n\n"
}
output += "[root]\n\{self.root}\n"
output
}
14 changes: 14 additions & 0 deletions src/closureps/types.mbt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
typealias T = @top.Type

typealias PrimOp = @precps.PrimOp

typealias S = @cps.Cps

typealias Var = @cps.Var

typealias Value = @cps.Value

enum Either[L, R] {
Left(L)
Right(R)
}
Loading

0 comments on commit 5807a5c

Please sign in to comment.