Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Free list transformer to replace deprecated ListT #253

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions docs/docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ runWeightedT (WeightedT m) = runStateT m 1

`WeightedT m` is not an instance of `MonadDistribution`, but only as instance of `MonadFactor` (and that, only when `m` is an instance of `Monad`). However, since `StateT` is a monad transformer, there is a function `lift :: m Double -> WeightedT m Double`.

So if we take a `MonadDistribution` instance like `SamplerIO`, then `WeightedT SamplerIO` is an instance of both `MonadDistribution` and `MonadFactor`. Which means it is an instance of `MonadMeasure`.
So if we take a `MonadDistribution` instance like `SamplerIO`, then `Weighted SamplerIO` is an instance of both `MonadDistribution` and `MonadFactor`. Which means it is an instance of `MonadMeasure`.

So we can successfully write `(sampler . runWeightedT) sprinkler` and get a program of type `IO (Bool, Log Double)`. When run, this will draw a sample from `sprinkler` along with an **unnormalized** density for that sample.

Expand Down Expand Up @@ -328,18 +328,20 @@ Summary of key info on `PopulationT`:
- `instance MonadFactor m => instance MonadFactor (PopulationT m)`

```haskell
newtype PopulationT m a = PopulationT (WeightedT (ListT m) a)
newtype PopulationT m a = PopulationT (WeightedT (FreeT [] m) a)
```

So:
The `FreeT []` construction is for branching our probabilistic program into different branches,
corresponding to different choices of a random variable.

It is interpreted, using `runPopulationT`, to:

```haskell
PopulationT m a ~ m [Log Double -> (a, Log Double)]
m [(a, Log Double)]
```

Note that while `ListT` isn't in general a valid monad transformer, we're not requiring it to be one here.

`PopulationT` is used to represent a collection of particles (in the statistical sense), along with their weights.
This shows that `Population` is used to compute a collection of particles (in the statistical sense), along with their weights.
Each `a` corresponds to one particle, and `Log Double` is the type of its weight.

There are several useful functions associated with it:

Expand All @@ -360,7 +362,7 @@ gives
[([((),0.5),((),0.5)],1.0)]
```

Observe how here we have interpreted `(spawn 2)` as of type `PopulationT Enumerator ()`.
Observe how here we have interpreted `(spawn 2)` as of type `Population Enumerator ()`.

`resampleGeneric` takes a function to probabilistically select a set of indices from a vector, and makes a new population by selecting those indices.

Expand Down Expand Up @@ -393,8 +395,8 @@ Summary of key info on `SequentialT`:


```haskell
newtype SequentialT m a =
SequentialT {runSequentialT :: Coroutine (Await ()) m a}
newtype Sequential m a =
Sequential {runSequential :: Coroutine (Await ()) m a}
```

This is a wrapper for the `Coroutine` type applied to the `Await` constructor from `Control.Monad.Coroutine`, which is defined thus:
Expand All @@ -410,7 +412,7 @@ newtype Await x y = Await (x -> y)
Unpacking that:

```haskell
SequentialT m a ~ m (Either (() -> SequentialT m a) a)
Sequential m a ~ m (Either (() -> Sequential m a) a)
```

As usual, `m` is going to be some other probability monad, so understand `SequentialT m a` as representing a program which, after making a random choice or doing conditioning, we either obtain an `a` value, or a paused computation, which when resumed gets us back to a new `SequentialT m a`.
Expand Down Expand Up @@ -501,11 +503,11 @@ The latter is best understood if you're familiar with the standard use of a free

```haskell
newtype SamF a = Random (Double -> a)
newtype DensityT m a =
DensityT {getDensityT :: FT SamF m a}
newtype Density m a =
Density {density :: FT SamF m a}

instance Monad m => MonadDistribution (DensityT m) where
random = DensityT $ liftF (Random id)
instance Monad m => MonadDistribution (Density m) where
random = Density $ liftF (Random id)
```

The monad-bayes implementation uses a more efficient implementation of `FreeT`, namely `FT` from the `free` package, known as the *Church transformed Free monad*. This is a technique explained in https://begriffs.com/posts/2016-02-04-difference-lists-and-codennsity.html. But that only changes the operational semantics - performance aside, it works just the same as the standard `FreeT` datatype.
Expand Down Expand Up @@ -575,7 +577,7 @@ data Trace a = Trace
}
```

