Skip to content

Commit

Permalink
refactor: cleanup KLR definitions
Browse files Browse the repository at this point in the history
Remove some unnecessary parts of KLR, and reorganize the source files.
  • Loading branch information
govereau committed Jan 8, 2025
1 parent bcc8182 commit e4dc75f
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 128 deletions.
1 change: 0 additions & 1 deletion NKL.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.Encode
import NKL.FFI
import NKL.KLR
import NKL.Python
104 changes: 2 additions & 102 deletions NKL/KLR.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,105 +3,5 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/


/-!
# Abstract syntax of Core NKL language
This language is the result of "tracing", and is used as the
portable format, a.k.a. Kernel Language Representation (KLR).
-/

namespace NKL.KLR

-- TODO
inductive Ty where

inductive Const where
| none
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
deriving Repr, BEq

namespace Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Except String Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toBitVec.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toBitVec.toNat
| .string s =>
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end Const

inductive IndexExpr where
| var (name : String)
| int (i : Int)
| neg (expr : IndexExpr)
| add (left right : IndexExpr)
| mul (scalar : Int) (expr : IndexExpr)
| floor (expr : IndexExpr) (scalar : Int)
| ceil (expr : IndexExpr) (scalar : Int)
| mod (expr : IndexExpr) (scalar : Int)
deriving Repr, BEq

inductive Index where
| ellipsis
| coord (e : Option IndexExpr)
| range (l u step : Option IndexExpr)
deriving Repr, BEq

inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (name : String) (shape : List Int)
| tuple (xs : List Expr)
| list (xs : List Expr)
| access (t : Expr) (ix : List Index)
| binop (op : String) (left right : Expr)
| unop (op : String) (e : Expr)
| call (f : Expr) (args : List Expr) (keywords : List (String × Expr))
deriving Repr, BEq

namespace Expr

-- TODO: Just a place-holder for now
def toAffine : Expr -> Except String IndexExpr
| .var v => return .var v
| .const (.int i) => return .int i
| e => throw s!"toAffine unimp {repr e}"

-- TODO: Just a place-holder for now
def simplify : Expr -> Expr :=
fun x => x

end Expr

inductive Stmt where
| pass
| expr (v : Expr)
| ret (v : Expr)
| assign (x : String) (e : Expr)
| loop (x : String) (l u step : IndexExpr) (body : List Stmt)
deriving Repr, BEq
import NKL.KLR.Basic
import NKL.KLR.Encode
104 changes: 104 additions & 0 deletions NKL/KLR/Basic.lean
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
--import TensorLib.Tensor

/-!
# Abstract syntax of Core NKL language
This language is the result of "tracing", and is used as the
portable format, a.k.a. Kernel Language Representation (KLR).
-/

namespace NKL.KLR

-- TODO switch to tensor lib
--export TensorLib (Tensor Dtype Shape)
-- Mostly, NKL deals with empty tensors, so just check dtype and shape
-- TODO: talk to Sean about a more general BEq for Tensor
--instance : BEq Tensor where
-- beq t₁ t₂ := t₁.dtype == t₂.dtype && t₁.shape == t₂.shape

abbrev Dtype := String
abbrev Shape := List Int
structure Tensor where
dtype : Dtype
shape : Shape
deriving Repr, BEq

-- TODO
inductive Typ where

inductive Const where
| none
| bool (value : Bool)
| int (value : Int)
| float (value : Float)
| string (value : String)
deriving Repr, BEq

namespace Const

-- Python-like rules for conversion to boolean
def isTrue : Const -> Bool
| .none => false
| .bool b => b
| .int i => i != 0
| .float f => f != 0.0
| .string s => s != ""

-- Python-like rules for conversion to integer
def toInt : Const -> Except String Int
| .none => throw "none cannot be converted to an integer"
| .bool true => return 1
| .bool false => return 0
| .int i => return i
| .float f =>
-- Python is a bit strange here, it truncates both
-- positive and negative numbers toward zero
if f < 0.0 then
return (Int.ofNat (Float.floor (-f)).toUInt64.toBitVec.toNat).neg
else
return Int.ofNat (Float.floor f).toUInt64.toBitVec.toNat
| .string s =>
match s.toInt? with
| .none => throw s!"string {s} cannot be converted to an integer"
| .some i => return i

end Const

inductive IndexExpr where
| var (name : String)
| int (i : Int)
| neg (expr : IndexExpr)
| add (left right : IndexExpr)
| mul (scalar : Int) (expr : IndexExpr)
| floor (expr : IndexExpr) (scalar : Int)
| ceil (expr : IndexExpr) (scalar : Int)
| mod (expr : IndexExpr) (scalar : Int)
deriving Repr, BEq

-- Note: `np.newindex` is represented as `(.coord none)`
inductive Index where
| ellipsis
| coord (e : Option IndexExpr)
| slice (l u step : Option IndexExpr)
deriving Repr, BEq

