Skip to content

Commit

Permalink
feat: tracing for python source functions
Browse files Browse the repository at this point in the history
This patch adds basic tracing for user python functions. The main code
is in Python.lean, and depends on definitions in Basic.lean and
NKI.lean, which are incomplete. As more primitives are implemented,
more user kernels will be supported.
  • Loading branch information
govereau committed Jan 9, 2025
1 parent c68b33e commit f3e587f
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 24 deletions.
13 changes: 4 additions & 9 deletions NKL/FFI.lean
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Authors: Paul Govereau
-/
import Lean
import NKL.Python
import NKL.Trace

namespace NKL

Expand All @@ -16,12 +17,6 @@ local instance : MonadLift (Except String) IO where
@[export parse_json]
def parse_json (s : String) : IO Unit := do
let kernel <- Python.Parsing.parse s
let names := kernel.funcs.map fun x => x.fst
let names := String.intercalate "," names
IO.println s!"Found functions: {names}"
for x in kernel.args do
IO.println s!"arg: {repr x}"
for x in kernel.kwargs do
IO.println s!"arg: {repr x}"
for x in kernel.globals do
IO.println s!"global: {repr x}"
let stmts <- NKL.Trace.runNKIKernel kernel
for s in stmts do
IO.println s!"{repr s}"
24 changes: 10 additions & 14 deletions NKL/Python.lean
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ then the structure will be populated with:
defaults = [1, 2]
vararg = "args"
kwonlyargs = [d, e]
kw_defaults = [None, 3]
kw_defaults = [("e", 3)]
kwarg = "kwargs"
Note, defaults and kw_defaults are inconsistent in how they treat
missing arguments, but this is just how it works in the python AST.
Note, this is slightly different from the official Python AST, which
encodes the kw_defaults as a list with None for missing defaults.
-/
structure Args where
posonlyargs : List String
Expand All @@ -122,17 +122,13 @@ structure Args where
kwarg : Option String
deriving Repr

def Args.names (ax : Args) : List String :=
let xs := ax.posonlyargs.append ax.args
let xs := match ax.vararg with | none => xs | some x => xs.append [x]
let xs := xs.append ax.kwonlyargs
let xs := match ax.kwarg with | none => xs | some x => xs.append [x]
xs

