Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add newtype to lift classes to monad transformers #272

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions monad-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,14 @@ library

default-extensions:
BlockArguments
DerivingVia
FlexibleContexts
GeneralizedNewtypeDeriving
ImportQualifiedPost
KindSignatures
LambdaCase
OverloadedStrings
StandaloneDeriving
TupleSections

if flag(dev)
Expand Down
91 changes: 54 additions & 37 deletions src/Control/Monad/Bayes/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ module Control.Monad.Bayes.Class
Measure,
Kernel,
Log (ln, Exp),
MonadMeasureTrans (..),
)
where

Expand All @@ -82,9 +83,11 @@ import Control.Monad.Identity (IdentityT)
import Control.Monad.List (ListT)
import Control.Monad.Reader (ReaderT)
import Control.Monad.State (StateT)
import Control.Monad.Trans (MonadTrans)
import Control.Monad.Writer (WriterT)
import Data.Histogram qualified as H
import Data.Histogram.Fill qualified as H
import Data.Kind (Type)
import Data.Matrix
( Matrix,
cholDecomp,
Expand Down Expand Up @@ -342,68 +345,82 @@ histogramToList = H.asList
----------------------------------------------------------------------------
-- Instances that lift probabilistic effects to standard tranformers.

instance MonadDistribution m => MonadDistribution (IdentityT m) where
random = lift random
bernoulli = lift . bernoulli
deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m)

instance MonadFactor m => MonadFactor (IdentityT m) where
score = lift . score
deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m)

instance MonadMeasure m => MonadMeasure (IdentityT m)

instance MonadDistribution m => MonadDistribution (ExceptT e m) where
random = lift random
uniformD = lift . uniformD
deriving via (MonadMeasureTrans (ExceptT e) m) instance MonadDistribution m => MonadDistribution (ExceptT e m)

instance MonadFactor m => MonadFactor (ExceptT e m) where
score = lift . score
deriving via (MonadMeasureTrans (ExceptT e) m) instance MonadFactor m => MonadFactor (ExceptT e m)

instance MonadMeasure m => MonadMeasure (ExceptT e m)

instance MonadDistribution m => MonadDistribution (ReaderT r m) where
random = lift random
bernoulli = lift . bernoulli
deriving via (MonadMeasureTrans (ReaderT r) m) instance MonadDistribution m => MonadDistribution (ReaderT r m)

instance MonadFactor m => MonadFactor (ReaderT r m) where
score = lift . score
deriving via (MonadMeasureTrans (ReaderT r) m) instance MonadFactor m => MonadFactor (ReaderT r m)

instance MonadMeasure m => MonadMeasure (ReaderT r m)

instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
deriving via (MonadMeasureTrans (WriterT w) m) instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m)

instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m) where
score = lift . score
deriving via (MonadMeasureTrans (WriterT w) m) instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m)

instance (Monoid w, MonadMeasure m) => MonadMeasure (WriterT w m)

