Skip to content

Commit

Permalink
fix: handle arguments and defaults consistently
Browse files Browse the repository at this point in the history
Default function arguments were not being handled the same way as
function arguments. It was possible that an expression that could be
passed as an argument could not have been used as a default value for
the same argument. This patch fixes this issue, and simplifies some of
the parsing code for arguments and function signatures.
  • Loading branch information
govereau committed Jan 10, 2025
1 parent 13b16c8 commit 76e7364
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 95 deletions.
66 changes: 30 additions & 36 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ missing arguments, but this is just how it works in the python AST.
structure Args where
posonlyargs : List String
args : List String
defaults: List Expr
defaults: List Expr'
vararg : Option String
kwonlyargs : List String
kw_defaults: List Expr
kw_defaults: List (String × Expr')
kwarg : Option String
deriving Repr

Expand All @@ -129,16 +129,14 @@ def Args.names (ax : Args) : List String :=
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

/-
In addition to the defaults above from the AST, we also collect
the values from f.__defaults__ here in the Fun structure. These
values are evaluated in a different context from the other names
in the function, so we need to capture them on the Python side.
-/
def Args.all_defaults (ax : Args) : List (String × Expr') :=
let args := ax.posonlyargs ++ ax.args
let dflt := args.reverse.zip ax.defaults.reverse
dflt ++ ax.kw_defaults

structure Fun where
source : String
args : Args
defaults: List Const
body: List Stmt
deriving Repr

Expand Down Expand Up @@ -315,33 +313,6 @@ where
| "If" => return (.ifStm (<- expr "test") (<- stmts "body") (<- stmts "orelse"))
| _ => throw s!"unsupported python construct {key}"

def arguments (j : Json) : Parser Args := do
let obj <- j.getObjVal? "arguments"
let arg? := field (opt arg) obj
let args := field (list arg) obj
let exprs := field (list expr) obj
return {
posonlyargs := (<- args "posonlyargs")
args := (<- args "args")
defaults := (<- exprs "defaults")
vararg := (<- arg? "vararg")
kwonlyargs := (<- args "kwonlyargs")
kw_defaults := (<- exprs "kw_defaults")
kwarg := (<- arg? "kwarg")
}
where
arg (j : Json) : Parser String := do
let obj <- j.getObjVal? "arg"
return (<- field str obj "arg")

def function (j : Json) : Parser Fun := do
let source <- field str j "source"
withSrc source do
let args <- field arguments j "args"
let defaults <- field (list const) j "defaults"
let body <- field (list stmt) j "body"
return Fun.mk source args defaults body

-- Both global references and arguments are processed in the global
-- environment. These terms do not have a position, and must be
-- evaluable in the default environment.
Expand All @@ -366,6 +337,29 @@ where
globals (arr : Array Json) : Parser (List Expr) :=
arr.toList.mapM fun x => return .exprPos (<- global x) {}

def arguments (j : Json) : Parser Args := do
let obj <- j.getObjVal? "arguments"
return {
posonlyargs := (<- field (list arg) obj "posonlyargs")
args := (<- field (list arg) obj "args")
defaults := (<- field (list global) obj "defaults")
vararg := (<- field (opt arg) obj "vararg")
kwonlyargs := (<- field (list arg) obj "kwonlyargs")
kw_defaults := (<- field (dict global) obj "kw_defaults")
kwarg := (<- field (opt arg) obj "kwarg")
}
where
arg (j : Json) : Parser String := do
let obj <- j.getObjVal? "arg"
return (<- field str obj "arg")

def function (j : Json) : Parser Fun := do
let source <- field str j "source"
withSrc source do
let args <- field arguments j "args"
let body <- field (list stmt) j "body"
return Fun.mk source args body

def kernel (j : Json) : Parser Kernel := do
let name <- field str j "entry"
let funcs <- field (dict function) j "funcs"
Expand Down
1 change: 0 additions & 1 deletion interop/nkl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
# Released under Apache 2.0 license as described in the file LICENSE.
# Authors: Paul Govereau

from .lean import load, to_json
from .parser import Parser
26 changes: 0 additions & 26 deletions interop/nkl/lean.py

This file was deleted.

53 changes: 21 additions & 32 deletions interop/nkl/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from textwrap import dedent
from itertools import chain
from collections import deque
from nkl.lean import py_to_lean
from nkl.lean_rffi import py_to_lean

# This is a custom JSON encoder for use with AST nodes.
# The AST nodes are not handled by the default encoder.
Expand Down Expand Up @@ -97,25 +97,29 @@ def json(self):
def load(self):
py_to_lean(self.json())

def apply_args(self, *args, **kwargs):
self.args = []
self.kwargs = {}
def process_args(self, args, kwargs):
l = []
d = {}
for arg in args:
self.reference(d, '_', arg)
try: self.args.append(d.popitem()[1])
except Exception:
raise Exception("Unsupported argument type")
for k,v in kwargs.items():
self.ref_arg(k, v)
if args:
for arg in args:
self.reference(d, '_', arg)
try: l.append(d.popitem()[1])
except Exception:
raise Exception("Unsupported argument type")
if kwargs:
for k,v in kwargs.items():
self.reference(d, k, v)
return l, d

def apply_args(self, *args, **kwargs):
l, d = self.process_args(args, kwargs)
self.args = l
self.kwargs = d

def __call__(self, *args, **kwargs):
self.apply_args(*args, **kwargs)
py_to_lean(self.json())

def ref_arg(self, refname, val):
return self.reference(self.kwargs, refname, val)

def ref_global(self, refname, val):
return self.reference(self.globals, refname, val)

Expand Down Expand Up @@ -160,9 +164,11 @@ def translate(self, f: types.FunctionType, args: ast.arguments, body: [ast.AST])
self.f = f
for s in body:
self.visit(s)
l, d = self.process_args(f.__defaults__, f.__kwdefaults__)
args.defaults = l
args.kw_defaults = d
return { 'source': inspect.getsource(f)
, 'args': args
, 'defaults': list(self.fun_defaults(f))
, 'body': body
}

Expand Down Expand Up @@ -201,20 +207,3 @@ def visit_Attribute(self, node):
raise e
except Exception:
return

def fun_defaults(self, f: types.FunctionType):
if f.__defaults__ is None:
return dict()
names = f.__code__.co_varnames[:f.__code__.co_argcount]
tbl = { n:v for (n,v) in zip(reversed(names), reversed(f.__defaults__)) }
if f.__kwdefaults__ is not None:
tbl.update(f.__kwdefaults__)
def is_ok(x):
if x is None or isinstance(x, (int, float, str)):
return True
if isinstance(x, types.FunctionType):
# TODO: this could be incorrect if default
# is using an alternate name for the function
self.ref_global(x.__name__, x)
return False
return { n:v for (n,v) in tbl.items() if is_ok(v) }

0 comments on commit 76e7364

Please sign in to comment.