Skip to content

Commit

Permalink
Handle integerNegate in the plugin (#5494)
Browse files Browse the repository at this point in the history
  • Loading branch information
zliu41 authored Aug 29, 2023
1 parent 35b1f1f commit 6c7e230
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 21 deletions.
2 changes: 2 additions & 0 deletions plutus-tx-plugin/plutus-tx-plugin.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ test-suite plutus-tx-tests
main-is: Spec.hs
other-modules:
Budget.Spec
IntegerLiterals.NoStrict.NegativeLiterals.Spec
IntegerLiterals.NoStrict.NoNegativeLiterals.Spec
IsData.Spec
Lib
Lift.Spec
Expand Down
68 changes: 58 additions & 10 deletions plutus-tx-plugin/src/PlutusTx/Compiler/Expr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
Expand Down Expand Up @@ -67,6 +68,7 @@ import Data.Set qualified as Set
import Data.Text qualified as T
import Data.Text.Encoding qualified as TE
import Data.Traversable
import GHC.Num.Integer qualified

{- Note [System FC and System FW]
Haskell uses system FC, which includes type equalities and coercions.
Expand Down Expand Up @@ -710,8 +712,15 @@ compileExpr e = traceCompilation 2 ("Compiling expr:" GHC.<+> GHC.ppr e) $ do
| GHC.getName build == GHC.buildName && GHC.getName unpack == GHC.unpackCStringFoldrName -> compileExpr expr
-- C# is just a wrapper around a literal
GHC.Var (GHC.idDetails -> GHC.DataConWorkId dc) `GHC.App` arg | dc == GHC.charDataCon -> compileExpr arg
-- constructors of 'Integer' just wrap literals
GHC.Var (GHC.idDetails -> GHC.DataConWorkId dc) `GHC.App` arg | GHC.dataConTyCon dc == GHC.integerTyCon -> compileExpr arg
-- Handle constructors of 'Integer'
GHC.Var (GHC.idDetails -> GHC.DataConWorkId dc) `GHC.App` arg | GHC.dataConTyCon dc == GHC.integerTyCon -> do
i <- compileExpr arg
-- IN is a negative integer!
if GHC.dataConName dc == GHC.integerINDataConName
then do
negateTerm <- lookupIntegerNegate
pure $ PIR.mkIterApp negateTerm [(annMayInline, i)]
else pure i
-- Unboxed unit, (##).
GHC.Var (GHC.idDetails -> GHC.DataConWorkId dc) | dc == GHC.unboxedUnitDataCon -> pure (PIR.mkConstant annMayInline ())
-- Ignore the magic 'noinline' function, it's the identity but has no unfolding.
Expand Down Expand Up @@ -741,9 +750,6 @@ compileExpr e = traceCompilation 2 ("Compiling expr:" GHC.<+> GHC.ppr e) $ do
GHC.Var (lookupName scope . GHC.getName -> Just var) -> pure $ PIR.mkVar annMayInline var
-- Special kinds of id
GHC.Var (GHC.idDetails -> GHC.DataConWorkId dc) -> compileDataConRef dc
-- See Note [Unfoldings]
-- The "unfolding template" includes things with normal unfoldings and also dictionary functions
GHC.Var n@(GHC.maybeUnfoldingTemplate . GHC.realIdUnfolding -> Just unfolding) -> hoistExpr n unfolding
-- Class ops don't have unfoldings in general (although they do if they're for one-method classes, so we
-- want to check the unfoldings case first), see the GHC Note [ClassOp/DFun selection] for why. That
-- means we have to reconstruct the RHS ourselves, though, which is a pain.
Expand All @@ -765,11 +771,17 @@ compileExpr e = traceCompilation 2 ("Compiling expr:" GHC.<+> GHC.ppr e) $ do
case maybeDef of
Just term -> pure term
Nothing ->
throwSd FreeVariableError $
"Variable"
GHC.<+> GHC.ppr n
GHC.$+$ (GHC.ppr $ GHC.idDetails n)
GHC.$+$ (GHC.ppr $ GHC.realIdUnfolding n)
-- No other cases apply; compile the unfolding of the var
case GHC.maybeUnfoldingTemplate (GHC.realIdUnfolding n) of
-- See Note [Unfoldings]
-- The "unfolding template" includes things with normal unfoldings and also dictionary functions
Just unfolding -> hoistExpr n unfolding
Nothing ->
throwSd FreeVariableError $
"Variable"
GHC.<+> GHC.ppr n
GHC.$+$ (GHC.ppr $ GHC.idDetails n)
GHC.$+$ (GHC.ppr $ GHC.realIdUnfolding n)

-- ignoring applications to types of 'RuntimeRep' kind, see Note [Unboxed tuples]
l `GHC.App` GHC.Type t | GHC.isRuntimeRepKindedTy t -> compileExpr l
Expand Down Expand Up @@ -1050,13 +1062,49 @@ coverageCompile originalExpr exprType src compiledTerm covT =
findHeadSymbol (GHC.Cast t _) = findHeadSymbol t
findHeadSymbol _ = Nothing

-- | We cannot compile the unfolding of `GHC.Num.Integer.integerNegate`, which is
-- important because GHC inserts calls to it when it sees negations, even negations
-- of literals (unless NegativeLiterals is on, which it usually isn't). So we directly
-- define a PIR term for it: @integerNegate = \x -> 0 - x@.
defineIntegerNegate :: (CompilingDefault PLC.DefaultUni fun m ann) => m ()
defineIntegerNegate = do
ghcId <- GHC.tyThingId <$> getThing 'GHC.Num.Integer.integerNegate
-- Always inline `integerNegate`.
-- `let integerNegate = \x -> 0 - x in integerNegate 1 + integerNegate 2`
-- is much more expensive than `(-1) + (-2)`. The inliner cannot currently
-- make this transformation without `annAlwaysInline`, because it is not aware
-- of constant folding.
var <- compileVarFresh annAlwaysInline ghcId
let ann = annMayInline
x <- safeFreshName "x"
let
-- body = 0 - x
body =
PIR.LamAbs ann x (PIR.mkTyBuiltin @_ @Integer @PLC.DefaultUni ann) $
PIR.mkIterApp
(PIR.Builtin ann PLC.SubtractInteger)
[ (ann, PIR.mkConstant @Integer ann 0)
, (ann, PIR.Var ann x)
]
def = PIR.Def var (body, PIR.Strict)
PIR.defineTerm (LexName GHC.integerNegateName) def mempty

lookupIntegerNegate :: (Compiling uni fun m ann) => m (PIRTerm uni fun)
lookupIntegerNegate = do
ghcName <- GHC.getName <$> getThing 'GHC.Num.Integer.integerNegate
PIR.lookupTerm annMayInline (LexName ghcName) >>= \case
Just t -> pure t
Nothing -> throwPlain $
CompilationError "Cannot find the definition of integerNegate. Please file a bug report."

compileExprWithDefs ::
(CompilingDefault uni fun m ann) =>
GHC.CoreExpr ->
m (PIRTerm uni fun)
compileExprWithDefs e = do
defineBuiltinTypes
defineBuiltinTerms
defineIntegerNegate
compileExpr e

{- Note [We always need DEFAULT]
Expand Down
4 changes: 3 additions & 1 deletion plutus-tx-plugin/src/PlutusTx/Plugin.hs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ import Data.ByteString.Unsafe qualified as BSUnsafe
import Data.Either.Validation
import Data.Map qualified as Map
import Data.Set qualified as Set
import GHC.Num.Integer qualified
import PlutusIR.Compiler.Provenance (noProvenance, original)
import Prettyprinter qualified as PP
import System.IO (openTempFile)
Expand Down Expand Up @@ -386,7 +387,8 @@ compileMarkedExpr locStr codeTy origE = do
let moduleNameStr =
GHC.showSDocForUser flags GHC.emptyUnitState GHC.alwaysQualify (GHC.ppr moduleName)
-- We need to do this out here, since it has to run in CoreM
nameInfo <- makePrimitiveNameInfo $ builtinNames ++ [''Bool, 'False, 'True, 'traceBool]
nameInfo <- makePrimitiveNameInfo $
builtinNames ++ [''Bool, 'False, 'True, 'traceBool, 'GHC.Num.Integer.integerNegate]
modBreaks <- asks pcModuleModBreaks
let coverage = CoverageOpts . Set.fromList $
[ l | _posCoverageAll opts, l <- [minBound .. maxBound]]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
(program
1.1.0
(lam
x
(con integer)
[
[ (builtin addInteger) (con integer 24680135792468013579) ]
[
[ (builtin addInteger) (con integer -99887766554433221100) ]
[
[ (builtin addInteger) (con integer 98765432109876543210) ]
[
[ (builtin addInteger) (con integer -654) ]
[
[ (builtin addInteger) (con integer 456) ]
[
[ (builtin addInteger) (con integer 13579246801357924680) ]
[
[ (builtin addInteger) (con integer -11223344556677889900) ]
[
[ (builtin addInteger) (con integer 12345678901234567890) ]
[
[ (builtin addInteger) (con integer -321) ]
[ [ (builtin multiplyInteger) (con integer 123) ] x ]
]
]
]
]
]
]
]
]
]
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
(program
1.1.0
(lam
x
(con integer)
[
[ (builtin addInteger) (con integer 24680135792468013579) ]
[
[ (builtin addInteger) (con integer -99887766554433221100) ]
[
[ (builtin addInteger) (con integer 98765432109876543210) ]
[
[ (builtin addInteger) (con integer -654) ]
[
[ (builtin addInteger) (con integer 456) ]
[
[ (builtin addInteger) (con integer 13579246801357924680) ]
[
[ (builtin addInteger) (con integer -11223344556677889900) ]
[
[ (builtin addInteger) (con integer 12345678901234567890) ]
[
[ (builtin addInteger) (con integer -321) ]
[ [ (builtin multiplyInteger) (con integer 123) ] x ]
]
]
]
]
]
]
]
]
]
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
(program
1.1.0
(lam
x
(con integer)
[
[ (builtin addInteger) (con integer 24680135792468013579) ]
[
[ (builtin addInteger) (con integer -99887766554433221100) ]
[
[ (builtin addInteger) (con integer 98765432109876543210) ]
[
[ (builtin addInteger) (con integer -654) ]
[
[ (builtin addInteger) (con integer 456) ]
[
[ (builtin addInteger) (con integer 13579246801357924680) ]
[
[ (builtin addInteger) (con integer -11223344556677889900) ]
[
[ (builtin addInteger) (con integer 12345678901234567890) ]
[
[ (builtin addInteger) (con integer -321) ]
[ [ (builtin multiplyInteger) (con integer 123) ] x ]
]
]
]
]
]
]
]
]
]
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
(program
1.1.0
(lam
x
(con integer)
[
[ (builtin addInteger) (con integer 24680135792468013579) ]
[
[ (builtin addInteger) (con integer -99887766554433221100) ]
[
[ (builtin addInteger) (con integer 98765432109876543210) ]
[
[ (builtin addInteger) (con integer -654) ]
[
[ (builtin addInteger) (con integer 456) ]
[
[ (builtin addInteger) (con integer 13579246801357924680) ]
[
[ (builtin addInteger) (con integer -11223344556677889900) ]
[
[ (builtin addInteger) (con integer 12345678901234567890) ]
[
[ (builtin addInteger) (con integer -321) ]
[ [ (builtin multiplyInteger) (con integer 123) ] x ]
]
]
]
]
]
]
]
]
]
)
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}

{-# LANGUAGE NegativeLiterals #-}
{-# LANGUAGE NoStrict #-}

{-# OPTIONS_GHC -fplugin PlutusTx.Plugin #-}

-- | This module tests that integer literals are handled correctly, when @Strict@ is off
-- and @NegativeLiterals@ is on. These two extensions affect the Core we get. When
-- @NegativeLiterals@ is on, we can get @IN@ for negative integers.
--
-- This module runs the PIR and UPLC simplifiers, because (1) we want to verify that
-- `integerNegate` is compiled away; (2) it is easier to tell from the optimized PIR
-- whether or not the signs of the numbers are correct, which is ultimately what we
-- care about.
module IntegerLiterals.NoStrict.NegativeLiterals.Spec where

import PlutusTx.Code
import PlutusTx.Prelude qualified as PlutusTx
import PlutusTx.Test
import PlutusTx.TH (compile)

import Test.Tasty.Extras

tests :: TestNested
tests = testNestedGhc "IntegerLiterals"
[ goldenPir "integerLiterals-NoStrict-NegativeLiterals" integerLiterals
]

integerLiterals :: CompiledCode (Integer -> Integer)
integerLiterals =
$$( compile
[||
\x ->
let !smallStrict = 123
!smallNegStrict = -321
!bigStrict = 12345678901234567890
!bigNegStrict = -11223344556677889900
!bigDoubleNegStrict = -(-13579246801357924680)
~smallLazy = 456
~smallNegLazy = -654
~bigLazy = 98765432109876543210
~bigNegLazy = -99887766554433221100
~bigDoubleNegLazy = -(-24680135792468013579)
in x PlutusTx.* smallStrict
PlutusTx.+ smallNegStrict
PlutusTx.+ bigStrict
PlutusTx.+ bigNegStrict
PlutusTx.+ bigDoubleNegStrict
PlutusTx.+ smallLazy
PlutusTx.+ smallNegLazy
PlutusTx.+ bigLazy
PlutusTx.+ bigNegLazy
PlutusTx.+ bigDoubleNegLazy
||]
)
Loading

0 comments on commit 6c7e230

Please sign in to comment.