diff --git a/bench/Main.hs b/bench/Main.hs index ba1d5de..f39eacc 100644 --- a/bench/Main.hs +++ b/bench/Main.hs @@ -1,44 +1,45 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE ScopedTypeVariables #-} + import TinyLang.Field.Generator () import TinyLang.Field.Typed.Core +import Control.Lens import Control.Monad import Data.List +import Data.Semigroup import Test.QuickCheck --- A couple of functions for checking the output of generators -progNodes :: Program f -> Int -progNodes = stmtsNodes . _programStatements +progNodes :: Program f -> Sum Int +progNodes prog = + foldMapOf progSubExt (const (Sum 1)) prog + <> foldMapOf progSubStatements stmtsNodes prog -stmtsNodes :: Statements f -> Int -stmtsNodes = sum . map stmtNodes . unStatements +stmtsNodes :: Statements f -> Sum Int +stmtsNodes stmts = foldMapOf stmtsSubStatement stmtNodes stmts -stmtNodes :: Statement f -> Int -stmtNodes (ELet _ e) = 1 + exprNodes e -stmtNodes (EAssert e) = 1 + exprNodes e +stmtNodes :: Statement f -> Sum Int +stmtNodes stmt = Sum 1 <> foldMapOf stmtSubExpr exprNodes stmt -exprNodes :: Expr f a -> Int -exprNodes (EConst _) = 1 -exprNodes (EVar _) = 1 -exprNodes (EAppUnOp _ e) = 1 + exprNodes e -exprNodes (EAppBinOp _ e1 e2) = 1 + exprNodes e1 + exprNodes e2 -exprNodes (EIf e e1 e2) = 1 + exprNodes e + exprNodes e1 + exprNodes e2 +exprNodes :: SomeUniExpr f -> Sum Int +exprNodes e = Sum 1 <> foldMapOf exprSubExpr exprNodes e -progDepth :: Program f -> Int -progDepth = stmtsDepth . _programStatements +-- NOTE: We need Max 0 for the empty case +progDepth :: Program f -> Max Int +progDepth prog = + Max 0 + <> foldMapOf progSubExt (const (Max 1)) prog + <> foldMapOf progSubStatements stmtsDepth prog -stmtsDepth :: Statements f -> Int -stmtsDepth = maximum . (0:) . map stmtDepth . unStatements +stmtsDepth :: Statements f -> Max Int +stmtsDepth stmts = Max 0 <> foldMapOf stmtsSubStatement stmtDepth stmts -stmtDepth :: Statement f -> Int -stmtDepth (ELet _ e) = 1 + exprDepth e -stmtDepth (EAssert e) = 1 + exprDepth e +stmtDepth :: Statement f -> Max Int +stmtDepth stmt = (+1) <$> Max 0 <> foldMapOf stmtSubExpr exprDepth stmt -exprDepth :: Expr f a -> Int -exprDepth (EConst _) = 1 -exprDepth (EVar _) = 1 -exprDepth (EAppUnOp _ e) = 1 + exprDepth e -exprDepth (EAppBinOp _ e1 e2) = 1 + max (exprDepth e1) (exprDepth e2) -exprDepth (EIf e e1 e2) = 1 + max (exprDepth e) (max (exprDepth e1) (exprDepth e2)) +exprDepth :: SomeUniExpr f -> Max Int +exprDepth expr = (+1) <$> Max 0 <> foldMapOf exprSubExpr exprDepth expr data TestResult = TestResult { nodes :: Int , depth :: Int @@ -48,7 +49,7 @@ data TestResult = TestResult { nodes :: Int runGen :: Int -> IO TestResult runGen size = do prog <- generate (resize size arbitrary) :: IO (Program (AField Rational)) - pure $ TestResult (progNodes prog) (progDepth prog) + pure $ TestResult (getSum (progNodes prog)) (getMax (progDepth prog)) average :: (Real a, Fractional b) => [a] -> b average xs = realToFrac (sum xs) / genericLength xs diff --git a/field/TinyLang/Field/Core.hs b/field/TinyLang/Field/Core.hs index 40aeac4..dd50119 100644 --- a/field/TinyLang/Field/Core.hs +++ b/field/TinyLang/Field/Core.hs @@ -4,9 +4,14 @@ module TinyLang.Field.Core ( Program (..) , Statements (..) + , progSubExt + , progSubStatements + , progSubStatement + , stmtsSubStatement ) where import Data.Bifunctor +import Control.Lens import GHC.Generics import Quiet @@ -16,8 +21,8 @@ newtype Statements stmt = Statements { unStatements :: [stmt] } deriving (Show) via (Quiet (Statements stmt)) -- | Basic wrapper of program -data Program var stmt = Program - { _programExts :: [var] +data Program ext stmt = Program + { _programExts :: [ext] , _programStatements :: Statements stmt } deriving (Eq, Foldable, Traversable, Functor) @@ -26,5 +31,21 @@ instance Bifunctor Program where bimap f g (Program exts stmts) = Program (fmap f exts) (fmap g stmts) -- NOTE: Adding explicit Show instance to avoid record syntax -instance (Show var, Show stmt) => Show (Program var stmt) where +instance (Show ext, Show stmt) => Show (Program ext stmt) where show (Program exts stmts) = "Program " ++ show exts ++ " " ++ show stmts + +-- Some Traversals +progSubExt :: Traversal' (Program ext stmts) ext +progSubExt f = \case + Program exts stmts -> Program <$> traverse f exts <*> pure stmts + +progSubStatements :: Traversal' (Program ext stmt) (Statements stmt) +progSubStatements f = \case + Program exts stmts -> Program exts <$> f stmts + +progSubStatement :: Traversal' (Program ext stmt) stmt +progSubStatement f = \case + Program exts stmts -> Program exts <$> traverse f stmts + +stmtsSubStatement :: Traversal' (Statements stmt) stmt +stmtsSubStatement = traverse diff --git a/field/TinyLang/Field/Raw/Core.hs b/field/TinyLang/Field/Raw/Core.hs index e430a66..288ae69 100644 --- a/field/TinyLang/Field/Raw/Core.hs +++ b/field/TinyLang/Field/Raw/Core.hs @@ -8,20 +8,32 @@ module TinyLang.Field.Raw.Core , UnOp(..) , Statement(..) , Program - , pattern C.Program + , pattern Program , C._programStatements + , C._programExts , Statements - , pattern C.Statements + , pattern Statements , C.unStatements , RawProgram , RawStatements , RawStatement , RawExpr + , progSubExt + , progSubStatements + , progSubStatement + , stmtsSubStatement + , stmtSubStatement + , progSubExpr + , stmtsSubExpr + , stmtSubExpr + , exprSubExpr ) where -import TinyLang.Field.Uni hiding (Uni) +import qualified TinyLang.Field.Uni as U +import qualified TinyLang.Field.Type as T import qualified TinyLang.Field.Core as C +import Control.Lens import GHC.Generics import Quiet @@ -50,22 +62,34 @@ statement level; the operations acting on statement level are not necessarily mappable over a list of statements. -} -type Program v f = C.Program (v, SomeUni f) (Statement v f) -type Statements v f = C.Statements (Statement v f) +-- NOTE: Exts are annotated with SomeUni, as we only support constant externals. +type Program v f = C.Program (v, U.SomeUni f) (Statement v f) + +{-# COMPLETE Program #-} +pattern Program :: [(v, U.SomeUni f)] -> Statements v f -> Program v f +pattern Program exts stmts = C.Program exts stmts + + +type Statements v f = C.Statements (Statement v f) + +{-# COMPLETE Statements #-} +pattern Statements :: [Statement v f] -> Statements v f +pattern Statements stmts = C.Statements stmts + data Statement v f - = ELet (v, SomeUni f) (Expr v f) + = ELet (v, T.Type f) (Expr v f) | EAssert (Expr v f) | EFor v Integer Integer (Statements v f) deriving (Show) data Expr v f - = EConst (SomeUniConst f) + = EConst (U.SomeUniConst f) | EVar v - | EAppBinOp BinOp (Expr v f) (Expr v f) - | EAppUnOp UnOp (Expr v f) - | EIf (Expr v f) (Expr v f) (Expr v f) - | ETypeAnn (SomeUni f) (Expr v f) + | EAppBinOp BinOp (Expr v f) (Expr v f) + | EAppUnOp UnOp (Expr v f) + | EIf (Expr v f) (Expr v f) (Expr v f) + | ETypeAnn (T.Type f) (Expr v f) deriving (Show) data BinOp @@ -98,3 +122,42 @@ type RawProgram f = Program Var f type RawStatements f = Statements Var f type RawStatement f = Statement Var f type RawExpr f = Expr Var f + + +-- Traversals +progSubExt :: Traversal' (Program v f) (v, U.SomeUni f) +progSubExt = C.progSubExt + +progSubStatements :: Traversal' (Program v f) (Statements v f) +progSubStatements = C.progSubStatements + +progSubStatement :: Traversal' (Program v f) (Statement v f) +progSubStatement = C.progSubStatement + +stmtsSubStatement :: Traversal' (Statements v f) (Statement v f) +stmtsSubStatement = C.stmtsSubStatement + +stmtSubStatement :: Traversal' (Statement v f) (Statement v f) +stmtSubStatement f = \case + EFor var i j stmts -> EFor var i j <$> stmtsSubStatement f stmts + x -> pure x + +progSubExpr :: Traversal' (Program v f) (Expr v f) +progSubExpr = progSubStatements . stmtsSubExpr + +stmtsSubExpr :: Traversal' (Statements v f) (Expr v f) +stmtsSubExpr = stmtsSubStatement . stmtSubExpr + +stmtSubExpr :: Traversal' (Statement v f) (Expr v f) +stmtSubExpr f = \case + ELet var expr -> ELet var <$> f expr + EAssert expr -> EAssert <$> f expr + EFor var i j stmts -> EFor var i j <$> stmtsSubExpr f stmts + +exprSubExpr :: Traversal' (Expr v f) (Expr v f) +exprSubExpr f = \case + EAppUnOp unOp expr1 -> EAppUnOp unOp <$> f expr1 + EAppBinOp binOp expr1 expr2 -> EAppBinOp binOp <$> f expr1 <*> f expr2 + EIf expr expr1 expr2 -> EIf <$> f expr <*> f expr1 <*> f expr2 + ETypeAnn typ expr -> ETypeAnn typ <$> f expr + x -> pure x diff --git a/field/TinyLang/Field/Raw/Parser.hs b/field/TinyLang/Field/Raw/Parser.hs index 5ae9fff..7a4cfbc 100644 --- a/field/TinyLang/Field/Raw/Parser.hs +++ b/field/TinyLang/Field/Raw/Parser.hs @@ -54,13 +54,14 @@ keyword ::= == Types/Universes -Currently we only support 3 types/universes: +We distinguish between types and universes. Types include function types, while +universes include the following types of contants: * booleans, * fields, and * vectors. -We use them to annotate expressions with their desired type. +We can use types to annotate expressions. @ uni ::= @@ -69,6 +70,12 @@ uni ::= "vector" @ +@ +type ::= + uni + type -> type +@ + == Identifiers To avoid a syntactic clash with /bool-literals/, identifiers start with a @@ -85,7 +92,7 @@ We follow the ML-family syntax for variable declarations, where the identifier @ var-decl ::= - ident ":" uni + ident ":" type @ @@ -117,7 +124,7 @@ expr ::= expr "[" expr "]" statement ";" expr "if" expr "then" expr "else" expr - expr ":" uni + expr ":" type infix-op ::= "and" @@ -165,7 +172,7 @@ ext-decls ::= (ext-decl ";")* ext-decl ::= - "ext" var-decl + "ext" ident ":" uni @ == Operator Precedence @@ -195,7 +202,8 @@ import TinyLang.Prelude hiding (many, option, try) import Data.Field import TinyLang.Field.Existential import TinyLang.Field.Raw.Core -import TinyLang.Field.Uni +import qualified TinyLang.Field.Uni as U +import qualified TinyLang.Field.Type as T import TinyLang.ParseUtils import qualified Control.Monad.Combinators.Expr as Comb @@ -249,14 +257,28 @@ isKeyword :: String -> Bool isKeyword = (`member` keywords) -pUni :: ParserT m (SomeUni f) +pUni :: ParserT m (U.SomeUni f) pUni = choice - [ Some Bool <$ keyword "bool" - , Some Field <$ keyword "field" - , Some Vector <$ keyword "vector" + [ Some U.Bool <$ keyword "bool" + , Some U.Field <$ keyword "field" + , Some U.Vector <$ keyword "vector" ] +pType :: forall m f. ParserT m (T.Type f) +pType = + choice + -- NOTE: this backtracks when no parentheses + [ try (parens pType) + -- NOTE: this backtracks if not a function type + , try pFunType + , pUniType + ] where + pUniType = fromSomeUni <$> pUni + pFunType = T.TyFun <$> pUniType <*> (symbol "->" *> pType) + fromSomeUni = forget T.UniType + + -- TODO: Consider merging Identifier with Variable pIdentifier :: ParserT m Identifier pIdentifier = @@ -273,8 +295,8 @@ pVar = do pure $ Var ident -- variable declaration -pVarDecl :: ParserT m (Var, SomeUni f) -pVarDecl = (,) <$> pVar <*> (symbol ":" *> pUni) +pVarDecl :: ParserT m (Var, T.Type f) +pVarDecl = (,) <$> pVar <*> (symbol ":" *> pType) pBoolLiteral :: ParserT m Bool pBoolLiteral = @@ -344,16 +366,16 @@ operatorTable = ] ] -vBool :: ParserT m (SomeUniConst f) -vBool = Some . UniConst Bool <$> pBoolLiteral +vBool :: ParserT m (U.SomeUniConst f) +vBool = Some . U.UniConst U.Bool <$> pBoolLiteral -vVec :: ParserT m (SomeUniConst f) -vVec = Some . UniConst Vector <$> pVecLiteral +vVec :: ParserT m (U.SomeUniConst f) +vVec = Some . U.UniConst U.Vector <$> pVecLiteral -vField :: Field f => ParserT m (SomeUniConst f) -vField = Some . UniConst Field . fromInteger <$> signedDecimal +vField :: Field f => ParserT m (U.SomeUniConst f) +vField = Some . U.UniConst U.Field . fromInteger <$> signedDecimal -pConst :: Field f => ParserT m (SomeUniConst f) +pConst :: Field f => ParserT m (U.SomeUniConst f) pConst = choice [ vBool , vVec @@ -361,8 +383,8 @@ pConst = choice ] -pAnn :: ParserT m (SomeUni f) -pAnn = symbol ":" *> pUni +pAnn :: ParserT m (T.Type f) +pAnn = symbol ":" *> pType pTerm :: Field f => ParserT m (RawExpr f) pTerm = @@ -395,10 +417,10 @@ pStatement = <* keyword "end" ] -pExtDecl :: ParserT m (Var, SomeUni f) -pExtDecl = keyword "ext" *> pVarDecl +pExtDecl :: ParserT m (Var, U.SomeUni f) +pExtDecl = (,) <$> (keyword "ext" *> pVar) <*> (symbol ":" *> pUni) -pExtDecls :: ParserT m [(Var, SomeUni f)] +pExtDecls :: ParserT m [(Var, U.SomeUni f)] pExtDecls = many (pExtDecl <* symbol ";") pStatements :: Field f => ParserT m (RawStatements f) diff --git a/field/TinyLang/Field/Type.hs b/field/TinyLang/Field/Type.hs new file mode 100644 index 0000000..bd5ebdd --- /dev/null +++ b/field/TinyLang/Field/Type.hs @@ -0,0 +1,44 @@ +module TinyLang.Field.Type + ( Type(..) + , pattern UniType + , pattern Bool + , pattern Field + , pattern Vector + ) where + +import Data.Field +import qualified TinyLang.Field.Uni as U +import TinyLang.Field.Existential + +-- Types +data Type f + = BuiltIn (Some (TypeIn U.Uni f)) + | TyFun (Type f) (Type f) + + +pattern UniType :: U.Uni f a -> Type f +pattern UniType uni = BuiltIn (Some (TypeIn uni)) + +deriving instance (TextField f) => Show (Type f) +deriving instance Eq (Type f) +deriving instance Ord (Type f) + +pattern Bool :: Type f +pattern Bool = UniType U.Bool + +pattern Field :: Type f +pattern Field = UniType U.Field + +pattern Vector :: Type f +pattern Vector = UniType U.Vector + +newtype TypeIn uni f a = TypeIn (uni f a) + deriving newtype (Show, Eq, Ord) + +instance Eq (Some (TypeIn U.Uni f)) where + Some (TypeIn u1) == Some (TypeIn u2) = Some u1 == Some u2 + +instance Ord (Some (TypeIn U.Uni f)) where + Some (TypeIn u1) `compare` Some (TypeIn u2) = (Some u1) `compare` (Some u2) + +deriving instance (TextField f) => Show (Some (TypeIn U.Uni f)) diff --git a/field/TinyLang/Field/Typed/Core.hs b/field/TinyLang/Field/Typed/Core.hs index ca3706b..7af6ebb 100644 --- a/field/TinyLang/Field/Typed/Core.hs +++ b/field/TinyLang/Field/Typed/Core.hs @@ -34,6 +34,14 @@ module TinyLang.Field.Typed.Core , progExtVarSigs , progSupplyFromAtLeastFree , uniOfExpr + , progSubExt + , progSubStatements + , progSubStatement + , stmtsSubStatement + , progSubExpr + , stmtsSubExpr + , stmtSubExpr + , exprSubExpr ) where import Prelude hiding (div) @@ -41,11 +49,11 @@ import TinyLang.Prelude import qualified TinyLang.Field.Core as C import Data.Field as Field +import Control.Lens import TinyLang.Environment as Env import TinyLang.Field.Existential import TinyLang.Field.Uni import TinyLang.Var as Var --- import TinyLang.Field.Printer (progToString, PrintStyle(..)) type SomeUniExpr f = SomeOf (Uni f) (Expr f) @@ -140,8 +148,8 @@ withBinOpUnis BAt k = k knownUni knownUni knownUni uniOfExpr :: Expr f a -> Uni f a uniOfExpr (EConst (UniConst uni _)) = uni uniOfExpr (EVar (UniVar uni _)) = uni -uniOfExpr (EAppUnOp op _) = withUnOpUnis op $ \_ resUni -> resUni -uniOfExpr (EAppBinOp op _ _) = withBinOpUnis op $ \_ _ resUni -> resUni +uniOfExpr (EAppUnOp unOp _) = withUnOpUnis unOp $ \_ resUni -> resUni +uniOfExpr (EAppBinOp binOp _ _) = withBinOpUnis binOp $ \_ _ resUni -> resUni uniOfExpr (EIf _ x _) = uniOfExpr x withGeqUnOp :: UnOp f a1 b1 -> UnOp f a2 b2 -> d -> ((a1 ~ a2, b1 ~ b2) => d) -> d @@ -291,3 +299,64 @@ progSupplyFromAtLeastFree = . _scopedVarSigsFree . execSVS . progVS + +-- Traversals, specialised for types +progSubExt :: Traversal' (Program f) (SomeUniVar f) +progSubExt = C.progSubExt + +progSubStatements :: Traversal' (Program f) (Statements f) +progSubStatements = C.progSubStatements + +progSubStatement :: Traversal' (Program f) (Statement f) +progSubStatement = C.progSubStatement + +stmtsSubStatement :: Traversal' (Statements f) (Statement f) +stmtsSubStatement = C.stmtsSubStatement + +-- Some helper methods for handling transitions between Expr and SomeUniExpr +box :: (KnownUni f a) => Expr f a -> SomeUniExpr f +box expr = SomeOf knownUni expr + +unbox :: forall f a. (KnownUni f a) => SomeUniExpr f -> Expr f a +unbox (SomeOf uni expr) = withGeqUni uni (knownUni :: Uni f a) (error message) expr where + message = "Uni mismatch!" + +-- boxF :: (KnownUni f a, Functor t) => t (Expr f a) -> t (SomeUniExpr f) +-- boxF = fmap box + +unboxF :: (KnownUni f a, Functor t) => t (SomeUniExpr f) -> t (Expr f a) +unboxF = fmap unbox + +wrapF :: (KnownUni f a, KnownUni f b, Functor t) => (SomeUniExpr f -> t (SomeUniExpr f)) -> Expr f a -> t (Expr f b) +wrapF f = unboxF . f . box + +progSubExpr :: Traversal' (Program f) (SomeUniExpr f) +progSubExpr = progSubStatements . stmtsSubExpr + +stmtsSubExpr :: Traversal' (Statements f) (SomeUniExpr f) +stmtsSubExpr = stmtsSubStatement . stmtSubExpr + +stmtSubExpr :: forall f. Traversal' (Statement f) (SomeUniExpr f) +stmtSubExpr f = \case + ELet uniVar@(UniVar uni _) expr -> + withKnownUni uni $ + ELet uniVar <$> wrapF f expr + EAssert expr -> + EAssert <$> wrapF f expr + +exprSubExpr :: Traversal' (SomeUniExpr f) (SomeUniExpr f) +exprSubExpr f = \case + SomeOf uni e0 -> SomeOf uni <$> case e0 of + EAppUnOp unOp e -> + withKnownUni (uniOfExpr e) $ + EAppUnOp unOp <$> wrapF f e + EAppBinOp binOp e1 e2 -> + withKnownUni (uniOfExpr e1) $ + withKnownUni (uniOfExpr e2) $ + EAppBinOp binOp <$> wrapF f e1 <*> wrapF f e2 + EIf e e1 e2 -> + withKnownUni (uniOfExpr e) $ + withKnownUni (uniOfExpr e1) $ + withKnownUni (uniOfExpr e2) $ + EIf <$> wrapF f e <*> wrapF f e1 <*> wrapF f e2 + x -> pure x diff --git a/field/TinyLang/Field/Typed/TypeChecker.hs b/field/TinyLang/Field/Typed/TypeChecker.hs index 3805fac..d3fe5c1 100644 --- a/field/TinyLang/Field/Typed/TypeChecker.hs +++ b/field/TinyLang/Field/Typed/TypeChecker.hs @@ -28,14 +28,15 @@ import TinyLang.Prelude hiding (TypeError) import Data.Field import TinyLang.Field.Existential +import qualified TinyLang.Field.Type as R import qualified TinyLang.Field.Raw.Core as R import qualified TinyLang.Field.Typed.Core as T -import TinyLang.Field.Uni +import qualified TinyLang.Field.Uni as T import TinyLang.Var import Control.Monad.Cont -- import qualified Data.Set as Set -import Data.Kind +import qualified Data.Kind as K import qualified Data.Map.Strict as Map import qualified Data.String.Interpolate as QQ @@ -51,11 +52,11 @@ type MonadTypeChecker m f = ( MonadSupply m {-| == Type Environments -} -type TyEnv f = Map R.Var (SomeUniVar f) +type TyEnv f = Map R.Var (T.SomeUniVar f) {-| @TypeChecker@ Transformer -} -newtype TypeCheckerT (m :: Type -> Type) f a = +newtype TypeCheckerT (m :: K.Type -> K.Type) f a = TypeChecker { runTypeCheckerT :: (ExceptT TypeCheckError (ReaderT (TyEnv f) (SupplyT m))) a } deriving newtype ( Monad , Functor @@ -91,14 +92,19 @@ typeProgram = runTypeChecker . checkProgram {-| Add a variable to type environment -} -- NOTE: At the moment this mimics the old scope -withSomeUniVar :: (Monad m) => (R.Var, SomeUni f) -> forall r. (SomeUniVar f -> TypeCheckerT m f r) -> TypeCheckerT m f r -withSomeUniVar (var, uni) kont = do - someUniVar <- mkSomeUniVar uni <$> (freshVar . R.unVar $ var) +withSomeUniVar :: (Monad m) => (R.Var, T.SomeUni f) -> forall r. (T.SomeUniVar f -> TypeCheckerT m f r) -> TypeCheckerT m f r +withSomeUniVar (var, someUni) kont = do + someUniVar <- T.mkSomeUniVar someUni <$> (freshVar . R.unVar $ var) local (Map.insert var someUniVar) $ kont someUniVar -withVar :: (Monad m) => (R.Var, SomeUni f) -> forall r. (T.Var -> TypeCheckerT m f r) -> TypeCheckerT m f r +withSomeUniVar' :: (Monad m) => (R.Var, R.Type f) -> forall r. (T.SomeUniVar f -> TypeCheckerT m f r) -> TypeCheckerT m f r +withSomeUniVar' (var, R.UniType uni) kont = do + someUniVar <- T.mkSomeUniVar (Some uni) <$> (freshVar . R.unVar $ var) + local (Map.insert var someUniVar) $ kont someUniVar + +withVar :: (Monad m) => (R.Var, R.Type f) -> forall r. (T.Var -> TypeCheckerT m f r) -> TypeCheckerT m f r withVar (var, uni) kont = - withSomeUniVar (var, uni) $ \ (Some (UniVar _ tVar)) -> kont tVar + withSomeUniVar' (var, uni) $ \ (Some (T.UniVar _ tVar)) -> kont tVar {-| Type inference for variables -} @@ -121,22 +127,22 @@ inferExpr (R.EVar v) = do pure $ SomeOf uni $ T.EVar uniVar inferExpr (R.EAppBinOp rBinOp l m) = withTypedBinOp rBinOp $ \tBinOp -> - SomeOf knownUni <$> (T.EAppBinOp tBinOp <$> checkExpr l <*> checkExpr m) + SomeOf T.knownUni <$> (T.EAppBinOp tBinOp <$> checkExpr l <*> checkExpr m) inferExpr (R.EAppUnOp rUnOp l) = withTypedUnOp rUnOp $ \tUnOp -> - SomeOf knownUni <$> (T.EAppUnOp tUnOp <$> checkExpr l) + SomeOf T.knownUni <$> (T.EAppUnOp tUnOp <$> checkExpr l) inferExpr (R.EIf l m n) = do tL <- checkExpr l SomeOf uni tM <- inferExpr m tN <- T.withKnownUni uni $ checkExpr n pure $ SomeOf uni $ T.EIf tL tM tN -inferExpr (R.ETypeAnn (Some uni) m) = T.withKnownUni uni $ SomeOf uni <$> checkExpr m +inferExpr (R.ETypeAnn (R.UniType uni) m) = T.withKnownUni uni $ SomeOf uni <$> checkExpr m {-| Mapping from Raw UnOp to Typed UnOp -} withTypedBinOp :: forall f r. - R.BinOp -> (forall a b c. ( KnownUni f a, KnownUni f b, KnownUni f c) => T.BinOp f a b c -> r) -> r + R.BinOp -> (forall a b c. ( T.KnownUni f a, T.KnownUni f b, T.KnownUni f c) => T.BinOp f a b c -> r) -> r withTypedBinOp R.Or k = k T.Or withTypedBinOp R.And k = k T.And withTypedBinOp R.Xor k = k T.Xor @@ -155,7 +161,7 @@ withTypedBinOp R.BAt k = k T.BAt -} withTypedUnOp :: forall f r. - R.UnOp -> (forall a b. (KnownUni f a, KnownUni f b) => T.UnOp f a b -> r) -> r + R.UnOp -> (forall a b. (T.KnownUni f a, T.KnownUni f b) => T.UnOp f a b -> r) -> r withTypedUnOp R.Not k = k T.Not withTypedUnOp R.Neq0 k = k T.Neq0 withTypedUnOp R.Neg k = k T.Neg @@ -165,14 +171,14 @@ withTypedUnOp R.Unp k = k T.Unp {-| Type checking for expressions -} checkExpr :: - forall m f a. (Monad m, TextField f, KnownUni f a) + forall m f a. (Monad m, TextField f, T.KnownUni f a) => R.Expr R.Var f -> TypeCheckerT m f (T.Expr f a) checkExpr (R.EIf l m n) = T.EIf <$> checkExpr l <*> checkExpr m <*> checkExpr n checkExpr m = do SomeOf mUni tM <- inferExpr m - let uni = knownUni @f @a + let uni = T.knownUni @f @a let uniMismatch = typeMismatch tM uni mUni - withGeqUniM uni mUni uniMismatch tM + T.withGeqUniM uni mUni uniMismatch tM checkProgram :: forall m f. (Monad m, TextField f) @@ -191,17 +197,17 @@ checkStatements (R.Statements stmts) kont = checkStatement :: forall m f. (Monad m , TextField f) => R.Statement R.Var f -> forall r. ([T.Statement f] -> TypeCheckerT m f r) -> TypeCheckerT m f r -checkStatement (R.ELet (var, someUni@(Some uni)) m) kont = do +checkStatement (R.ELet (var, someUni@(R.UniType uni)) m) kont = do tM <- T.withKnownUni uni $ checkExpr m - withVar (var, someUni) $ \ tVar -> kont [T.ELet (UniVar uni tVar) tM] + withVar (var, someUni) $ \ tVar -> kont [T.ELet (T.UniVar uni tVar) tM] checkStatement (R.EAssert m) kont = do tM <- checkExpr m kont [T.EAssert tM] checkStatement (R.EFor var start end stmts) kont = do runContT (foldMapA (ContT . iter) [start .. end]) kont where - iter i ikont = withVar (var, Some Field) $ \ tVar -> + iter i ikont = withVar (var, R.Field) $ \ tVar -> checkStatements stmts $ \ (T.Statements tStmts) -> do - let uVar = T.UniVar Field tVar + let uVar = T.UniVar T.Field tVar ikont $ T.ELet uVar (T.EConst . fromIntegral $ i) : tStmts {-| Error message for a failed type equality diff --git a/test/Field/Raw/golden/00-bool-literals.golden b/test/Field/Raw/golden/00-bool-literals.golden index 0b32f99..8ccb271 100644 --- a/test/Field/Raw/golden/00-bool-literals.golden +++ b/test/Field/Raw/golden/00-bool-literals.golden @@ -1 +1 @@ -Program [] Statements [ELet (Var "false",Some Bool) (EConst (Some (UniConst Bool False))),ELet (Var "true",Some Bool) (EConst (Some (UniConst Bool True)))] \ No newline at end of file +Program [] Statements [ELet (Var "false",BuiltIn (Some Bool)) (EConst (Some (UniConst Bool False))),ELet (Var "true",BuiltIn (Some Bool)) (EConst (Some (UniConst Bool True)))] \ No newline at end of file diff --git a/test/Field/Raw/golden/01-field-literals.golden b/test/Field/Raw/golden/01-field-literals.golden index 308b0bd..24d75c7 100644 --- a/test/Field/Raw/golden/01-field-literals.golden +++ b/test/Field/Raw/golden/01-field-literals.golden @@ -1 +1 @@ -Program [] Statements [ELet (Var "zero",Some Field) (EConst (Some 0)),ELet (Var "one",Some Field) (EConst (Some 1)),ELet (Var "two",Some Field) (EConst (Some 2)),ELet (Var "half",Some Field) (EAppBinOp Div (EConst (Some 1)) (EConst (Some 2))),ELet (Var "third",Some Field) (EAppBinOp Div (EConst (Some 1)) (EConst (Some 3)))] \ No newline at end of file +Program [] Statements [ELet (Var "zero",BuiltIn (Some Field)) (EConst (Some 0)),ELet (Var "one",BuiltIn (Some Field)) (EConst (Some 1)),ELet (Var "two",BuiltIn (Some Field)) (EConst (Some 2)),ELet (Var "half",BuiltIn (Some Field)) (EAppBinOp Div (EConst (Some 1)) (EConst (Some 2))),ELet (Var "third",BuiltIn (Some Field)) (EAppBinOp Div (EConst (Some 1)) (EConst (Some 3)))] \ No newline at end of file diff --git a/test/Field/Raw/golden/02-vector-literals.golden b/test/Field/Raw/golden/02-vector-literals.golden index 9e2ac95..23fad52 100644 --- a/test/Field/Raw/golden/02-vector-literals.golden +++ b/test/Field/Raw/golden/02-vector-literals.golden @@ -1 +1 @@ -Program [] Statements [ELet (Var "one",Some Vector) (EConst (Some (UniConst Vector [True]))),ELet (Var "two",Some Vector) (EConst (Some (UniConst Vector [True,False]))),ELet (Var "three",Some Vector) (EConst (Some (UniConst Vector [True,False,True]))),ELet (Var "four",Some Vector) (EConst (Some (UniConst Vector [True,False,True,False])))] \ No newline at end of file +Program [] Statements [ELet (Var "one",BuiltIn (Some Vector)) (EConst (Some (UniConst Vector [True]))),ELet (Var "two",BuiltIn (Some Vector)) (EConst (Some (UniConst Vector [True,False]))),ELet (Var "three",BuiltIn (Some Vector)) (EConst (Some (UniConst Vector [True,False,True]))),ELet (Var "four",BuiltIn (Some Vector)) (EConst (Some (UniConst Vector [True,False,True,False])))] \ No newline at end of file diff --git a/test/Field/Raw/golden/03-lexer-whitespace.golden b/test/Field/Raw/golden/03-lexer-whitespace.golden index d100a60..dc9607f 100644 --- a/test/Field/Raw/golden/03-lexer-whitespace.golden +++ b/test/Field/Raw/golden/03-lexer-whitespace.golden @@ -1 +1 @@ -Program [] Statements [ELet (Var "x",Some Field) (EAppBinOp Div (EConst (Some 1)) (EConst (Some 2)))] \ No newline at end of file +Program [] Statements [ELet (Var "x",BuiltIn (Some Field)) (EAppBinOp Div (EConst (Some 1)) (EConst (Some 2)))] \ No newline at end of file diff --git a/test/Field/Raw/golden/10-for-loop.golden b/test/Field/Raw/golden/10-for-loop.golden index ad5bd0e..a4e21fb 100644 --- a/test/Field/Raw/golden/10-for-loop.golden +++ b/test/Field/Raw/golden/10-for-loop.golden @@ -1 +1 @@ -Program [] Statements [EFor (Var "i") 1 2 (Statements [ELet (Var "i'",Some Field) (EVar (Var "i")),EFor (Var "j") 2 3 (Statements [ELet (Var "k",Some Field) (EAppBinOp Mul (EVar (Var "i")) (EVar (Var "j"))),EAssert (EAppBinOp FEq (EVar (Var "k")) (EAppBinOp Mul (EVar (Var "i'")) (EVar (Var "j"))))]),ELet (Var "p",Some Field) (EVar (Var "i")),EFor (Var "l") 1 2 (Statements [ELet (Var "p",Some Field) (EAppBinOp Mul (EVar (Var "p")) (EVar (Var "l")))])])] \ No newline at end of file +Program [] Statements [EFor (Var "i") 1 2 (Statements [ELet (Var "i'",BuiltIn (Some Field)) (EVar (Var "i")),EFor (Var "j") 2 3 (Statements [ELet (Var "k",BuiltIn (Some Field)) (EAppBinOp Mul (EVar (Var "i")) (EVar (Var "j"))),EAssert (EAppBinOp FEq (EVar (Var "k")) (EAppBinOp Mul (EVar (Var "i'")) (EVar (Var "j"))))]),ELet (Var "p",BuiltIn (Some Field)) (EVar (Var "i")),EFor (Var "l") 1 2 (Statements [ELet (Var "p",BuiltIn (Some Field)) (EAppBinOp Mul (EVar (Var "p")) (EVar (Var "l")))])])] \ No newline at end of file diff --git a/test/Field/Raw/golden/11-everything.golden b/test/Field/Raw/golden/11-everything.golden index df1f018..3de5455 100644 --- a/test/Field/Raw/golden/11-everything.golden +++ b/test/Field/Raw/golden/11-everything.golden @@ -1 +1 @@ -Program [(Var "eb",Some Bool),(Var "ef",Some Field),(Var "ev",Some Vector)] Statements [ELet (Var "a",Some Bool) (EConst (Some (UniConst Bool True))),ELet (Var "b",Some Bool) (EConst (Some (UniConst Bool False))),ELet (Var "or'",Some Bool) (EAppBinOp Or (EVar (Var "a")) (EVar (Var "b"))),ELet (Var "or''",Some Bool) (EAppBinOp Or (EVar (Var "a")) (EVar (Var "eb"))),ELet (Var "and'",Some Bool) (EAppBinOp And (EVar (Var "or''")) (EVar (Var "b"))),ELet (Var "xor'",Some Bool) (EAppBinOp Xor (EVar (Var "and'")) (EVar (Var "or''"))),ELet (Var "f",Some Field) (EConst (Some 0)),ELet (Var "g",Some Field) (EConst (Some 1)),ELet (Var "h",Some Field) (EConst (Some 2)),ELet (Var "feq",Some Bool) (EAppBinOp FEq (EVar (Var "g")) (EVar (Var "f"))),ELet (Var "fe'",Some Bool) (EAppBinOp FEq (EVar (Var "ef")) (EVar (Var "f"))),ELet (Var "fle",Some Bool) (EAppBinOp FLe (EVar (Var "h")) (EVar (Var "f"))),ELet (Var "flt",Some Bool) (EAppBinOp FLt (EVar (Var "f")) (EVar (Var "g"))),ELet (Var "fge",Some Bool) (EAppBinOp FGe (EVar (Var "g")) (EVar (Var "h"))),ELet (Var "fgt",Some Bool) (EAppBinOp FGt (EVar (Var "h")) (EVar (Var "f"))),ELet (Var "add",Some Field) (EAppBinOp Add (EVar (Var "f")) (EVar (Var "g"))),ELet (Var "ad'",Some Field) (EAppBinOp Add (EVar (Var "f")) (EVar (Var "ef"))),ELet (Var "sub",Some Field) (EAppBinOp Sub (EVar (Var "g")) (EVar (Var "h"))),ELet (Var "mul",Some Field) (EAppBinOp Mul (EVar (Var "h")) (EVar (Var "f"))),ELet (Var "div",Some Field) (EAppBinOp Div (EVar (Var "f")) (EVar (Var "g"))),ELet (Var "v",Some Vector) (EConst (Some (UniConst Vector [True,False,True]))),ELet (Var "bat",Some Bool) (EAppBinOp BAt (EVar (Var "f")) (EVar (Var "v"))),ELet (Var "ba'",Some Bool) (EAppBinOp BAt (EVar (Var "f")) (EVar (Var "ev"))),ELet (Var "not'",Some Bool) (EAppUnOp Not (EVar (Var "a"))),ELet (Var "neq0'",Some Bool) (EAppUnOp Neq0 (EVar (Var "f"))),ELet (Var "neg'",Some Field) (EAppUnOp Neg (EVar (Var "g"))),ELet (Var "inv'",Some Field) (EAppUnOp Inv (EVar (Var "h"))),ELet (Var "unp",Some Vector) (EAppUnOp Unp (EVar (Var "h"))),ELet (Var "let'",Some Bool) (EConst (Some (UniConst Bool True))),EAssert (EConst (Some (UniConst Bool True))),EFor (Var "j") 0 2 (Statements []),ELet (Var "if'",Some Bool) (EIf (EConst (Some (UniConst Bool True))) (EConst (Some (UniConst Bool True))) (EConst (Some (UniConst Bool True)))),ELet (Var "asf",Some Field) (ETypeAnn (Some Field) (EConst (Some 1))),ELet (Var "asb",Some Bool) (ETypeAnn (Some Bool) (EConst (Some (UniConst Bool True)))),ELet (Var "asv",Some Vector) (ETypeAnn (Some Vector) (EConst (Some (UniConst Vector [True]))))] \ No newline at end of file +Program [(Var "eb",Some Bool),(Var "ef",Some Field),(Var "ev",Some Vector)] Statements [ELet (Var "a",BuiltIn (Some Bool)) (EConst (Some (UniConst Bool True))),ELet (Var "b",BuiltIn (Some Bool)) (EConst (Some (UniConst Bool False))),ELet (Var "or'",BuiltIn (Some Bool)) (EAppBinOp Or (EVar (Var "a")) (EVar (Var "b"))),ELet (Var "or''",BuiltIn (Some Bool)) (EAppBinOp Or (EVar (Var "a")) (EVar (Var "eb"))),ELet (Var "and'",BuiltIn (Some Bool)) (EAppBinOp And (EVar (Var "or''")) (EVar (Var "b"))),ELet (Var "xor'",BuiltIn (Some Bool)) (EAppBinOp Xor (EVar (Var "and'")) (EVar (Var "or''"))),ELet (Var "f",BuiltIn (Some Field)) (EConst (Some 0)),ELet (Var "g",BuiltIn (Some Field)) (EConst (Some 1)),ELet (Var "h",BuiltIn (Some Field)) (EConst (Some 2)),ELet (Var "feq",BuiltIn (Some Bool)) (EAppBinOp FEq (EVar (Var "g")) (EVar (Var "f"))),ELet (Var "fe'",BuiltIn (Some Bool)) (EAppBinOp FEq (EVar (Var "ef")) (EVar (Var "f"))),ELet (Var "fle",BuiltIn (Some Bool)) (EAppBinOp FLe (EVar (Var "h")) (EVar (Var "f"))),ELet (Var "flt",BuiltIn (Some Bool)) (EAppBinOp FLt (EVar (Var "f")) (EVar (Var "g"))),ELet (Var "fge",BuiltIn (Some Bool)) (EAppBinOp FGe (EVar (Var "g")) (EVar (Var "h"))),ELet (Var "fgt",BuiltIn (Some Bool)) (EAppBinOp FGt (EVar (Var "h")) (EVar (Var "f"))),ELet (Var "add",BuiltIn (Some Field)) (EAppBinOp Add (EVar (Var "f")) (EVar (Var "g"))),ELet (Var "ad'",BuiltIn (Some Field)) (EAppBinOp Add (EVar (Var "f")) (EVar (Var "ef"))),ELet (Var "sub",BuiltIn (Some Field)) (EAppBinOp Sub (EVar (Var "g")) (EVar (Var "h"))),ELet (Var "mul",BuiltIn (Some Field)) (EAppBinOp Mul (EVar (Var "h")) (EVar (Var "f"))),ELet (Var "div",BuiltIn (Some Field)) (EAppBinOp Div (EVar (Var "f")) (EVar (Var "g"))),ELet (Var "v",BuiltIn (Some Vector)) (EConst (Some (UniConst Vector [True,False,True]))),ELet (Var "bat",BuiltIn (Some Bool)) (EAppBinOp BAt (EVar (Var "f")) (EVar (Var "v"))),ELet (Var "ba'",BuiltIn (Some Bool)) (EAppBinOp BAt (EVar (Var "f")) (EVar (Var "ev"))),ELet (Var "not'",BuiltIn (Some Bool)) (EAppUnOp Not (EVar (Var "a"))),ELet (Var "neq0'",BuiltIn (Some Bool)) (EAppUnOp Neq0 (EVar (Var "f"))),ELet (Var "neg'",BuiltIn (Some Field)) (EAppUnOp Neg (EVar (Var "g"))),ELet (Var "inv'",BuiltIn (Some Field)) (EAppUnOp Inv (EVar (Var "h"))),ELet (Var "unp",BuiltIn (Some Vector)) (EAppUnOp Unp (EVar (Var "h"))),ELet (Var "let'",BuiltIn (Some Bool)) (EConst (Some (UniConst Bool True))),EAssert (EConst (Some (UniConst Bool True))),EFor (Var "j") 0 2 (Statements []),ELet (Var "if'",BuiltIn (Some Bool)) (EIf (EConst (Some (UniConst Bool True))) (EConst (Some (UniConst Bool True))) (EConst (Some (UniConst Bool True)))),ELet (Var "asf",BuiltIn (Some Field)) (ETypeAnn (BuiltIn (Some Field)) (EConst (Some 1))),ELet (Var "asb",BuiltIn (Some Bool)) (ETypeAnn (BuiltIn (Some Bool)) (EConst (Some (UniConst Bool True)))),ELet (Var "asv",BuiltIn (Some Vector)) (ETypeAnn (BuiltIn (Some Vector)) (EConst (Some (UniConst Vector [True]))))] \ No newline at end of file diff --git a/tiny-lang.cabal b/tiny-lang.cabal index 9184a69..07c2aeb 100644 --- a/tiny-lang.cabal +++ b/tiny-lang.cabal @@ -30,6 +30,7 @@ library TinyLang.Field.Printer TinyLang.Field.Rename TinyLang.Field.Jubjub + TinyLang.Field.Type TinyLang.Field.Uni TinyLang.Field.Raw.Core TinyLang.Field.Raw.Parser @@ -80,7 +81,8 @@ benchmark bench-generators build-depends: tiny-lang, base >= 4.7 && < 5, - QuickCheck + QuickCheck, + lens default-extensions: BangPatterns ghc-options: