Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle arguments and defaults consistently #16

Merged
merged 1 commit into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 30 additions & 36 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ missing arguments, but this is just how it works in the python AST.
structure Args where
posonlyargs : List String
args : List String
defaults: List Expr
defaults: List Expr'
vararg : Option String
kwonlyargs : List String
kw_defaults: List Expr
kw_defaults: List (String × Expr')
kwarg : Option String
deriving Repr

Expand All @@ -129,16 +129,14 @@ def Args.names (ax : Args) : List String :=
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

/-
In addition to the defaults above from the AST, we also collect
the values from f.__defaults__ here in the Fun structure. These
values are evaluated in a different context from the other names
in the function, so we need to capture them on the Python side.
-/
def Args.all_defaults (ax : Args) : List (String × Expr') :=
govereau marked this conversation as resolved.
Show resolved Hide resolved
let args := ax.posonlyargs ++ ax.args
let dflt := args.reverse.zip ax.defaults.reverse
govereau marked this conversation as resolved.
Show resolved Hide resolved
dflt ++ ax.kw_defaults
govereau marked this conversation as resolved.
Show resolved Hide resolved

structure Fun where
source : String
args : Args
defaults: List Const
body: List Stmt
deriving Repr

Expand Down Expand Up @@ -315,33 +313,6 @@ where
| "If" => return (.ifStm (<- expr "test") (<- stmts "body") (<- stmts "orelse"))
| _ => throw s!"unsupported python construct {key}"

def arguments (j : Json) : Parser Args := do
let obj <- j.getObjVal? "arguments"
let arg? := field (opt arg) obj
let args := field (list arg) obj
let exprs := field (list expr) obj
return {
posonlyargs := (<- args "posonlyargs")
args := (<- args "args")
defaults := (<- exprs "defaults")
vararg := (<- arg? "vararg")
kwonlyargs := (<- args "kwonlyargs")
kw_defaults := (<- exprs "kw_defaults")
kwarg := (<- arg? "kwarg")
}
where
arg (j : Json) : Parser String := do
let obj <- j.getObjVal? "arg"
return (<- field str obj "arg")

def function (j : Json) : Parser Fun := do
let source <- field str j "source"
withSrc source do
let args <- field arguments j "args"
let defaults <- field (list const) j "defaults"
let body <- field (list stmt) j "body"
return Fun.mk source args defaults body

-- Both global references and arguments are processed in the global
-- environment. These terms do not have a position, and must be
-- evaluable in the default environment.
Expand All @@ -366,6 +337,29 @@ where
globals (arr : Array Json) : Parser (List Expr) :=
arr.toList.mapM fun x => return .exprPos (<- global x) {}

def arguments (j : Json) : Parser Args := do
let obj <- j.getObjVal? "arguments"
return {
posonlyargs := (<- field (list arg) obj "posonlyargs")
args := (<- field (list arg) obj "args")
defaults := (<- field (list global) obj "defaults")
vararg := (<- field (opt arg) obj "vararg")
kwonlyargs := (<- field (list arg) obj "kwonlyargs")
kw_defaults := (<- field (dict global) obj "kw_defaults")
kwarg := (<- field (opt arg) obj "kwarg")
}
where
arg (j : Json) : Parser String := do
let obj <- j.getObjVal? "arg"
return (<- field str obj "arg")

def function (j : Json) : Parser Fun := do
let source <- field str j "source"
withSrc source do
let args <- field arguments j "args"
let body <- field (list stmt) j "body"
return Fun.mk source args body

def kernel (j : Json) : Parser Kernel := do
let name <- field str j "entry"
let funcs <- field (dict function) j "funcs"
Expand Down
1 change: 0 additions & 1 deletion interop/nkl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
# Released under Apache 2.0 license as described in the file LICENSE.
# Authors: Paul Govereau

from .lean import load, to_json
from .parser import Parser
26 changes: 0 additions & 26 deletions interop/nkl/lean.py

This file was deleted.

53 changes: 21 additions & 32 deletions interop/nkl/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from textwrap import dedent
from itertools import chain
from collections import deque
from nkl.lean import py_to_lean
from nkl.lean_rffi import py_to_lean

# This is a custom JSON encoder for use with AST nodes.
# The AST nodes are not handled by the default encoder.
Expand Down Expand Up @@ -97,25 +97,29 @@ def json(self):
def load(self):
py_to_lean(self.json())

def apply_args(self, *args, **kwargs):
self.args = []
self.kwargs = {}
def process_args(self, args, kwargs):
l = []
d = {}
for arg in args:
self.reference(d, '_', arg)
try: self.args.append(d.popitem()[1])
except Exception:
raise Exception("Unsupported argument type")
for k,v in kwargs.items():
self.ref_arg(k, v)
if args:
for arg in args:
self.reference(d, '_', arg)
try: l.append(d.popitem()[1])
except Exception:
raise Exception("Unsupported argument type")
if kwargs:
for k,v in kwargs.items():
self.reference(d, k, v)
return l, d

def apply_args(self, *args, **kwargs):
l, d = self.process_args(args, kwargs)
self.args = l
self.kwargs = d

def __call__(self, *args, **kwargs):
self.apply_args(*args, **kwargs)
py_to_lean(self.json())

def ref_arg(self, refname, val):
return self.reference(self.kwargs, refname, val)

def ref_global(self, refname, val):
return self.reference(self.globals, refname, val)

Expand Down Expand Up @@ -160,9 +164,11 @@ def translate(self, f: types.FunctionType, args: ast.arguments, body: [ast.AST])
self.f = f
for s in body:
self.visit(s)
l, d = self.process_args(f.__defaults__, f.__kwdefaults__)
args.defaults = l
args.kw_defaults = d
return { 'source': inspect.getsource(f)
, 'args': args
, 'defaults': list(self.fun_defaults(f))
, 'body': body
}

Expand Down Expand Up @@ -201,20 +207,3 @@ def visit_Attribute(self, node):
raise e
except Exception:
return

def fun_defaults(self, f: types.FunctionType):
if f.__defaults__ is None:
return dict()
names = f.__code__.co_varnames[:f.__code__.co_argcount]
tbl = { n:v for (n,v) in zip(reversed(names), reversed(f.__defaults__)) }
if f.__kwdefaults__ is not None:
tbl.update(f.__kwdefaults__)
def is_ok(x):
if x is None or isinstance(x, (int, float, str)):
return True
if isinstance(x, types.FunctionType):
# TODO: this could be incorrect if default
# is using an alternate name for the function
self.ref_global(x.__name__, x)
return False
return { n:v for (n,v) in tbl.items() if is_ok(v) }
Loading