instance MonadDistribution m => MonadDistribution (StateT s m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
uniformD = lift . uniformD
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only note here would be that some of the previous lifting instances appear to have varied slightly. This actually may have been my doing, because I recall that working in a transformed Enumerator didn't lift bernoulli, so that discrete distributions got calculated via random and thus failed with enumerate.

deriving via (MonadMeasureTrans (StateT s) m) instance MonadDistribution m => MonadDistribution (StateT s m)

instance MonadFactor m => MonadFactor (StateT s m) where
score = lift . score
deriving via (MonadMeasureTrans (StateT s) m) instance MonadFactor m => MonadFactor (StateT s m)

instance MonadMeasure m => MonadMeasure (StateT s m)

instance MonadDistribution m => MonadDistribution (ListT m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
deriving via (MonadMeasureTrans ListT m) instance MonadDistribution m => MonadDistribution (ListT m)

instance MonadFactor m => MonadFactor (ListT m) where
score = lift . score
deriving via (MonadMeasureTrans ListT m) instance MonadFactor m => MonadFactor (ListT m)

instance MonadMeasure m => MonadMeasure (ListT m)

instance MonadDistribution m => MonadDistribution (ContT r m) where
deriving via (MonadMeasureTrans (ContT r) m) instance MonadDistribution m => MonadDistribution (ContT r m)

deriving via (MonadMeasureTrans (ContT r) m) instance MonadFactor m => MonadFactor (ContT r m)

instance MonadMeasure m => MonadMeasure (ContT r m)

-- * Utility for deriving MonadDistribution, MonadFactor and MonadMeasure

-- | Newtype to derive 'MonadDistribution', 'MonadFactor' and 'MonadMeasure' automatically for monad transformers.
--
-- The typical usage is with the `StandaloneDeriving` and `DerivingVia` extensions.
-- For example, to derive all instances for the 'IdentityT' transformer, one writes:
--
-- @
-- deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m)
-- deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m)
-- instance MonadMeasure m => MonadMeasure (IdentityT m)
-- @
-- (The final 'MonadMeasure' could also be derived `via`, but this isn't necessary because it doesn't contain any methods.)
newtype MonadMeasureTrans (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a = MonadMeasureTrans {getMonadMeasureTrans :: t m a}
deriving (Functor, Applicative, Monad)

instance MonadTrans t => MonadTrans (MonadMeasureTrans t) where
lift = MonadMeasureTrans . lift

instance (MonadTrans t, MonadDistribution m, Monad (t m)) => MonadDistribution (MonadMeasureTrans t m) where
random = lift random
uniform = (lift .) . uniform
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, there may have been a reason that many of the instances were limited to mostly just lift random. This doesn't change the denotational semantics, but may have efficiency implications, so make sure to benchmark, particularly PMMH and RMSMC.

normal = (lift .) . normal
gamma = (lift .) . gamma
beta = (lift .) . beta
bernoulli = lift . bernoulli
categorical = lift . categorical
logCategorical = lift . logCategorical
uniformD = lift . uniformD
geometric = lift . geometric
poisson = lift . poisson
dirichlet = lift . dirichlet

instance MonadFactor m => MonadFactor (ContT r m) where
instance (MonadFactor m, MonadTrans t, Monad (t m)) => MonadFactor (MonadMeasureTrans t m) where
score = lift . score

instance MonadMeasure m => MonadMeasure (ContT r m)
instance (MonadDistribution m, MonadFactor m, MonadTrans t, Monad (t m)) => MonadMeasure (MonadMeasureTrans t m)
10 changes: 2 additions & 8 deletions src/Control/Monad/Bayes/Inference/SMC2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ module Control.Monad.Bayes.Inference.SMC2
where

import Control.Monad.Bayes.Class
( MonadDistribution (random),
( MonadDistribution,
MonadFactor (..),
MonadMeasure,
)
Expand All @@ -35,20 +35,14 @@ import Numeric.Log (Log)

-- | Helper monad transformer for preprocessing the model for 'smc2'.
newtype SMC2 m a = SMC2 (Sequential (Traced (Population m)) a)
deriving newtype (Functor, Applicative, Monad)
deriving newtype (Functor, Applicative, Monad, MonadDistribution, MonadFactor)

setup :: SMC2 m a -> Sequential (Traced (Population m)) a
setup (SMC2 m) = m

instance MonadTrans SMC2 where
lift = SMC2 . lift . lift . lift

instance MonadDistribution m => MonadDistribution (SMC2 m) where
random = lift random

instance Monad m => MonadFactor (SMC2 m) where
score = SMC2 . score

instance MonadDistribution m => MonadMeasure (SMC2 m)

-- | Sequential Monte Carlo squared.
Expand Down
8 changes: 3 additions & 5 deletions src/Control/Monad/Bayes/Sequential/Coroutine.hs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ module Control.Monad.Bayes.Sequential.Coroutine
where

import Control.Monad.Bayes.Class
( MonadDistribution (bernoulli, categorical, random),
( MonadDistribution,
MonadFactor (..),
MonadMeasure,
MonadMeasureTrans (..),
)
import Control.Monad.Coroutine
( Coroutine (..),
Expand All @@ -54,10 +55,7 @@ newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}
extract :: Await () a -> a
extract (Await f) = f ()

instance MonadDistribution m => MonadDistribution (Sequential m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical
deriving via (MonadMeasureTrans Sequential m) instance MonadDistribution m => MonadDistribution (Sequential m)

-- | Execution is 'suspend'ed after each 'score'.
instance MonadFactor m => MonadFactor (Sequential m) where
Expand Down