diff --git a/Main.lean b/Main.lean index 85265a8..f9275c3 100644 --- a/Main.lean +++ b/Main.lean @@ -1,8 +1,5 @@ import NKL -def main (args : List String) : IO Unit := - match args with - | .nil => IO.println s!"Hello, NKL!" - | .cons x _ => do - let s <- IO.FS.readFile x - NKL.parse_json s +def main : List String -> IO Unit + | [ file ] => IO.FS.readFile file >>= NKL.parse_json + | _ => IO.println "invalid arguments" diff --git a/NKL.lean b/NKL.lean index 0607eeb..a8d195a 100644 --- a/NKL.lean +++ b/NKL.lean @@ -6,3 +6,4 @@ Authors: Paul Govereau import NKL.Encode import NKL.FFI import NKL.NKI +import NKL.Python diff --git a/NKL/FFI.lean b/NKL/FFI.lean index 1a968ec..1ad6410 100644 --- a/NKL/FFI.lean +++ b/NKL/FFI.lean @@ -6,16 +6,24 @@ Authors: Paul Govereau import Lean import NKL.NKI import NKL.PrettyPrint +import NKL.Python namespace NKL --- temporary for testing +local instance : MonadLift (Except String) IO where + monadLift + | .ok x => return x + | .error s => throw $ .userError s + +@[export parse_json_old] +def parse_json_old (json : String) : IO Unit := do + let jsn <- Lean.Json.parse json + let f:Fun <- Lean.fromJson? jsn + print_nki f @[export parse_json] -def parse_json (json : String) : IO Unit := do - match Lean.Json.parse json with - | .error str => throw $ .userError str - | .ok jsn => do - match Lean.fromJson? jsn with - | .error str => throw $ .userError str - | .ok (f:Fun) => print_nki f +def parse_json (s : String) : IO Unit := do + let kernel <- Python.Parsing.parse s + for (n,f) in kernel.funcs do + IO.println s!"found {n}" + IO.println s!"{repr f}" diff --git a/NKL/Python.lean b/NKL/Python.lean new file mode 100644 index 0000000..da806a8 --- /dev/null +++ b/NKL/Python.lean @@ -0,0 +1,288 @@ +/- +Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Paul Govereau +-/ +import Lean + +/-! +# Abstract syntax of Python functions + +Mostly 1-to-1 translation of the Python AST to lean. +see: https://docs.python.org/3/library/ast.html +-/ + +namespace NKL +namespace Python + +deriving instance Repr for Lean.JsonNumber + +structure Pos where + lineno : Nat + end_lineno : Nat := 0 + col_offset : Nat := 0 + end_col_offset : Nat := 0 + deriving Repr + +inductive Const where + | none + | bool (value: Bool) + | num (value: Lean.JsonNumber) + | string (value: String) + | ellipsis + deriving Repr + +inductive Ctx where + | load | store | del + deriving Repr + +mutual +inductive Expr where + | exprPos (expr : Expr') (pos : Pos) + deriving Repr + +inductive Expr' where + | const (value: Const) + | name (id: String) (ctx : Ctx) + | attr (value : Expr) (id : String) (ctx : Ctx) + | tuple (xs: List Expr) (ctx : Ctx) + | list (xs: List Expr) (ctx : Ctx) + | subscript (tensor: Expr) (ix: List Expr) (ctx : Ctx) + | slice (l u step: Option Expr) + | boolOp (op : String) (values : List Expr) + | binOp (op : String) (left right : Expr) + | unaryOp (op : String) (operand : Expr) + | compare (left : Expr) (ops : List String) (comparators : List Expr) + | ifExp (test body orelse : Expr) + | call (f: Expr) (args: List Expr) (keywords : List Keyword) + deriving Repr + +inductive Keyword where + | keyword (id : String) (value : Expr) (pos : Pos) + deriving Repr +end + +mutual +inductive Stmt where + | stmtPos (stmt : Stmt') (pos : Pos) + deriving Repr + +inductive Stmt' where + | pass + | expr (e : Expr) + | assert (e : Expr) + | ret (e: Expr) + | assign (xs: List Expr) (e: Expr) + | augAssign (x : Expr) (op : String) (e : Expr) + | annAssign (x : Expr) (annotation : Expr) (value : Option Expr) + | forLoop (x : Expr) (iter: Expr) (body: List Stmt) (orelse : List Stmt) + | ifStm (e : Expr) (thn els: List Stmt) + deriving Repr +end + +structure Args where + posonlyargs : List String + args : List String + defaults: List Expr + vararg : Option String + kwonlyargs : List String + kw_defaults: List Expr + kwarg : Option String + deriving Repr + +structure Fun where + source : String + args : Args + defaults: List Const + body: List Stmt + deriving Repr + +structure Kernel where + entry : String + funcs : List (String × Fun) + globals : List (String × Option String) + +------------------------------------------------------------------------------- +-- Converting Python AST from Json + +namespace Parsing +open Lean + +-- I am using a state monad only to provide better error messages: the source +-- span (Pos) is saved while traversing the tree to identify the location +-- of any errors in the original program + +abbrev Parser := EStateM String Pos + +local instance : MonadLift (Except String) Parser where + monadLift + | .ok x => return x + | .error s => throw s + +private def str : Json -> Parser String := + monadLift ∘ Json.getStr? + +private def field (f: Json -> Parser a) (j : Json) (name : String) : Parser a := + j.getObjVal? name >>= f + +private def field? (f: Json -> Parser a) (j : Json) (name : String) : Parser (Option a) := + try let x <- field f j name; return (some x) + catch _ => return none + +private def list (f: Json -> Parser a) : Json -> Parser (List a) + | .arr arr => arr.toList.mapM f + | json => return [(<- f json )] + +private def dict (f : Json -> Parser a) : Json -> Parser (List (String × a)) + | .obj kvs => kvs.toArray.toList.mapM fun p => return (p.1, (<- f p.2)) + | _ => throw s!"expecting dictionary" + +private def opt (p : Json -> Parser a) : Json -> Parser (Option a) + | .null => return none + | j => return (some (<- p j)) + +-- Note: this will not fail, but can produce an invalid Pos +private def pos (j: Json) : Parser Pos := + return { + lineno := (<- nat "lineno") + end_lineno := (<- nat "end_lineno") + col_offset := (<- nat "col_offset") + end_col_offset := (<- nat "end_col_offset") + } +where + nat (name : String) : Parser Nat := + tryCatch (nat' name) fun _ => return 0 + nat' (name : String) : Parser Nat := do + let obj <- j.getObjVal? name + Json.getNat? obj + +private def withPos (p : String -> Json -> Parser b) (f : b -> Pos -> a) : Json -> Parser a + | .obj (.node _ _ key val _) => do + let pos <- pos val + set pos + let exp <- p key val + return (f exp pos) + | _ => throw "expecting object" + +private def withSrc (source : String) (p : Parser a) : Parser a := + try set { lineno := 0 : Pos } ; p + catch e => get >>= throw ∘ genError e +where + genError (err : String) (pos : Pos) : String := + let lines := source.splitOn "\n" + let lineno := pos.lineno - 1 + let colno := pos.col_offset + let line := if lines.length < lineno + then "" + else lines[lineno]! + let indent := (Nat.repeat (List.cons ' ') colno List.nil).asString + s!"line {lineno}:\n{line}\n{indent}^-- {err}" + +------------------------------------------------------------------------------- +-- Python AST Json objects + +def const : Json -> Parser Const + | .null => return .none + | .bool b => return (.bool b) + | .num jn => return (.num jn) + | .str "..." => return .ellipsis + | .str s => return (.string s) + | _ => throw "expecting constant" + +def exprCtx : Json -> Parser Ctx + | .str "Load" => return .load + | .str "Store" => return .store + | .str "Del" => return .del + | _ => throw "expecting ctx" + +partial def expr (j : Json) : Parser Expr := + withPos expr' Expr.exprPos j +where + expr' (key : String) (j : Json) : Parser Expr' := do + let strs := field (list str) j + let str := field str j + let ctx := field exprCtx j + let const := field const j + let exprs := field (list expr) j + let expr? := field (opt expr) j + let expr := field expr j + let keywords := field (list keyword) j + match key with + | "Constant" => return (.const (<- const "value")) + | "Name" => return (.name (<- str "id") (<- ctx "ctx")) + | "Attribute" => return (.attr (<- expr "value") (<- str "attr") (<- ctx "ctx")) + | "Tuple" => return (.tuple (<- exprs "elts") (<- ctx "ctx")) + | "List" => return (.list (<- exprs "elts") (<- ctx "ctx")) + | "Subscript" => return (.subscript (<- expr "value") (<- exprs "slice") (<- ctx "ctx")) + | "Slice" => return (.slice (<- expr? "lower") (<- expr? "upper") (<- expr? "step")) + | "BoolOp" => return (.boolOp (<- str "op") (<- exprs "values")) + | "BinOp" => return (.binOp (<- str "op") (<- expr "left") (<- expr "right")) + | "UnaryOp" => return (.unaryOp (<- str "op") (<- expr "operand")) + | "Compare" => return (.compare (<- expr "left") (<- strs "ops") (<- exprs "comparators")) + | "IfExp" => return (.ifExp (<- expr "test") (<- expr "body") (<- expr "orelse")) + | "Call" => return (.call (<- expr "func") (<- exprs "args") (<- keywords "keywords")) + | _ => throw s!"unsupported python construct {key}" + + keyword (j: Json) : Parser Keyword := do + let j <- j.getObjVal? "keyword" + return ⟨ <- field str j "arg", <- field expr j "value", <- pos j ⟩ + +partial def stmt (j : Json) : Parser Stmt := + withPos stmt' Stmt.stmtPos j +where + stmt' (key : String) (j : Json) : Parser Stmt' := do + let str := field str j + let exprs := field (list expr) j + let expr? := field (opt expr) j + let expr := field expr j + let stmts := field (list stmt) j + match key with + | "Pass" => return .pass + | "Expr" => return (.expr (<- expr "value")) + | "Assert" => return (.assert (<- expr "test")) + | "Return" => return (.ret (<- expr "value")) + | "Assign" => return (.assign (<- exprs "targets") (<- expr "value")) + | "AugAssign" => return (.augAssign (<- expr "target") (<- str "op") (<- expr "value")) + | "AnnAssign" => return (.annAssign (<- expr "target") (<- expr "annotation") (<- expr? "value")) + | "For" => return (.forLoop (<- expr "target") (<- expr "iter") (<- stmts "body") (<- stmts "orelse")) + | "If" => return (.ifStm (<- expr "test") (<- stmts "body") (<- stmts "orelse")) + | _ => throw s!"unsupported python construct {key}" + +def arguments (j : Json) : Parser Args := do + let obj <- j.getObjVal? "arguments" + let arg? := field (opt arg) obj + let args := field (list arg) obj + let exprs := field (list expr) obj + return { + posonlyargs := (<- args "posonlyargs") + args := (<- args "args") + defaults := (<- exprs "defaults") + vararg := (<- arg? "vararg") + kwonlyargs := (<- args "kwonlyargs") + kw_defaults := (<- exprs "kw_defaults") + kwarg := (<- arg? "kwarg") + } +where + arg (j : Json) : Parser String := do + let obj <- j.getObjVal? "arg" + return (<- field str obj "arg") + +def function (j : Json) : Parser Fun := do + let source <- field str j "source" + withSrc source do + let args <- field arguments j "args" + let defaults <- field (list const) j "defaults" + let body <- field (list stmt) j "body" + return Fun.mk source args defaults body + +def kernel (j : Json) : Parser Kernel := do + let name <- field str j "entry" + let funcs <- field (dict function) j "funcs" + let globals <- field (dict (opt str)) j "globals" + return Kernel.mk name funcs globals + +def parse (s : String) : Except String Kernel := do + let jsn <- Json.parse s + match kernel jsn { lineno := 0 } with + | .ok x _ => .ok x + | .error s _ => .error s diff --git a/interop/nkl/parser.py b/interop/nkl/parser.py new file mode 100644 index 0000000..a86547c --- /dev/null +++ b/interop/nkl/parser.py @@ -0,0 +1,134 @@ +# Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Released under Apache 2.0 license as described in the file LICENSE. +# Authors: Paul Govereau + +import types +import inspect +import ast +import json + +from textwrap import dedent +from itertools import chain +from collections import deque + +class Enc(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, ast.AST): + if len(obj.__dict__) == 0: + return obj.__class__.__name__ + else: + return { obj.__class__.__name__:obj.__dict__ } + try: + return super().default(obj) + except Exception: + return "..." + +class Parser(ast.NodeVisitor): + def __init__(self, f: types.FunctionType): + super().__init__() + self.workq = deque() + self.funcs = {} + self.globals = {} + self.entry = f.__module__ + "." + f.__name__ + self.reference(self.entry, f) + self.do_work() + + def json(self): + d = { 'entry': self.entry + , 'funcs': self.funcs + , 'globals': self.globals + } + return json.dumps(d, cls=Enc) + + # resolve a reference: either populating the environment, + # or adding new items to the work queue + def reference(self, refname, val): + f = None + if isinstance(val, types.FunctionType): + f = val + val = f.__module__ + "." + f.__name__ + elif isinstance(val, types.ModuleType): + val = val.__name__ + + if refname in self.globals: + if val != self.globals[refname]: + assert 0, "global mismatch" + else: + self.globals[refname] = val + + if f is None: + return + try: + match ast.parse(dedent(inspect.getsource(f))): + case ast.Module([ast.FunctionDef(_, args, body)]): + self.workq.append((val, f, args, body)) + case _: + assert 0, "expecting function definition" + except Exception as e: + pass + + def do_work(self): + while len(self.workq) > 0: + fullname, f, args, body = self.workq.popleft() + if fullname in self.funcs: + continue + self.funcs[fullname] = self.translate(f, args, body) + + def translate(self, f: types.FunctionType, args: ast.arguments, body: [ast.AST]): + self.f = f + for s in body: + self.visit(s) + return { 'source': inspect.getsource(f) + , 'args': args + , 'defaults': list(self.fun_defaults(f)) + , 'body': body + } + + # A best-effort dependency finder. + # This is a valid approach because we only need to find + # the expressions that refer to external names, it is ok + # if we find other uses of potentially global names + # and fail to understand them; as long as we find and record + # the "real" uses into the environment for the Lean code. + def lookup(self, s): + return self.f.__globals__.get(s) or self.f.__builtins__.get(s) + + def visit_Name(self, node): + if node.id not in self.f.__code__.co_names: + return + try: + y = self.lookup(node.id) + self.reference(node.id, y) + return node.id, y + except Exception as e: + return + + def visit_Attribute(self, node): + if node.ctx == ast.Store() or \ + node.attr not in self.f.__code__.co_names: + return + try: + n, x = self.visit(node.value) + n = n + "." + node.attr + y = getattr(x, node.attr) + self.reference(n, y) + return n, y + except Exception as e: + return + + def fun_defaults(self, f: types.FunctionType): + if f.__defaults__ is None: + return dict() + names = f.__code__.co_varnames[:f.__code__.co_argcount] + tbl = { n:v for (n,v) in zip(reversed(names), reversed(f.__defaults__)) } + if f.__kwdefaults__ is not None: + tbl.update(f.__kwdefaults__) + def is_ok(x): + if x is None or isinstance(x, (int, float, str)): + return True + if isinstance(x, types.FunctionType): + # TODO: this could be incorrect if default + # is using an alternate name for the function + self.reference(x.__name__, x) + return False + return { n:v for (n,v) in tbl.items() if is_ok(v) }