inductive Expr where
| var (x : String)
| const (c : Const)
| tensor (t : Tensor)
| access (t : Expr) (ix : List Index)
| call (f : Expr) (args : List Expr) (kwargs : List (String × Expr))
deriving Repr, BEq

inductive Stmt where
| pass
| expr (v : Expr)
| ret (v : Expr)
| assign (x : String) (e : Expr)
| loop (x : String) (l u step : IndexExpr) (body : List Stmt)
deriving Repr, BEq
38 changes: 13 additions & 25 deletions NKL/Encode.lean → NKL/KLR/Encode.lean
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Paul Govereau
-/
import NKL.KLR
import NKL.KLR.Basic

/-!
# Serialization and Deserialization
Expand Down Expand Up @@ -239,15 +239,15 @@ private def ie_var : IndexExpr := .var "s"
def encIndex : Index -> ByteArray
| .ellipsis => tag 0x20 []
| .coord e => tag 0x21 [enc e]
| .range l u s => tag 0x22 [enc l, enc u, enc s]
| .slice l u s => tag 0x22 [enc l, enc u, enc s]
where
enc := encOption encIndexExpr

def decIndex : DecodeM Index := do
match (<- next) with
| 0x20 => return .ellipsis
| 0x21 => return .coord (<- dec)
| 0x22 => return .range (<- dec) (<- dec) (<- dec)
| 0x22 => return .slice (<- dec) (<- dec) (<- dec)
| t => throw s!"Unknown tag in Index {t}"
where
dec:= decOption decIndexExpr
Expand All @@ -258,36 +258,28 @@ private def chkIndex (i : Index) : Bool :=
#guard chkIndex .ellipsis
#guard chkIndex (.coord none)
#guard chkIndex (.coord $ some ie_var)
#guard chkIndex (.range (some ie_var) none none)
#guard chkIndex (.slice (some ie_var) none none)

------------------------------------------------------------------------------
-- Expressions

partial def encExpr : Expr -> ByteArray
| .var s => tag 0x30 [encString s]
| .tensor t s => tag 0x31 [encString t, encList encInt s]
| .const c => tag 0x32 [encConst c]
| .tuple es => tag 0x33 [encList encExpr es]
| .list es => tag 0x34 [encList encExpr es]
| .access e ix => tag 0x35 [encExpr e, encList encIndex ix]
| .binop op l r => tag 0x36 [encString op, encExpr l, encExpr r]
| .unop op e => tag 0x37 [encString op, encExpr e]
| .call f ax kw => tag 0x38 [encExpr f, encList encExpr ax, encList encKeyword kw]
| .var s => tag 0x30 [encString s]
| .tensor t => tag 0x31 [encString t.dtype, encList encInt t.shape]
| .const c => tag 0x32 [encConst c]
| .access e ix => tag 0x33 [encExpr e, encList encIndex ix]
| .call f ax kw => tag 0x34 [encExpr f, encList encExpr ax, encList encKeyword kw]
where
encKeyword : String × Expr -> ByteArray
| (key, expr) => (encString key).append (encExpr expr)

partial def decExpr : DecodeM Expr := do
match (<- next) with
| 0x30 => return .var (<- decString)
| 0x31 => return .tensor (<- decString) (<- decList decInt)
| 0x31 => return .tensor $ .mk (<- decString) (<- decList decInt)
| 0x32 => return .const (<- decConst)
| 0x33 => return .tuple (<- decList decExpr)
| 0x34 => return .list (<- decList decExpr)
| 0x35 => return .access (<- decExpr) (<- decList decIndex)
| 0x36 => return .binop (<- decString) (<- decExpr) (<- decExpr)
| 0x37 => return .unop (<- decString) (<- decExpr)
| 0x38 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
| 0x33 => return .access (<- decExpr) (<- decList decIndex)
| 0x34 => return .call (<- decExpr) (<- decList decExpr) (<- decList decKeyword)
| t => throw s!"Unknown tag in Expr {t}"
where
decKeyword : DecodeM (String × Expr) :=
Expand All @@ -301,13 +293,9 @@ private def ixz := Index.coord (IndexExpr.int 0)

#guard chkExpr nil
#guard chkExpr (.var "var")
#guard chkExpr (.tensor "float32" [1,2,3])
#guard chkExpr (.tensor $ .mk "float32" [1,2,3])
#guard chkExpr (.const (.int 1))
#guard chkExpr (.tuple [nil, nil, nil])
#guard chkExpr (.list [nil, nil, nil])
#guard chkExpr (.access nil [ixz, ixz, ixz])
#guard chkExpr (.binop "op" nil nil)
#guard chkExpr (.unop "op" nil)
#guard chkExpr (.call nil [nil, nil, nil] [("a", nil), ("b", nil)])

------------------------------------------------------------------------------
Expand Down

0 comments on commit e4dc75f

Please sign in to comment.