We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `DensityT` is for.
We also need a specification of the probabilistic program in question, free of any particular interpretation. That is precisely what `Density` is for.

The simplest version of `TracedT` is in `Control.Monad.Bayes.TracedT.Basic`

Expand Down Expand Up @@ -635,13 +637,13 @@ example = do
return x
```

`(enumerator . runWeightedT) example` gives `[((False,0.0),0.5),((True,1.0),0.5)]`. This is quite edifying for understanding `(sampler . runWeightedT) example`. What it says is that there are precisely two ways the program will run, each with equal probability: either you get `False` with weight `0.0` or `True` with weight `1.0`.
`(enumerator . weighted) example` gives `[((False,0.0),0.5),((True,1.0),0.5)]`. This is quite edifying for understanding `(sampler . weighted) example`. What it says is that there are precisely two ways the program will run, each with equal probability: either you get `False` with weight `0.0` or `True` with weight `1.0`.

### Quadrature

As described on the section on `Integrator`, we can interpret our probabilistic program of type `MonadDistribution m => m a` as having concrete type `Integrator a`. This views our program as an integrator, allowing us to calculate expectations, probabilities and so on via quadrature (i.e. numerical approximation of an integral).

This can also handle programs of type `MonadMeasure m => m a`, that is, programs with `factor` statements. For these cases, a function `normalize :: WeightedT Integrator a -> Integrator a` is employed. For example,
This can also handle programs of type `MonadMeasure m => m a`, that is, programs with `factor` statements. For these cases, a function `normalize :: Weighted Integrator a -> Integrator a` is employed. For example,

```haskell
model :: MonadMeasure m => m Double
Expand All @@ -652,7 +654,7 @@ model = do
return var
```

is really an unnormalized measure, rather than a probability distribution. `normalize` views it as of type `WeightedT Integrator Double`, which is isomorphic to `(Double -> (Double, Log Double) -> Double)`. This can be used to compute the normalization constant, and divide the integrator's output by it, all within `Integrator`.
is really an unnormalized measure, rather than a probability distribution. `normalize` views it as of type `Weighted Integrator Double`, which is isomorphic to `(Double -> (Double, Log Double) -> Double)`. This can be used to compute the normalization constant, and divide the integrator's output by it, all within `Integrator`.

### Independent forward sampling

Expand Down Expand Up @@ -796,7 +798,7 @@ pmmh ::
pmmh mcmcConf smcConf param model =
(mcmc mcmcConf :: T m [(a, Log Double)] -> m [[(a, Log Double)]])
((param :: T m b) >>=
(runPopulationT :: P (T m) a -> T m [(a, Log Double)])
(population :: P (T m) a -> T m [(a, Log Double)])
. (pushEvidence :: P (T m) a -> P (T m) a)
. Pop.hoist (lift :: forall x. m x -> T m x)
. (smc smcConf :: S (P m) a -> P m a)
Expand Down
13 changes: 9 additions & 4 deletions monad-bayes.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,21 @@ common deps
, scientific ^>=0.3
, statistics >=0.14.0 && <0.17
, text >=1.2 && <2.1
, transformers ^>=0.5.6
, vector >=0.12.0 && <0.14
, vty ^>=5.38

common test-deps
build-depends:
, abstract-par ^>=0.3
, criterion >=1.5 && <1.7
, criterion >=1.5 && <1.7
, directory ^>=1.3
, hspec ^>=2.11
, monad-bayes
, optparse-applicative >=0.17 && <0.19
, optparse-applicative >=0.17 && <0.19
, process ^>=1.6
, QuickCheck ^>=2.14
, time >=1.9 && <1.13
, transformers ^>=0.5.6
, time >=1.9 && <1.13
, typed-process ^>=0.2

autogen-modules: Paths_monad_bayes
Expand All @@ -86,6 +86,7 @@ common test-deps
library
import: deps
exposed-modules:
Control.Applicative.List
Control.Monad.Bayes.Class
Control.Monad.Bayes.Density.Free
Control.Monad.Bayes.Density.State
Expand All @@ -100,6 +101,7 @@ library
Control.Monad.Bayes.Inference.TUI
Control.Monad.Bayes.Integrator
Control.Monad.Bayes.Population
Control.Monad.Bayes.Population.Applicative
Control.Monad.Bayes.Sampler.Lazy
Control.Monad.Bayes.Sampler.Strict
Control.Monad.Bayes.Sequential.Coroutine
Expand All @@ -114,8 +116,11 @@ library
other-modules: Control.Monad.Bayes.Traced.Common
default-language: Haskell2010
default-extensions:
ApplicativeDo
BlockArguments
DerivingStrategies
FlexibleContexts
GeneralizedNewtypeDeriving
ImportQualifiedPost
LambdaCase
OverloadedStrings
Expand Down
23 changes: 23 additions & 0 deletions src/Control/Applicative/List.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{-# LANGUAGE StandaloneDeriving #-}

module Control.Applicative.List where

-- base
import Control.Applicative
import Data.Functor.Compose

-- * Applicative ListT

-- | _Applicative_ transformer adding a list/nondeterminism/choice effect.
-- It is not a valid monad transformer, but it is a valid 'Applicative'.
newtype ListT m a = ListT {getListT :: Compose m [] a}
deriving newtype (Functor, Applicative, Alternative)

listT :: m [a] -> ListT m a
listT = ListT . Compose

lift :: (Functor m) => m a -> ListT m a
lift = ListT . Compose . fmap pure

runListT :: ListT m a -> m [a]
runListT = getCompose . getListT
8 changes: 4 additions & 4 deletions src/Control/Monad/Bayes/Class.hs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,9 @@ import Control.Monad (replicateM, when)
import Control.Monad.Cont (ContT)
import Control.Monad.Except (ExceptT, lift)
import Control.Monad.Identity (IdentityT)
import Control.Monad.List (ListT)
import Control.Monad.Reader (ReaderT)
import Control.Monad.State (StateT)
import Control.Monad.Trans.Free (FreeT)
import Control.Monad.Writer (WriterT)
import Data.Histogram qualified as H
import Data.Histogram.Fill qualified as H
Expand Down Expand Up @@ -390,15 +390,15 @@ instance (MonadFactor m) => MonadFactor (StateT s m) where

instance (MonadMeasure m) => MonadMeasure (StateT s m)

instance (MonadDistribution m) => MonadDistribution (ListT m) where
instance (Applicative f, (MonadDistribution m)) => MonadDistribution (FreeT f m) where
random = lift random
bernoulli = lift . bernoulli
categorical = lift . categorical

instance (MonadFactor m) => MonadFactor (ListT m) where
instance (Applicative f, (MonadFactor m)) => MonadFactor (FreeT f m) where
score = lift . score

instance (MonadMeasure m) => MonadMeasure (ListT m)
instance (Applicative f, (MonadMeasure m)) => MonadMeasure (FreeT f m)

instance (MonadDistribution m) => MonadDistribution (ContT r m) where
random = lift random
Expand Down
11 changes: 6 additions & 5 deletions src/Control/Monad/Bayes/Inference/RMSMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..))
import Control.Monad.Bayes.Inference.SMC
import Control.Monad.Bayes.Population
( PopulationT,
spawn,
flatten,
single,
withParticles,
)
import Control.Monad.Bayes.Sequential.Coroutine as Seq
Expand All @@ -50,8 +51,8 @@ rmsmc ::
PopulationT m a
rmsmc (MCMCConfig {..}) (SMCConfig {..}) =
marginal
. S.sequentially (composeCopies numMCMCSteps mhStep . TrStat.hoist resampler) numSteps
. S.hoistFirst (TrStat.hoist (spawn numParticles >>))
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoistModel (single . flatten) . TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
. S.hoistFirst (TrStat.hoistModel (single . flatten) . TrStat.hoist (withParticles numParticles))

-- | Resample-move Sequential Monte Carlo with a more efficient
-- tracing representation.
Expand All @@ -64,7 +65,7 @@ rmsmcBasic ::
PopulationT m a
rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) =
TrBas.marginal
. S.sequentially (composeCopies numMCMCSteps TrBas.mhStep . TrBas.hoist resampler) numSteps
. S.sequentially (TrBas.hoist (single . flatten) . composeCopies numMCMCSteps (TrBas.hoist (single . flatten) . TrBas.mhStep) . TrBas.hoist resampler) numSteps
. S.hoistFirst (TrBas.hoist (withParticles numParticles))

-- | A variant of resample-move Sequential Monte Carlo
Expand All @@ -79,7 +80,7 @@ rmsmcDynamic ::
PopulationT m a
rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) =
TrDyn.marginal
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps TrDyn.mhStep . TrDyn.hoist resampler) numSteps
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist (single . flatten) . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps
. S.hoistFirst (TrDyn.hoist (withParticles numParticles))

-- | Apply a function a given number of times.
Expand Down
29 changes: 26 additions & 3 deletions src/Control/Monad/Bayes/Inference/SMC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,38 @@ where

import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure)
import Control.Monad.Bayes.Population
( PopulationT,
( PopulationT (..),
flatten,
pushEvidence,
single,
withParticles,
)
import Control.Monad.Bayes.Population.Applicative qualified as Applicative
import Control.Monad.Bayes.Sequential.Coroutine as Coroutine
import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
import Control.Monad.Bayes.Weighted (WeightedT (..), weightedT)
import Control.Monad.Coroutine
import Control.Monad.Trans.Free (FreeF (..), FreeT (..))

data SMCConfig m = SMCConfig
{ resampler :: forall x. PopulationT m x -> PopulationT m x,
numSteps :: Int,
numParticles :: Int
}

sequentialToPopulation :: (Monad m) => Coroutine.SequentialT (Applicative.PopulationT m) a -> PopulationT m a
sequentialToPopulation =
PopulationT
. weightedT
. coroutineToFree
. Coroutine.runSequentialT
where
coroutineToFree =
FreeT
. fmap (Free . fmap (\(cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont))
. Applicative.runPopulationT
. resume

-- | Sequential importance resampling.
-- Basically an SMC template that takes a custom resampler.
smc ::
Expand All @@ -42,12 +62,15 @@ smc ::
Coroutine.SequentialT (PopulationT m) a ->
PopulationT m a
smc SMCConfig {..} =
Coroutine.sequentially resampler numSteps
(single . flatten)
. Coroutine.sequentially resampler numSteps
. SequentialT.hoist (single . flatten)
. Coroutine.hoistFirst (withParticles numParticles)
. SequentialT.hoist (single . flatten)

-- | Sequential Monte Carlo with multinomial resampling at each timestep.
-- Weights are normalized at each timestep and the total weight is pushed
-- as a score into the transformed monad.
smcPush ::
(MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a
smcPush config = smc config {resampler = (pushEvidence . resampler config)}
smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)}
12 changes: 10 additions & 2 deletions src/Control/Monad/Bayes/Inference/SMC2.hs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ import Control.Monad.Bayes.Class
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT)
import Control.Monad.Bayes.Population as Pop (PopulationT, flatten, resampleMultinomial, runPopulationT, single)
import Control.Monad.Bayes.Population qualified as PopulationT
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
import Control.Monad.Bayes.Traced
import Control.Monad.Trans (MonadTrans (..))
import Numeric.Log (Log)
Expand Down Expand Up @@ -71,4 +73,10 @@ smc2 k n p t param m =
rmsmc
MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0}
SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial}
(param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m)
(flattenSequentiallyTraced param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . flattenSMC2 . m)

flattenSequentiallyTraced :: (Monad m) => SequentialT (TracedT (PopulationT m)) a -> SequentialT (TracedT (PopulationT m)) a
flattenSequentiallyTraced = SequentialT.hoist $ hoistModel (single . flatten) . hoist (single . flatten)

flattenSMC2 :: (Monad m) => SequentialT (PopulationT (SMC2 m)) a -> SequentialT (PopulationT (SMC2 m)) a
flattenSMC2 = SequentialT.hoist $ single . flatten . PopulationT.hoist (SMC2 . flattenSequentiallyTraced . setup)
Loading
Loading