def Args.all_defaults (ax : Args) : List (String × Expr') :=
let args := ax.posonlyargs ++ ax.args
let dflt := args.reverse.zip ax.defaults.reverse
dflt ++ ax.kw_defaults
def Args.names (args : Args) : List String :=
args.posonlyargs ++ args.args ++ args.kwonlyargs

def Args.all_defaults (args : Args) : List (String × Expr') :=
let pargs := args.posonlyargs ++ args.args
let dflt := pargs.reverse.zip args.defaults.reverse
dflt ++ args.kw_defaults

structure Fun where
source : String
Expand Down
13 changes: 12 additions & 1 deletion NKL/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,18 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Python
import NKL.Trace.Types
import NKL.Trace.Basic
import NKL.Trace.Builtin
--import NKL.Trace.Python
import NKL.Trace.Python
import NKL.Trace.NKI

namespace NKL.Trace

def runNKIKernel (k : NKL.Python.Kernel) : Except String (List NKL.KLR.Stmt) :=
tracer ⟨ .ofList nki_env, #[] ⟩ do
traceKernel k
let g <- get
return g.body.toList
32 changes: 32 additions & 0 deletions NKL/Trace/NKI.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.Trace.Types
import NKL.Trace.Builtin

/-
# NKI built-ins
This module defines the builtin constants used by tracing for NKI kernels.
-/
namespace NKL.Trace
open NKL.KLR

private def module (s : String) : Name × Item :=
let name := s.toName
(name, .module name)

private def const_var (s : String) : Name × Item :=
let name := s.toName
(name, .term (.expr (.var s) (.any name)))

def nki_env : List (Name × Item) :=
[ module "nki"
, module "nki.language"
, const_var "nki.language.add"
, const_var "nki.language.load"
, const_var "nki.language.store"
]
214 changes: 214 additions & 0 deletions NKL/Trace/Python.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import Lean
import NKL.KLR
import NKL.Python
import NKL.Trace.Types
import NKL.Trace.Basic

namespace NKL.Trace
open NKL.Python

def const : Const -> TraceM Term
| .none => return .expr (.const $ .none) .none
| .bool b => return .expr (.const $ .bool b) .bool
| .int i => return .expr (.const $ .int i) .int
| .float f => return .expr (.const $ .float f) .float
| .string s => return .expr (.const $ .string s) .string
| .ellipsis => throw "unsupported use of ellipsis"

mutual
def indexExpr : Expr -> Tracer KLR.IndexExpr
| .exprPos e' p => withPos p (indexExpr' e')

def indexExpr' : Expr' -> Tracer KLR.IndexExpr
| .const (.int i) => return .int i
| .const c => throw s!"invalid constant {repr c} in index expression"
| .name id _ => return .var id
| .binOp op l r => return <- indexBinOp op (<- indexExpr l) (<- indexExpr r)
| .unaryOp op e => return <- indexUnOp op (<- indexExpr e)
| _ => throw "invalid index expression"

def indexExpr? : Option Expr -> Tracer (Option KLR.IndexExpr)
| none => return none
| some (.exprPos (.const .none) _) => return none
| some e => indexExpr e

def index : Expr -> Tracer KLR.Index
| .exprPos (.const .ellipsis) _ => return .ellipsis
| .exprPos (.slice l u s) p => withPos p do
return (.slice (<- indexExpr? l) (<- indexExpr? u) (<- indexExpr? s))
| e => return (.coord (<- indexExpr? e))
end


mutual
partial def expr : Expr -> Tracer Item
| .exprPos e' p => withPos p (expr' e')

partial def term (e : Expr) : Tracer Term := do
match (<- expr e) with
| .module n => return .expr (.var n.toString) (.any "?".toName)
| .global g => return .expr (.var g.name.toString) (.any "?".toName)
| .source _ => throw "invalid use of source function"
| .term t => return t

partial def term' (e : Expr') : Tracer Term := do
term (.exprPos e (<- getPos))

partial def klr (e : Expr) : Tracer KLR.Expr := do
match (<- term e) with
| .object obj => return .var obj.name.toString
| .tuple _ => throw "tuple cannot be converted to a KLR term"
| .list _ => throw "list cannot be converted to a KLR term"
| .expr e _ => return e

partial def integer (e : Expr) : Tracer Int := do
match (<- term e) with
| .expr (.const c) _ => return (<- c.toInt)
| _ => throw "invalid tensor dimension"

partial def expr' : Expr' -> Tracer Item
| .const c => return .term (<- const c)
| .tensor s dty => do
let shape <- s.mapM integer
return .term (.expr (.tensor ⟨ dty, shape ⟩) (.tensor dty shape))
| .name id _ => lookup_item id.toName
| .attr (.exprPos e p) id _ => do withPos p ((<- expr' e).attr id)
| .tuple l _ => return .term (.tuple (<- l.mapM term))
| .list l _ => return .term (.list (<- l.mapM term))
| .subscript t [ .exprPos (.tuple ix _) _ ] _
| .subscript t ix _ => return .term (.expr (.access (<- klr t) (<- ix.mapM index)) (.any "?".toName))
| .slice _ _ _ => throw "syntax error"
| .boolOp op xs => return .term (<- boolOp op (<- xs.mapM term))
| .binOp op l r => return .term (<- binOp op (<- term l) (<- term r))
| .unaryOp op e => return .term (<- unOp op (<- term e))
| .compare l ops cs => return .term (<- compare (<- term l) ops (<- cs.mapM term))
| .ifExp tst tru fls => do
let tst <- (<- term tst).isTrue
let tru <- expr tru -- eagerly evaluate both branches
let fls <- expr fls -- to report errors to user
return if tst then tru else fls
| .call f args kws => do
match <- expr f with
| .module n => throw s!"module {n} not callable"
| .global g => return .term (<- g.call (<- args.mapM term) (<- kws.mapM (keyword term)))
| .term t => return .term (<- t.call (<- args.mapM klr) (<- kws.mapM (keyword klr)))
| .source f => do
function_call f (<- args.mapM term) (<- kws.mapM (keyword term))
return .term (.expr (.const .none) .none)

partial def keyword (f : Expr -> Tracer a) : Keyword -> Tracer (String × a)
| .keyword id e p => withPos p do return (id, (<- f e))


partial def var (e : Expr) : Tracer String := do
match (<- klr e) with
| .var s => return s
| _ => throw "expecting variable"

partial def assign (xs : List Expr) (e : Expr) : Tracer Unit := do
let xs <- xs.mapM var
let e <- term e
xs.forM fun x => extend x.toName e
if let .expr e _ := e then
xs.forM fun x => add_stmt (KLR.Stmt.assign x e)

partial def stmt : Stmt -> Tracer Unit
| .stmtPos s' p => withPos p (stmt' s')

partial def stmt' : Stmt' -> Tracer Unit
| .expr (.exprPos (.const _) _) => return ()
| .expr e => do
match <- term e with
| .expr e _ => add_stmt (.expr e)
| _ => return () -- effects are done, can be removed from KLR
| .assert e => do
let t <- term e
if (<- t.isFalse) then throw "assertion failed"
| .assign xs e => assign xs e
| .augAssign x op e => do
stmt' (.assign [x] (.exprPos (.binOp op x e) (<- getPos)))
| .annAssign _ _ .none => return ()
| .annAssign x _ (.some e) => stmt' (.assign [x] e)
| _s => throw "not yet implemented" --s!"unimp {repr s}"

-- Bind positional and keyword arguments to a Python function.
-- Note: default arguments should be evaluated in the global environment,
-- however we know that each source function begins with an empty local
-- environment, so it is OK to evaluate the default arguments in the
-- functions initial environment.

partial def bind_args (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
: Tracer (List (String × Term)) := do
if f.args.vararg != none || f.args.kwarg != none then
throw "var args not supported"
if args.length < f.args.posonlyargs.length then
throw "not enough arguments"
let dflts := f.args.all_defaults
let names := f.args.names
if args.length + kwargs.length > names.length then
throw "too many arguments supplied (varargs not supported)"
let argmap <- f.args.names.enum.mapM fun (i,x) => do
if h:args.length > i then
return (x, args.get (Fin.mk i h))
else if let some v := kwargs.lookup x then
return (x, v)
else if let some e := dflts.lookup x then
return (x, <- term' e)
else
throw s!"argument {x} not supplied"
return argmap

-- For a function call, first evaluate the argument in the current environment.
-- Then enter a new environment and evaluate the function statements.
partial def function_call (f : Fun)
(args : List Term)
(kwargs : List (String × Term))
: Tracer Unit := do
let args <- bind_args f args kwargs
let args <- args.mapM fun (x,e) => return (x, e)
withSrc f.source $ enterFun $ do
args.forM fun (x,e) => do extend x.toName e
f.body.forM stmt

end

-- Evaluate each global in the current environment, skipping any globals that
-- are already defined. Note, we may have globals or functions with dummy
-- implementations, e.g.
-- def add(x,y): pass
-- If we have an internal definition, we will use this over anything found
-- during parsing.

private def globals (k : Kernel) : Tracer Unit := do
let s <- get
for (n, f) in k.funcs do
let n := n.toName
if not (s.env.contains n) then
extend_global n (.source f)
for (n,e) in k.globals do
let n := n.toName
if not (s.env.contains n) then
extend_global n (<- expr' e)

-- Call the top-level kernel function
def traceKernel (k : Kernel) : Tracer Unit := do
globals k
match k.funcs.lookup k.entry with
| none => throw s!"function {k.entry} not found"
| some f => do
let args <- k.args.mapM term'
let kwargs <- k.kwargs.mapM fun (x,e) => return (x, <- term' e)
function_call f args kwargs

def runKernel (k : Kernel) : Except String (List KLR.Stmt) :=
tracer ⟨ ∅, #[] ⟩ do
traceKernel k
let g <- get
return g.body.toList
Loading

0 comments on commit f3e587f

Please sign in to comment.